-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[caikit-nlp-163] Refactor peft module to take out common peft config functionality #174
Conversation
Signed-off-by: Trevor Grant <[email protected]>
@gkumbhat i moved a batch of stuff to a new file- is this what you had in mind? wanted to confirm i was on the right path before i did more stuff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @rawkintrevo for quick turn-around. I think 1 other thing that would move would be https://github.com/caikit/caikit-nlp/pull/174/files#diff-1cb191003903163320c02f8ffaf7c5edd48ca6649cd3092d1b4c3a0fdcd038c1R376 function, since that returns the peft_config
which is then passed to get the peft_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Trevor Grant <[email protected]>
I think I've bashed all the bugs, we'll see what the checks say. If so I'll write unit tests tomorrow |
Signed-off-by: Trevor Grant <[email protected]>
Signed-off-by: Trevor Grant <[email protected]>
Signed-off-by: Trevor Grant <[email protected]>
"Validated tuning source prompt [%s]", | ||
tuning_config.prompt_tuning_init_source_model, | ||
) | ||
base_model_name = base_model._model_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since in runtime use-cases we do accept str
as base_model
this would error out. Can we bring back the base_model
parsing logic we had in the train function earlier to resolve that. ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The linked block exists in peft_config.py at line 100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, but the problem is that one is getting called after we do line 347. And so line 347 can fail if base_model
is a string.
Since the code to resolve the base_model is common. It might make sense to pull that out from validate_peft_config
function into a separate resolve_base_model
function and we call that function before we call validate_peft_config
.. So essentially:
base_model = resolve_base_model(base_model)
task_type, output_model_types, peft_config, tuning_type = validate_peft_config(..., base_model, ...)
# LORA = "LORA" | ||
|
||
|
||
def validate_peft_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we rename the function to get_peft_config
as this function takes the raw config from our side and returns back the "peft config` ?
|
||
|
||
def validate_peft_config( | ||
tuning_type, tuning_config, error, log, base_model, cls, torch_dtype, verbalizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually initialize logger per module / file to allow easy backtracking. To do this, we can initialize log
at the beginning of this file and re-use that everywhere in this file.
log = alog.use_channel("PFT_CNFG_TLKT")
"Validated tuning source prompt [%s]", | ||
tuning_config.prompt_tuning_init_source_model, | ||
) | ||
base_model_name = base_model._model_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, but the problem is that one is getting called after we do line 347. And so line 347 can fail if base_model
is a string.
Since the code to resolve the base_model is common. It might make sense to pull that out from validate_peft_config
function into a separate resolve_base_model
function and we call that function before we call validate_peft_config
.. So essentially:
base_model = resolve_base_model(base_model)
task_type, output_model_types, peft_config, tuning_type = validate_peft_config(..., base_model, ...)
Signed-off-by: Trevor Grant <[email protected]>
Signed-off-by: Trevor Grant <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks Trevor. We need to coordinate these merges a bit as they are all changing same files.
ack @gkumbhat - will let someone else coordinate the merge |
Pushed a rebased version of this branch to #197! |
No description provided.