diff --git a/caikit_nlp/modules/text_generation/peft_config.py b/caikit_nlp/modules/text_generation/peft_config.py new file mode 100644 index 00000000..3ec0ff61 --- /dev/null +++ b/caikit_nlp/modules/text_generation/peft_config.py @@ -0,0 +1,197 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +import os + +from caikit import get_config + +from ...resources.pretrained_model import ( + PretrainedModelBase, +) + +from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer +from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype + +from peft import ( + MultitaskPromptTuningConfig, + MultitaskPromptTuningInit, +) + +from ...data_model import ( + PromptOutputModelType, +) +from transformers import ( + AutoConfig +) + +# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK` +# since those are for experimental use and would not be useful / applicable +# for end-user use-cases +allowed_tuning_init_methods = [ + "TEXT", + "RANDOM", + "ONLY_SOURCE_SHARED", + "AVERAGE_SOURCE_TASKS", +] + + +class TuningType(str, Enum): + PROMPT_TUNING = "PROMPT_TUNING" + MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" + # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" + # P_TUNING = "P_TUNING" + # PREFIX_TUNING = "PREFIX_TUNING" + # LORA = "LORA" + + +def validate_peft_config(tuning_type, + tuning_config, + error, + log, + base_model, + cls, + torch_dtype, + verbalizer): + # TODO: Move all of the validation into a separate function + + if tuning_type not in TuningType._member_names_: + raise NotImplementedError( + "{} tuning type not supported!".format(tuning_type) + ) + + if tuning_config.prompt_tuning_init_method: + # NOTE: GK-APR-5-2023 + # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the + # time of writing, which is a superset of PromptTuningInit + init_method = tuning_config.prompt_tuning_init_method + + error.value_check( + "", + init_method in allowed_tuning_init_methods, + f"Init method [{init_method}] not in allowed init methods: " + f"[{allowed_tuning_init_methods}]", + ) + + init_method = MultitaskPromptTuningInit(init_method) + log.info("Using initialization method [%s]", init_method) + + # If init method provided relates to one that requires source model, + # make sure the source prompt model is provided. + if init_method in [ + MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, + MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, + ]: + # NOTE: prompt_tuning_init_source_model is currently a path. In future + # we will replace this with caikit.resources to properly cataloging these + error.type_check( + "", + str, + prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model, + ) + tuning_config.prompt_tuning_init_source_model = os.path.join( + get_config().source_prompt_base, + tuning_config.prompt_tuning_init_source_model, + ) + + error.file_check( + "", tuning_config.prompt_tuning_init_source_model + ) + log.debug( + "Validated tuning source prompt [%s]", + tuning_config.prompt_tuning_init_source_model, + ) + + if isinstance(base_model, str): + model_config = AutoConfig.from_pretrained( + base_model, local_files_only=not get_config().allow_downloads + ) + + resource_type = None + for resource in cls.supported_resources: + if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: + resource_type = resource + break + + if not resource_type: + error( + "", + "{} model type is not supported currently!".format( + model_config.model_type + ), + ) + log.debug("Bootstrapping base resource [%s]", base_model) + base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) + error.type_check("", PretrainedModelBase, base_model=base_model) + + # Validate if tuned output model type is compatible with base model or not + if not tuning_config.output_model_types: + output_model_types = base_model.PROMPT_OUTPUT_TYPES + else: + # If the first element is not PromptOutputModelType, assume the entire list + # isn't and convert + if not isinstance( + tuning_config.output_model_types[0], PromptOutputModelType + ): + output_model_types = [] + for output_type in tuning_config.output_model_types: + output_model_types.append(PromptOutputModelType(output_type)) + else: + output_model_types = tuning_config.output_model_types + error.value_check( + "", + all( + output_type in base_model.PROMPT_OUTPUT_TYPES + for output_type in output_model_types + ), + "{} not supported for base model type {}".format( + output_model_types, base_model.MODEL_TYPE + ), + ) + + error.value_check( + "", + len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, + f"Too many output model types. Got {len(output_model_types)}, " + f"maximum {base_model.MAX_NUM_TRANSFORMERS}", + ) + # Ensure that our verbalizer is a string and will not render to a hardcoded string + error.value_check( + "", + is_valid_verbalizer(verbalizer), + "Provided verbalizer is an invalid type or has no renderable placeholders", + ) + + # NOTE: Base model is a resource at this point + task_type = base_model.TASK_TYPE + + + if isinstance(tuning_type, str): + error.value_check( + "", + tuning_type in TuningType._member_names_, + f"Invalid tuning type [{tuning_type}]. Allowed types: " + f"[{TuningType._member_names_}]", + ) + tuning_type = TuningType(tuning_type) + error.type_check("", TuningType, tuning_type=tuning_type) + + # Coerce the passed model into a resource; if we have one, this is a noop + # TODO: When splitting up this mono-module, use the configured resource + # type of the concrete class to bootstrap + torch_dtype = get_torch_dtype(torch_dtype) + + return task_type, output_model_types + diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index c1b4ca53..e36147a0 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -13,7 +13,7 @@ # limitations under the License. """This module contains prompt tuning through PEFT""" # Standard -from enum import Enum + from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import gc import json @@ -35,7 +35,6 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( - AutoConfig, AutoModelForCausalLM, DataCollatorForLanguageModeling, default_data_collator, @@ -45,7 +44,7 @@ import torch # First Party -from caikit import get_config + from caikit.core.data_model import DataStream from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.toolkit import error_handler @@ -79,30 +78,11 @@ ) from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer +from peft_config import TuningType, allowed_tuning_init_methods, validate_peft_config + log = alog.use_channel("PEFT_PROMPT") error = error_handler.get(log) - -# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK` -# since those are for experimental use and would not be useful / applicable -# for end-user use-cases -allowed_tuning_init_methods = [ - "TEXT", - "RANDOM", - "ONLY_SOURCE_SHARED", - "AVERAGE_SOURCE_TASKS", -] - - -class TuningType(str, Enum): - PROMPT_TUNING = "PROMPT_TUNING" - MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" - # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" - # P_TUNING = "P_TUNING" - # PREFIX_TUNING = "PREFIX_TUNING" - # LORA = "LORA" - - # TODO: try to refactor this into a smaller module # pylint: disable=too-many-lines,too-many-instance-attributes @module( @@ -362,133 +342,10 @@ def train( Instance of this class with tuned prompt vectors. """ - # TODO: Move all of the validation into a separate function - - if tuning_type not in TuningType._member_names_: - raise NotImplementedError( - "{} tuning type not supported!".format(tuning_type) - ) - - if tuning_config.prompt_tuning_init_method: - # NOTE: GK-APR-5-2023 - # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the - # time of writing, which is a superset of PromptTuningInit - init_method = tuning_config.prompt_tuning_init_method - - error.value_check( - "", - init_method in allowed_tuning_init_methods, - f"Init method [{init_method}] not in allowed init methods: " - f"[{allowed_tuning_init_methods}]", - ) - - init_method = MultitaskPromptTuningInit(init_method) - log.info("Using initialization method [%s]", init_method) - - # If init method provided relates to one that requires source model, - # make sure the source prompt model is provided. - if init_method in [ - MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, - MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, - ]: - # NOTE: prompt_tuning_init_source_model is currently a path. In future - # we will replace this with caikit.resources to properly cataloging these - error.type_check( - "", - str, - prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model, - ) - tuning_config.prompt_tuning_init_source_model = os.path.join( - get_config().source_prompt_base, - tuning_config.prompt_tuning_init_source_model, - ) - - error.file_check( - "", tuning_config.prompt_tuning_init_source_model - ) - log.debug( - "Validated tuning source prompt [%s]", - tuning_config.prompt_tuning_init_source_model, - ) - - # Coerce the passed model into a resource; if we have one, this is a noop - # TODO: When splitting up this mono-module, use the configured resource - # type of the concrete class to bootstrap - torch_dtype = get_torch_dtype(torch_dtype) - if isinstance(base_model, str): - model_config = AutoConfig.from_pretrained( - base_model, local_files_only=not get_config().allow_downloads - ) - - resource_type = None - for resource in cls.supported_resources: - if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: - resource_type = resource - break - - if not resource_type: - error( - "", - "{} model type is not supported currently!".format( - model_config.model_type - ), - ) - log.debug("Bootstrapping base resource [%s]", base_model) - base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) - error.type_check("", PretrainedModelBase, base_model=base_model) - - # Validate if tuned output model type is compatible with base model or not - if not tuning_config.output_model_types: - output_model_types = base_model.PROMPT_OUTPUT_TYPES - else: - # If the first element is not PromptOutputModelType, assume the entire list - # isn't and convert - if not isinstance( - tuning_config.output_model_types[0], PromptOutputModelType - ): - output_model_types = [] - for output_type in tuning_config.output_model_types: - output_model_types.append(PromptOutputModelType(output_type)) - else: - output_model_types = tuning_config.output_model_types - error.value_check( - "", - all( - output_type in base_model.PROMPT_OUTPUT_TYPES - for output_type in output_model_types - ), - "{} not supported for base model type {}".format( - output_model_types, base_model.MODEL_TYPE - ), - ) - - error.value_check( - "", - len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, - f"Too many output model types. Got {len(output_model_types)}, " - f"maximum {base_model.MAX_NUM_TRANSFORMERS}", - ) - # Ensure that our verbalizer is a string and will not render to a hardcoded string - error.value_check( - "", - is_valid_verbalizer(verbalizer), - "Provided verbalizer is an invalid type or has no renderable placeholders", - ) - - # NOTE: Base model is a resource at this point - task_type = base_model.TASK_TYPE - # HACK - These things can't be passed through the train API currently metric = kwargs.get("metric") - if isinstance(tuning_type, str): - error.value_check( - "", - tuning_type in TuningType._member_names_, - f"Invalid tuning type [{tuning_type}]. Allowed types: " - f"[{TuningType._member_names_}]", - ) - tuning_type = TuningType(tuning_type) - error.type_check("", TuningType, tuning_type=tuning_type) + task_type, output_model_types = validate_peft_config(tuning_type, tuning_config, error, log, base_model, cls, torch_dtype, verbalizer) + train_stream = train_stream.map(convert_to_generation_record) if val_stream: