Skip to content

Commit

Permalink
we'll see if this defeats DCO
Browse files Browse the repository at this point in the history
Signed-off-by: Trevor Grant <[email protected]>
  • Loading branch information
rawkintrevo committed Sep 5, 2023
1 parent d0c18d7 commit ddc2587
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 149 deletions.
197 changes: 197 additions & 0 deletions caikit_nlp/modules/text_generation/peft_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum

import os

from caikit import get_config

from ...resources.pretrained_model import (
PretrainedModelBase,
)

from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer
from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype

from peft import (
MultitaskPromptTuningConfig,
MultitaskPromptTuningInit,
)

from ...data_model import (
PromptOutputModelType,
)
from transformers import (
AutoConfig
)

# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK`
# since those are for experimental use and would not be useful / applicable
# for end-user use-cases
allowed_tuning_init_methods = [
"TEXT",
"RANDOM",
"ONLY_SOURCE_SHARED",
"AVERAGE_SOURCE_TASKS",
]


class TuningType(str, Enum):
PROMPT_TUNING = "PROMPT_TUNING"
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
# P_TUNING = "P_TUNING"
# PREFIX_TUNING = "PREFIX_TUNING"
# LORA = "LORA"


def validate_peft_config(tuning_type,
tuning_config,
error,
log,
base_model,
cls,
torch_dtype,
verbalizer):
# TODO: Move all of the validation into a separate function

if tuning_type not in TuningType._member_names_:
raise NotImplementedError(
"{} tuning type not supported!".format(tuning_type)
)

if tuning_config.prompt_tuning_init_method:
# NOTE: GK-APR-5-2023
# MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the
# time of writing, which is a superset of PromptTuningInit
init_method = tuning_config.prompt_tuning_init_method

error.value_check(
"<NLP11848053E>",
init_method in allowed_tuning_init_methods,
f"Init method [{init_method}] not in allowed init methods: "
f"[{allowed_tuning_init_methods}]",
)

init_method = MultitaskPromptTuningInit(init_method)
log.info("Using initialization method [%s]", init_method)

# If init method provided relates to one that requires source model,
# make sure the source prompt model is provided.
if init_method in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
]:
# NOTE: prompt_tuning_init_source_model is currently a path. In future
# we will replace this with caikit.resources to properly cataloging these
error.type_check(
"<NLP89108490E>",
str,
prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model,
)
tuning_config.prompt_tuning_init_source_model = os.path.join(
get_config().source_prompt_base,
tuning_config.prompt_tuning_init_source_model,
)

error.file_check(
"<NLP96030210E>", tuning_config.prompt_tuning_init_source_model
)
log.debug(
"Validated tuning source prompt [%s]",
tuning_config.prompt_tuning_init_source_model,
)

if isinstance(base_model, str):
model_config = AutoConfig.from_pretrained(
base_model, local_files_only=not get_config().allow_downloads
)

resource_type = None
for resource in cls.supported_resources:
if model_config.model_type in resource.SUPPORTED_MODEL_TYPES:
resource_type = resource
break

if not resource_type:
error(
"<NLP61784225E>",
"{} model type is not supported currently!".format(
model_config.model_type
),
)
log.debug("Bootstrapping base resource [%s]", base_model)
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)

# Validate if tuned output model type is compatible with base model or not
if not tuning_config.output_model_types:
output_model_types = base_model.PROMPT_OUTPUT_TYPES
else:
# If the first element is not PromptOutputModelType, assume the entire list
# isn't and convert
if not isinstance(
tuning_config.output_model_types[0], PromptOutputModelType
):
output_model_types = []
for output_type in tuning_config.output_model_types:
output_model_types.append(PromptOutputModelType(output_type))
else:
output_model_types = tuning_config.output_model_types
error.value_check(
"<NLP36947542E>",
all(
output_type in base_model.PROMPT_OUTPUT_TYPES
for output_type in output_model_types
),
"{} not supported for base model type {}".format(
output_model_types, base_model.MODEL_TYPE
),
)

error.value_check(
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
"<NLP83837412E>",
is_valid_verbalizer(verbalizer),
"Provided verbalizer is an invalid type or has no renderable placeholders",
)

# NOTE: Base model is a resource at this point
task_type = base_model.TASK_TYPE


if isinstance(tuning_type, str):
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
torch_dtype = get_torch_dtype(torch_dtype)

return task_type, output_model_types

155 changes: 6 additions & 149 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""This module contains prompt tuning through PEFT"""
# Standard
from enum import Enum

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import gc
import json
Expand All @@ -35,7 +35,6 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
default_data_collator,
Expand All @@ -45,7 +44,7 @@
import torch

# First Party
from caikit import get_config

from caikit.core.data_model import DataStream
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.core.toolkit import error_handler
Expand Down Expand Up @@ -79,30 +78,11 @@
)
from ...toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer

from peft_config import TuningType, allowed_tuning_init_methods, validate_peft_config

log = alog.use_channel("PEFT_PROMPT")
error = error_handler.get(log)


# NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK`
# since those are for experimental use and would not be useful / applicable
# for end-user use-cases
allowed_tuning_init_methods = [
"TEXT",
"RANDOM",
"ONLY_SOURCE_SHARED",
"AVERAGE_SOURCE_TASKS",
]


class TuningType(str, Enum):
PROMPT_TUNING = "PROMPT_TUNING"
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
# MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
# P_TUNING = "P_TUNING"
# PREFIX_TUNING = "PREFIX_TUNING"
# LORA = "LORA"


# TODO: try to refactor this into a smaller module
# pylint: disable=too-many-lines,too-many-instance-attributes
@module(
Expand Down Expand Up @@ -362,133 +342,10 @@ def train(
Instance of this class with tuned prompt vectors.
"""

# TODO: Move all of the validation into a separate function

if tuning_type not in TuningType._member_names_:
raise NotImplementedError(
"{} tuning type not supported!".format(tuning_type)
)

if tuning_config.prompt_tuning_init_method:
# NOTE: GK-APR-5-2023
# MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the
# time of writing, which is a superset of PromptTuningInit
init_method = tuning_config.prompt_tuning_init_method

error.value_check(
"<NLP11848053E>",
init_method in allowed_tuning_init_methods,
f"Init method [{init_method}] not in allowed init methods: "
f"[{allowed_tuning_init_methods}]",
)

init_method = MultitaskPromptTuningInit(init_method)
log.info("Using initialization method [%s]", init_method)

# If init method provided relates to one that requires source model,
# make sure the source prompt model is provided.
if init_method in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
]:
# NOTE: prompt_tuning_init_source_model is currently a path. In future
# we will replace this with caikit.resources to properly cataloging these
error.type_check(
"<NLP89108490E>",
str,
prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model,
)
tuning_config.prompt_tuning_init_source_model = os.path.join(
get_config().source_prompt_base,
tuning_config.prompt_tuning_init_source_model,
)

error.file_check(
"<NLP96030210E>", tuning_config.prompt_tuning_init_source_model
)
log.debug(
"Validated tuning source prompt [%s]",
tuning_config.prompt_tuning_init_source_model,
)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
torch_dtype = get_torch_dtype(torch_dtype)
if isinstance(base_model, str):
model_config = AutoConfig.from_pretrained(
base_model, local_files_only=not get_config().allow_downloads
)

resource_type = None
for resource in cls.supported_resources:
if model_config.model_type in resource.SUPPORTED_MODEL_TYPES:
resource_type = resource
break

if not resource_type:
error(
"<NLP61784225E>",
"{} model type is not supported currently!".format(
model_config.model_type
),
)
log.debug("Bootstrapping base resource [%s]", base_model)
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)

# Validate if tuned output model type is compatible with base model or not
if not tuning_config.output_model_types:
output_model_types = base_model.PROMPT_OUTPUT_TYPES
else:
# If the first element is not PromptOutputModelType, assume the entire list
# isn't and convert
if not isinstance(
tuning_config.output_model_types[0], PromptOutputModelType
):
output_model_types = []
for output_type in tuning_config.output_model_types:
output_model_types.append(PromptOutputModelType(output_type))
else:
output_model_types = tuning_config.output_model_types
error.value_check(
"<NLP36947542E>",
all(
output_type in base_model.PROMPT_OUTPUT_TYPES
for output_type in output_model_types
),
"{} not supported for base model type {}".format(
output_model_types, base_model.MODEL_TYPE
),
)

error.value_check(
"<NLP30542004E>",
len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
f"Too many output model types. Got {len(output_model_types)}, "
f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
)
# Ensure that our verbalizer is a string and will not render to a hardcoded string
error.value_check(
"<NLP83837412E>",
is_valid_verbalizer(verbalizer),
"Provided verbalizer is an invalid type or has no renderable placeholders",
)

# NOTE: Base model is a resource at this point
task_type = base_model.TASK_TYPE

# HACK - These things can't be passed through the train API currently
metric = kwargs.get("metric")
if isinstance(tuning_type, str):
error.value_check(
"<NLP65714994E>",
tuning_type in TuningType._member_names_,
f"Invalid tuning type [{tuning_type}]. Allowed types: "
f"[{TuningType._member_names_}]",
)
tuning_type = TuningType(tuning_type)
error.type_check("<NLP65714993E>", TuningType, tuning_type=tuning_type)
task_type, output_model_types = validate_peft_config(tuning_type, tuning_config, error, log, base_model, cls, torch_dtype, verbalizer)


train_stream = train_stream.map(convert_to_generation_record)
if val_stream:
Expand Down

0 comments on commit ddc2587

Please sign in to comment.