diff --git a/caikit_nlp/modules/text_generation/peft_config.py b/caikit_nlp/modules/text_generation/peft_config.py new file mode 100644 index 00000000..90883b3d --- /dev/null +++ b/caikit_nlp/modules/text_generation/peft_config.py @@ -0,0 +1,201 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from enum import Enum +import os + +# Third Party +from peft import MultitaskPromptTuningInit +from transformers import AutoConfig + +# First Party +from caikit import get_config +import alog + +# Local +from ...data_model import PromptOutputModelType +from ...resources.pretrained_model import PretrainedModelBase +from ...toolkit.data_type_utils import get_torch_dtype +from ...toolkit.verbalizer_utils import is_valid_verbalizer + +# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK` +# since those are for experimental use and would not be useful / applicable +# for end-user use-cases +allowed_tuning_init_methods = [ + "TEXT", + "RANDOM", + "ONLY_SOURCE_SHARED", + "AVERAGE_SOURCE_TASKS", +] + +log = alog.use_channel("PFT_CNFG_TLKT") + + +class TuningType(str, Enum): + PROMPT_TUNING = "PROMPT_TUNING" + MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" + # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" + # P_TUNING = "P_TUNING" + # PREFIX_TUNING = "PREFIX_TUNING" + # LORA = "LORA" + + +def resolve_base_model(base_model, cls, error, torch_dtype): + if isinstance(base_model, str): + model_config = AutoConfig.from_pretrained( + base_model, local_files_only=not get_config().allow_downloads + ) + + resource_type = None + for resource in cls.supported_resources: + if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: + resource_type = resource + break + + if not resource_type: + error( + "", + "{} model type is not supported currently!".format( + model_config.model_type + ), + ) + log.debug("Bootstrapping base resource [%s]", base_model) + base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) + return base_model + + +def get_peft_config( + tuning_type, tuning_config, error, base_model, cls, torch_dtype, verbalizer +): + + if tuning_type not in TuningType._member_names_: + raise NotImplementedError("{} tuning type not supported!".format(tuning_type)) + + if tuning_config.prompt_tuning_init_method: + # NOTE: GK-APR-5-2023 + # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the + # time of writing, which is a superset of PromptTuningInit + init_method = tuning_config.prompt_tuning_init_method + + error.value_check( + "", + init_method in allowed_tuning_init_methods, + f"Init method [{init_method}] not in allowed init methods: " + f"[{allowed_tuning_init_methods}]", + ) + + init_method = MultitaskPromptTuningInit(init_method) + log.info("Using initialization method [%s]", init_method) + + # If init method provided relates to one that requires source model, + # make sure the source prompt model is provided. + if init_method in [ + MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, + MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, + ]: + # NOTE: prompt_tuning_init_source_model is currently a path. In future + # we will replace this with caikit.resources to properly cataloging these + error.type_check( + "", + str, + prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model, + ) + tuning_config.prompt_tuning_init_source_model = os.path.join( + get_config().source_prompt_base, + tuning_config.prompt_tuning_init_source_model, + ) + + error.file_check( + "", tuning_config.prompt_tuning_init_source_model + ) + log.debug( + "Validated tuning source prompt [%s]", + tuning_config.prompt_tuning_init_source_model, + ) + + error.type_check("", PretrainedModelBase, base_model=base_model) + + # Validate if tuned output model type is compatible with base model or not + if not tuning_config.output_model_types: + output_model_types = base_model.PROMPT_OUTPUT_TYPES + else: + # If the first element is not PromptOutputModelType, assume the entire list + # isn't and convert + if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType): + output_model_types = [] + for output_type in tuning_config.output_model_types: + output_model_types.append(PromptOutputModelType(output_type)) + else: + output_model_types = tuning_config.output_model_types + error.value_check( + "", + all( + output_type in base_model.PROMPT_OUTPUT_TYPES + for output_type in output_model_types + ), + "{} not supported for base model type {}".format( + output_model_types, base_model.MODEL_TYPE + ), + ) + + error.value_check( + "", + len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, + f"Too many output model types. Got {len(output_model_types)}, " + f"maximum {base_model.MAX_NUM_TRANSFORMERS}", + ) + # Ensure that our verbalizer is a string and will not render to a hardcoded string + error.value_check( + "", + is_valid_verbalizer(verbalizer), + "Provided verbalizer is an invalid type or has no renderable placeholders", + ) + + # NOTE: Base model is a resource at this point + task_type = base_model.TASK_TYPE + + if isinstance(tuning_type, str): + error.value_check( + "", + tuning_type in TuningType._member_names_, + f"Invalid tuning type [{tuning_type}]. Allowed types: " + f"[{TuningType._member_names_}]", + ) + tuning_type = TuningType(tuning_type) + error.type_check("", TuningType, tuning_type=tuning_type) + + # Coerce the passed model into a resource; if we have one, this is a noop + # TODO: When splitting up this mono-module, use the configured resource + # type of the concrete class to bootstrap + torch_dtype = get_torch_dtype(torch_dtype) + + # Take tokenizer name/path from the model + tokenizer_name_or_path = base_model.model.config._name_or_path + + # Build the peft config; this is how we determine that we want a sequence classifier. + # If we want more types, we will likely need to map this to data model outputs etc. + + # NOTE: We currently only support TEXT as init type, this is to later only easily + # switch to MPT + peft_config = cls.create_hf_tuning_config( + base_model=base_model, + tuning_type=tuning_type, + task_type=task_type, + tokenizer_name_or_path=tokenizer_name_or_path, + tuning_config=tuning_config, + output_model_types=output_model_types, + ) + + return task_type, output_model_types, peft_config, tuning_type diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index e0f64077..c71da2a1 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -13,7 +13,8 @@ # limitations under the License. """This module contains prompt tuning through PEFT""" # Standard -from enum import Enum + +# Standard from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import gc import json @@ -23,7 +24,6 @@ from accelerate import Accelerator from peft import ( MultitaskPromptTuningConfig, - MultitaskPromptTuningInit, PeftConfig, PeftModel, PeftType, @@ -35,7 +35,6 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( - AutoConfig, AutoModelForCausalLM, DataCollatorForLanguageModeling, default_data_collator, @@ -46,7 +45,6 @@ import torch # First Party -from caikit import get_config from caikit.core.data_model import DataStream from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.toolkit import error_handler @@ -78,32 +76,12 @@ generate_text_func, generate_text_func_stream, ) -from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer +from ...toolkit.verbalizer_utils import render_verbalizer +from .peft_config import TuningType, get_peft_config, resolve_base_model log = alog.use_channel("PEFT_PROMPT") error = error_handler.get(log) - -# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK` -# since those are for experimental use and would not be useful / applicable -# for end-user use-cases -allowed_tuning_init_methods = [ - "TEXT", - "RANDOM", - "ONLY_SOURCE_SHARED", - "AVERAGE_SOURCE_TASKS", -] - - -class TuningType(str, Enum): - PROMPT_TUNING = "PROMPT_TUNING" - MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" - # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" - # P_TUNING = "P_TUNING" - # PREFIX_TUNING = "PREFIX_TUNING" - # LORA = "LORA" - - # TODO: try to refactor this into a smaller module # pylint: disable=too-many-lines,too-many-instance-attributes @module( @@ -365,133 +343,26 @@ def train( Instance of this class with tuned prompt vectors. """ - # TODO: Move all of the validation into a separate function - - if tuning_type not in TuningType._member_names_: - raise NotImplementedError( - "{} tuning type not supported!".format(tuning_type) - ) - - if tuning_config.prompt_tuning_init_method: - # NOTE: GK-APR-5-2023 - # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the - # time of writing, which is a superset of PromptTuningInit - init_method = tuning_config.prompt_tuning_init_method - - error.value_check( - "", - init_method in allowed_tuning_init_methods, - f"Init method [{init_method}] not in allowed init methods: " - f"[{allowed_tuning_init_methods}]", - ) + # HACK - These things can't be passed through the train API currently - init_method = MultitaskPromptTuningInit(init_method) - log.info("Using initialization method [%s]", init_method) - - # If init method provided relates to one that requires source model, - # make sure the source prompt model is provided. - if init_method in [ - MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, - MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, - ]: - # NOTE: prompt_tuning_init_source_model is currently a path. In future - # we will replace this with caikit.resources to properly cataloging these - error.type_check( - "", - str, - prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model, - ) - tuning_config.prompt_tuning_init_source_model = os.path.join( - get_config().source_prompt_base, - tuning_config.prompt_tuning_init_source_model, - ) + metric = kwargs.get("metric") - error.file_check( - "", tuning_config.prompt_tuning_init_source_model - ) - log.debug( - "Validated tuning source prompt [%s]", - tuning_config.prompt_tuning_init_source_model, - ) + base_model = resolve_base_model(base_model, cls, error, torch_dtype) + base_model_name = base_model._model_name + task_type, output_model_types, peft_config, tuning_type = get_peft_config( + tuning_type, + tuning_config, + error, + base_model, + cls, + torch_dtype, + verbalizer, + ) # Coerce the passed model into a resource; if we have one, this is a noop # TODO: When splitting up this mono-module, use the configured resource # type of the concrete class to bootstrap torch_dtype = get_torch_dtype(torch_dtype) - if isinstance(base_model, str): - model_config = AutoConfig.from_pretrained( - base_model, local_files_only=not get_config().allow_downloads - ) - - resource_type = None - for resource in cls.supported_resources: - if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: - resource_type = resource - break - - if not resource_type: - error( - "", - "{} model type is not supported currently!".format( - model_config.model_type - ), - ) - log.debug("Bootstrapping base resource [%s]", base_model) - base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) - error.type_check("", PretrainedModelBase, base_model=base_model) - - # Validate if tuned output model type is compatible with base model or not - if not tuning_config.output_model_types: - output_model_types = base_model.PROMPT_OUTPUT_TYPES - else: - # If the first element is not PromptOutputModelType, assume the entire list - # isn't and convert - if not isinstance( - tuning_config.output_model_types[0], PromptOutputModelType - ): - output_model_types = [] - for output_type in tuning_config.output_model_types: - output_model_types.append(PromptOutputModelType(output_type)) - else: - output_model_types = tuning_config.output_model_types - error.value_check( - "", - all( - output_type in base_model.PROMPT_OUTPUT_TYPES - for output_type in output_model_types - ), - "{} not supported for base model type {}".format( - output_model_types, base_model.MODEL_TYPE - ), - ) - - error.value_check( - "", - len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, - f"Too many output model types. Got {len(output_model_types)}, " - f"maximum {base_model.MAX_NUM_TRANSFORMERS}", - ) - # Ensure that our verbalizer is a string and will not render to a hardcoded string - error.value_check( - "", - is_valid_verbalizer(verbalizer), - "Provided verbalizer is an invalid type or has no renderable placeholders", - ) - - # NOTE: Base model is a resource at this point - task_type = base_model.TASK_TYPE - - # HACK - These things can't be passed through the train API currently - metric = kwargs.get("metric") - if isinstance(tuning_type, str): - error.value_check( - "", - tuning_type in TuningType._member_names_, - f"Invalid tuning type [{tuning_type}]. Allowed types: " - f"[{TuningType._member_names_}]", - ) - tuning_type = TuningType(tuning_type) - error.type_check("", TuningType, tuning_type=tuning_type) train_stream = train_stream.map(convert_to_generation_record) if val_stream: @@ -509,24 +380,6 @@ def train( max_target_length=max_target_length, ) - base_model_name = base_model._model_name - - # Take tokenizer name/path from the model - tokenizer_name_or_path = base_model.model.config._name_or_path - - # Build the peft config; this is how we determine that we want a sequence classifier. - # If we want more types, we will likely need to map this to data model outputs etc. - - # NOTE: We currently only support TEXT as init type, this is to later only easily - # switch to MPT - peft_config = cls.create_hf_tuning_config( - base_model=base_model, - tuning_type=tuning_type, - task_type=task_type, - tokenizer_name_or_path=tokenizer_name_or_path, - tuning_config=tuning_config, - output_model_types=output_model_types, - ) log.debug("Peft config [%s]", peft_config) # FIXME: Should only do following line for causal LM (and bloomz?) - check that is the case if isinstance(base_model, HFAutoCausalLM): diff --git a/tests/modules/text_generation/test_peft_config.py b/tests/modules/text_generation/test_peft_config.py new file mode 100644 index 00000000..6eeb513b --- /dev/null +++ b/tests/modules/text_generation/test_peft_config.py @@ -0,0 +1,99 @@ +# Standard +from unittest.mock import Mock + +# Third Party +import pytest + +# Local +from caikit_nlp.data_model import PromptOutputModelType +from caikit_nlp.modules.text_generation.peft_config import TuningType, get_peft_config + + +@pytest.fixture +def mock_error(): + # Create a mock error object with the expected behavior + error = Mock() + error.value_check.side_effect = ( + lambda code, condition, message: None if condition else error(code, message) + ) + return error + + +@pytest.fixture +def mock_base_model(): + base_model = Mock() + base_model.PROMPT_OUTPUT_TYPES = [ + PromptOutputModelType.ENCODER, + PromptOutputModelType.DECODER, + ] + base_model.MAX_NUM_TRANSFORMERS = 2 + return base_model + + +@pytest.fixture +def mock_cls(): + return Mock() + + +@pytest.fixture +def mock_torch_dtype(): + return Mock() + + +@pytest.fixture +def mock_verbalizer(): + return Mock() + + +@pytest.fixture +def mock_tuning_config(): + # Create a mock tuning_config with a list of output_model_types + tuning_config = Mock( + prompt_tuning_init_method="TEXT", prompt_tuning_init_source_model="source_model" + ) + tuning_config.output_model_types = [ + PromptOutputModelType.ENCODER, + PromptOutputModelType.DECODER, + ] + + return tuning_config + + +def test_get_peft_config( + mock_error, + mock_base_model, + mock_cls, + mock_torch_dtype, + mock_verbalizer, + mock_tuning_config, +): + # Define some sample values for testing + tuning_type = TuningType.PROMPT_TUNING + + output_model_types = [PromptOutputModelType.DECODER] + + # Call the function being tested + task_type, output_model_types, peft_config, tuning_type = get_peft_config( + tuning_type, + mock_tuning_config, + mock_error, + mock_base_model, + mock_cls, + "float32", + mock_verbalizer, + ) + + # Add assertions to validate the behavior of the function + assert task_type == mock_base_model.TASK_TYPE + assert output_model_types == mock_tuning_config.output_model_types + assert peft_config == mock_cls.create_hf_tuning_config.return_value + assert tuning_type == TuningType.PROMPT_TUNING + + mock_cls.create_hf_tuning_config.assert_called_once_with( + base_model=mock_base_model, + tuning_type=TuningType.PROMPT_TUNING, + task_type=mock_base_model.TASK_TYPE, + tokenizer_name_or_path=mock_base_model.model.config._name_or_path, + tuning_config=mock_tuning_config, + output_model_types=mock_tuning_config.output_model_types, + )