Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[caikit-nlp-163] Refactor peft module to take out common peft config functionality #174

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions caikit_nlp/modules/text_generation/peft_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum

import os

from caikit import get_config

from ...resources.pretrained_model import (
PretrainedModelBase,
)

from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer
from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype

from peft import (
MultitaskPromptTuningConfig,
MultitaskPromptTuningInit,
)

from ...data_model import (
PromptOutputModelType,
)
from transformers import (
AutoConfig
)

# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK`
# since those are for experimental use and would not be useful / applicable
# for end-user use-cases
allowed_tuning_init_methods = [
"TEXT",
"RANDOM",
"ONLY_SOURCE_SHARED",
"AVERAGE_SOURCE_TASKS",
]


class TuningType(str, Enum):
PROMPT_TUNING = "PROMPT_TUNING"
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
# P_TUNING = "P_TUNING"
# PREFIX_TUNING = "PREFIX_TUNING"
# LORA = "LORA"


def validate_peft_config(tuning_type,
tuning_config,
error,
log,
base_model,
cls,
torch_dtype,
verbalizer):
# TODO: Move all of the validation into a separate function

if tuning_type not in TuningType._member_names_:
raise NotImplementedError(
"{} tuning type not supported!".format(tuning_type)
)

if tuning_config.prompt_tuning_init_method:
# NOTE: GK-APR-5-2023
# MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the
# time of writing, which is a superset of PromptTuningInit
init_method = tuning_config.prompt_tuning_init_method

error.value_check(
"<NLP11848053E>",
init_method in allowed_tuning_init_methods,
f"Init method [{init_method}] not in allowed init methods: "
f"[{allowed_tuning_init_methods}]",
)

init_method = MultitaskPromptTuningInit(init_method)
log.info("Using initialization method [%s]", init_method)

# If init method provided relates to one that requires source model,
# make sure the source prompt model is provided.
if init_method in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
]:
# NOTE: prompt_tuning_init_source_model is currently a path. In future
# we will replace this with caikit.resources to properly cataloging these
error.type_check(
"<NLP89108490E>",
str,
prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model,
)
tuning_config.prompt_tuning_init_source_model = os.path.join(
get_config().source_prompt_base,
tuning_config.prompt_tuning_init_source_model,
)

error.file_check(
"<NLP96030210E>", tuning_config.prompt_tuning_init_source_model
)
log.debug(
"Validated tuning source prompt [%s]",
tuning_config.prompt_tuning_init_source_model,
)

if isinstance(base_model, str):
model_config = AutoConfig.from_pretrained(
base_model, local_files_only=not get_config().allow_downloads
)

resource_type = None
for resource in cls.supported_resources:
if model_config.model_type in resource.SUPPORTED_MODEL_TYPES:
resource_type = resource
break

if not resource_type:
error(
"<NLP61784225E>",
"{} model type is not supported currently!".format(
model_config.model_type
),
)
log.debug("Bootstrapping base resource [%s]", base_model)
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)

# Validate if tuned output model type is compatible with base model or not
if not tuning_config.output_model_types:
output_model_types = base_model.PROMPT_OUTPUT_TYPES
else:
# If the first element is not PromptOutputModelType, assume the entire list
# isn't and convert
if not isinstance(
tuning_config.output_model_types[0], PromptOutputModelType
):
output_model_types = []
for output_type in tuning_config.output_model_types:
output_model_types.append(PromptOutputModelType(output_type))
else:
output_model_types = tuning_config.output_model_types
error.value_check(
"<NLP36947542E>",
all(
output_type in base_model.PROMPT_OUTPUT_TYPES
for output_type in output_model_types
),
"{} not supported for base model type {}".format(
output_model_types, base_model.MODEL_TYPE
),
)

error.value_check(
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
"<NLP83837412E>",
is_valid_verbalizer(verbalizer),
"Provided verbalizer is an invalid type or has no renderable placeholders",
)

# NOTE: Base model is a resource at this point
task_type = base_model.TASK_TYPE


if isinstance(tuning_type, str):
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
torch_dtype = get_torch_dtype(torch_dtype)

return task_type, output_model_types

155 changes: 6 additions & 149 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Copy link
Collaborator

@gkumbhat gkumbhat Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rawkintrevo for quick turn-around. I think 1 other thing that would move would be https://github.com/caikit/caikit-nlp/pull/174/files#diff-1cb191003903163320c02f8ffaf7c5edd48ca6649cd3092d1b4c3a0fdcd038c1R376 function, since that returns the peft_config which is then passed to get the peft_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""This module contains prompt tuning through PEFT"""
# Standard
from enum import Enum

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import gc
import json
Expand All @@ -35,7 +35,6 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
default_data_collator,
Expand All @@ -45,7 +44,7 @@
import torch

# First Party
from caikit import get_config

from caikit.core.data_model import DataStream
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.core.toolkit import error_handler
Expand Down Expand Up @@ -79,30 +78,11 @@
)
from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer

from peft_config import TuningType, allowed_tuning_init_methods, validate_peft_config

log = alog.use_channel("PEFT_PROMPT")
error = error_handler.get(log)


# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK`
# since those are for experimental use and would not be useful / applicable
# for end-user use-cases
allowed_tuning_init_methods = [
"TEXT",
"RANDOM",
"ONLY_SOURCE_SHARED",
"AVERAGE_SOURCE_TASKS",
]


class TuningType(str, Enum):
PROMPT_TUNING = "PROMPT_TUNING"
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
# P_TUNING = "P_TUNING"
# PREFIX_TUNING = "PREFIX_TUNING"
# LORA = "LORA"


# TODO: try to refactor this into a smaller module
# pylint: disable=too-many-lines,too-many-instance-attributes
@module(
Expand Down Expand Up @@ -362,133 +342,10 @@ def train(
Instance of this class with tuned prompt vectors.
"""

# TODO: Move all of the validation into a separate function

if tuning_type not in TuningType._member_names_:
raise NotImplementedError(
"{} tuning type not supported!".format(tuning_type)
)

if tuning_config.prompt_tuning_init_method:
# NOTE: GK-APR-5-2023
# MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the
# time of writing, which is a superset of PromptTuningInit
init_method = tuning_config.prompt_tuning_init_method

error.value_check(
"<NLP11848053E>",
init_method in allowed_tuning_init_methods,
f"Init method [{init_method}] not in allowed init methods: "
f"[{allowed_tuning_init_methods}]",
)

init_method = MultitaskPromptTuningInit(init_method)
log.info("Using initialization method [%s]", init_method)

# If init method provided relates to one that requires source model,
# make sure the source prompt model is provided.
if init_method in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
]:
# NOTE: prompt_tuning_init_source_model is currently a path. In future
# we will replace this with caikit.resources to properly cataloging these
error.type_check(
"<NLP89108490E>",
str,
prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model,
)
tuning_config.prompt_tuning_init_source_model = os.path.join(
get_config().source_prompt_base,
tuning_config.prompt_tuning_init_source_model,
)

error.file_check(
"<NLP96030210E>", tuning_config.prompt_tuning_init_source_model
)
log.debug(
"Validated tuning source prompt [%s]",
tuning_config.prompt_tuning_init_source_model,
)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
torch_dtype = get_torch_dtype(torch_dtype)
if isinstance(base_model, str):
model_config = AutoConfig.from_pretrained(
base_model, local_files_only=not get_config().allow_downloads
)

resource_type = None
for resource in cls.supported_resources:
if model_config.model_type in resource.SUPPORTED_MODEL_TYPES:
resource_type = resource
break

if not resource_type:
error(
"<NLP61784225E>",
"{} model type is not supported currently!".format(
model_config.model_type
),
)
log.debug("Bootstrapping base resource [%s]", base_model)
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)

# Validate if tuned output model type is compatible with base model or not
if not tuning_config.output_model_types:
output_model_types = base_model.PROMPT_OUTPUT_TYPES
else:
# If the first element is not PromptOutputModelType, assume the entire list
# isn't and convert
if not isinstance(
tuning_config.output_model_types[0], PromptOutputModelType
):
output_model_types = []
for output_type in tuning_config.output_model_types:
output_model_types.append(PromptOutputModelType(output_type))
else:
output_model_types = tuning_config.output_model_types
error.value_check(
"<NLP36947542E>",
all(
output_type in base_model.PROMPT_OUTPUT_TYPES
for output_type in output_model_types
),
"{} not supported for base model type {}".format(
output_model_types, base_model.MODEL_TYPE
),
)

error.value_check(
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
"<NLP83837412E>",
is_valid_verbalizer(verbalizer),
"Provided verbalizer is an invalid type or has no renderable placeholders",
)

# NOTE: Base model is a resource at this point
task_type = base_model.TASK_TYPE

# HACK - These things can't be passed through the train API currently
metric = kwargs.get("metric")
if isinstance(tuning_type, str):
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)
task_type, output_model_types = validate_peft_config(tuning_type, tuning_config, error, log, base_model, cls, torch_dtype, verbalizer)


train_stream = train_stream.map(convert_to_generation_record)
if val_stream:
Expand Down
Loading