generated from caikit/caikit-template
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[caikit-nlp-163] Refactor peft module to take out common peft config functionality #174
Closed
Closed
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
ddc2587
we'll see if this defeats DCO
rawkintrevo 4a29eab
bashing various bugs
rawkintrevo 1f86d90
clean and pretty
rawkintrevo 33660a6
lint
rawkintrevo 3eff84e
unit tests
rawkintrevo a9d4643
Merge branch 'caikit:main' into 163
rawkintrevo 3473f6a
respond to pr comments
rawkintrevo 1af19c6
add resolve base model
rawkintrevo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Copyright The Caikit Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Standard | ||
from enum import Enum | ||
import os | ||
|
||
# Third Party | ||
from peft import MultitaskPromptTuningInit | ||
from transformers import AutoConfig | ||
|
||
# First Party | ||
from caikit import get_config | ||
|
||
# Local | ||
from ...data_model import PromptOutputModelType | ||
from ...resources.pretrained_model import PretrainedModelBase | ||
from ...toolkit.data_type_utils import get_torch_dtype | ||
from ...toolkit.verbalizer_utils import is_valid_verbalizer | ||
|
||
# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK` | ||
# since those are for experimental use and would not be useful / applicable | ||
# for end-user use-cases | ||
allowed_tuning_init_methods = [ | ||
"TEXT", | ||
"RANDOM", | ||
"ONLY_SOURCE_SHARED", | ||
"AVERAGE_SOURCE_TASKS", | ||
] | ||
|
||
|
||
class TuningType(str, Enum): | ||
PROMPT_TUNING = "PROMPT_TUNING" | ||
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" | ||
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" | ||
# P_TUNING = "P_TUNING" | ||
# PREFIX_TUNING = "PREFIX_TUNING" | ||
# LORA = "LORA" | ||
|
||
|
||
def validate_peft_config( | ||
tuning_type, tuning_config, error, log, base_model, cls, torch_dtype, verbalizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We usually initialize logger per module / file to allow easy backtracking. To do this, we can initialize
|
||
): | ||
if tuning_type not in TuningType._member_names_: | ||
raise NotImplementedError("{} tuning type not supported!".format(tuning_type)) | ||
|
||
if tuning_config.prompt_tuning_init_method: | ||
# NOTE: GK-APR-5-2023 | ||
# MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the | ||
# time of writing, which is a superset of PromptTuningInit | ||
init_method = tuning_config.prompt_tuning_init_method | ||
|
||
error.value_check( | ||
"<NLP11848053E>", | ||
init_method in allowed_tuning_init_methods, | ||
f"Init method [{init_method}] not in allowed init methods: " | ||
f"[{allowed_tuning_init_methods}]", | ||
) | ||
|
||
init_method = MultitaskPromptTuningInit(init_method) | ||
log.info("Using initialization method [%s]", init_method) | ||
|
||
# If init method provided relates to one that requires source model, | ||
# make sure the source prompt model is provided. | ||
if init_method in [ | ||
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, | ||
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, | ||
]: | ||
# NOTE: prompt_tuning_init_source_model is currently a path. In future | ||
# we will replace this with caikit.resources to properly cataloging these | ||
error.type_check( | ||
"<NLP89108490E>", | ||
str, | ||
prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model, | ||
) | ||
tuning_config.prompt_tuning_init_source_model = os.path.join( | ||
get_config().source_prompt_base, | ||
tuning_config.prompt_tuning_init_source_model, | ||
) | ||
|
||
error.file_check( | ||
"<NLP96030210E>", tuning_config.prompt_tuning_init_source_model | ||
) | ||
log.debug( | ||
"Validated tuning source prompt [%s]", | ||
tuning_config.prompt_tuning_init_source_model, | ||
) | ||
|
||
if isinstance(base_model, str): | ||
model_config = AutoConfig.from_pretrained( | ||
base_model, local_files_only=not get_config().allow_downloads | ||
) | ||
|
||
resource_type = None | ||
for resource in cls.supported_resources: | ||
if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: | ||
resource_type = resource | ||
break | ||
|
||
if not resource_type: | ||
error( | ||
"<NLP61784225E>", | ||
"{} model type is not supported currently!".format( | ||
model_config.model_type | ||
), | ||
) | ||
log.debug("Bootstrapping base resource [%s]", base_model) | ||
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) | ||
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model) | ||
|
||
# Validate if tuned output model type is compatible with base model or not | ||
if not tuning_config.output_model_types: | ||
output_model_types = base_model.PROMPT_OUTPUT_TYPES | ||
else: | ||
# If the first element is not PromptOutputModelType, assume the entire list | ||
# isn't and convert | ||
if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType): | ||
output_model_types = [] | ||
for output_type in tuning_config.output_model_types: | ||
output_model_types.append(PromptOutputModelType(output_type)) | ||
else: | ||
output_model_types = tuning_config.output_model_types | ||
error.value_check( | ||
"<NLP36947542E>", | ||
all( | ||
output_type in base_model.PROMPT_OUTPUT_TYPES | ||
for output_type in output_model_types | ||
), | ||
"{} not supported for base model type {}".format( | ||
output_model_types, base_model.MODEL_TYPE | ||
), | ||
) | ||
|
||
error.value_check( | ||
"<NLP30542004E>", | ||
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, | ||
f"Too many output model types. Got {len(output_model_types)}, " | ||
f"maximum {base_model.MAX_NUM_TRANSFORMERS}", | ||
) | ||
# Ensure that our verbalizer is a string and will not render to a hardcoded string | ||
error.value_check( | ||
"<NLP83837412E>", | ||
is_valid_verbalizer(verbalizer), | ||
"Provided verbalizer is an invalid type or has no renderable placeholders", | ||
) | ||
|
||
# NOTE: Base model is a resource at this point | ||
task_type = base_model.TASK_TYPE | ||
|
||
if isinstance(tuning_type, str): | ||
error.value_check( | ||
"<NLP65714994E>", | ||
tuning_type in TuningType._member_names_, | ||
f"Invalid tuning type [{tuning_type}]. Allowed types: " | ||
f"[{TuningType._member_names_}]", | ||
) | ||
tuning_type = TuningType(tuning_type) | ||
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type) | ||
|
||
# Coerce the passed model into a resource; if we have one, this is a noop | ||
# TODO: When splitting up this mono-module, use the configured resource | ||
# type of the concrete class to bootstrap | ||
torch_dtype = get_torch_dtype(torch_dtype) | ||
|
||
# Take tokenizer name/path from the model | ||
tokenizer_name_or_path = base_model.model.config._name_or_path | ||
|
||
# Build the peft config; this is how we determine that we want a sequence classifier. | ||
# If we want more types, we will likely need to map this to data model outputs etc. | ||
|
||
# NOTE: We currently only support TEXT as init type, this is to later only easily | ||
# switch to MPT | ||
peft_config = cls.create_hf_tuning_config( | ||
base_model=base_model, | ||
tuning_type=tuning_type, | ||
task_type=task_type, | ||
tokenizer_name_or_path=tokenizer_name_or_path, | ||
tuning_config=tuning_config, | ||
output_model_types=output_model_types, | ||
) | ||
|
||
return task_type, output_model_types, peft_config, tuning_type |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we rename the function to
get_peft_config
as this function takes the raw config from our side and returns back the "peft config` ?