Skip to content

Commit

Permalink
feat: allow for padding free plugin to be used without response templ…
Browse files Browse the repository at this point in the history
…ate (foundation-model-stack#430)

* fix: allow for padding free + pretraining

Signed-off-by: Harikrishnan Balagopal <[email protected]>

* add data collator for padding free plugin scenario to be used for extended pretraining

Signed-off-by: Dushyant Behl <[email protected]>

* fix: update value error

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: delete images only when exists

Signed-off-by: Mehant Kammakomati <[email protected]>

---------

Signed-off-by: Harikrishnan Balagopal <[email protected]>
Signed-off-by: Dushyant Behl <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Co-authored-by: Harikrishnan Balagopal <[email protected]>
Co-authored-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent 8851227 commit 53a9d18
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 23 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ jobs:
sudo swapoff -a
sudo rm -f /swapfile
sudo apt clean
docker rmi $(docker image ls -aq)
if [ "$(docker image ls -q)" ]; then docker rmi $(docker image ls -aq); fi
df -h
- name: Build image
run: |
docker build -t fms-hf-tuning:dev . -f build/Dockerfile
62 changes: 49 additions & 13 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def test_is_pretokenized_data(data, result):

@pytest.mark.parametrize(
"packing, response_template, formatted_train_dataset,\
max_seq_length, instruction_template, expected_collator",
max_seq_length, instruction_template, is_padding_free, expected_collator",
[
(
False,
Expand All @@ -501,6 +501,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
None,
False,
DataCollatorForCompletionOnlyLM,
),
(
Expand All @@ -517,6 +518,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
None,
False,
DataCollatorForSeq2Seq,
),
(
Expand All @@ -529,6 +531,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
"\n### Text:",
False,
DataCollatorForCompletionOnlyLM,
),
(
Expand All @@ -545,6 +548,20 @@ def test_is_pretokenized_data(data, result):
),
1024,
"\n### Text:",
False,
DataCollatorForSeq2Seq,
),
(
False,
None,
datasets.load_dataset(
"json",
data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
split="train",
),
1024,
None,
True,
DataCollatorForSeq2Seq,
),
],
Expand All @@ -555,6 +572,7 @@ def test_get_data_collator(
formatted_train_dataset,
max_seq_length,
instruction_template,
is_padding_free,
expected_collator,
):
"""Ensure that the correct collator type is fetched based on the data args"""
Expand All @@ -565,6 +583,7 @@ def test_get_data_collator(
is_pretokenized_dataset(formatted_train_dataset),
max_seq_length,
instruction_template,
is_padding_free,
)
assert isinstance(collator, expected_collator)

Expand Down Expand Up @@ -1044,7 +1063,7 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(


@pytest.mark.parametrize(
"data_args",
"data_args, is_padding_free",
[
# single sequence JSON and response template
(
Expand All @@ -1053,7 +1072,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# single sequence JSONL and response template
(
Expand All @@ -1062,7 +1082,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_JSONL,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# single sequence PARQUET and response template
(
Expand All @@ -1071,7 +1092,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output JSON
(
Expand All @@ -1080,7 +1102,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output JSONL
(
Expand All @@ -1089,7 +1112,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output PARQUET
(
Expand All @@ -1098,32 +1122,44 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# input/output JSON with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
)
),
False,
),
# input/output JSONL with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
)
),
False,
),
# input/output PARQUET with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
)
),
False,
),
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
dataset_text_field="output",
),
True,
),
],
)
def test_process_dataargs(data_args):
def test_process_dataargs(data_args, is_padding_free):
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
Expand All @@ -1132,7 +1168,7 @@ def test_process_dataargs(data_args):
output_dir="tmp", # Not needed but positional
)
(train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs(
data_args, tokenizer, TRAIN_ARGS
data_args, tokenizer, TRAIN_ARGS, is_padding_free=is_padding_free
)
assert isinstance(train_set, Dataset)
assert isinstance(eval_set, Dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ class AttentionAndDistributedPackingConfig:
def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)

@property
def is_padding_free(self):
return self.padding_free is not None
13 changes: 13 additions & 0 deletions tuning/data/data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_data_collator(
is_traindata_tokenized: bool,
max_seq_length: int,
instruction_template: Optional[str],
is_padding_free: bool = False,
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
response_template, and dataset_text_field.
Expand All @@ -46,6 +47,8 @@ def get_data_collator(
Max sequence length expected
instruction_template: str
str representing the human response in a chat template
is_padding_free: bool
if padding free plugin is used or not
Returns:
Callable
Expand Down Expand Up @@ -74,6 +77,16 @@ def get_data_collator(
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)

if is_padding_free:
# when packing is false but padding_free is used and
# no response template is used then its a pretrained scenario.
# Current plugin in fms-acceleration is compatible with
# `DataCollatorForSeq2Seq` collator hence we use this.
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=False, max_length=max_seq_length
)

# Note that this automatically pads labels with -100
# TODO check if this is sufficient for preprocessed
if is_traindata_tokenized:
Expand Down
28 changes: 21 additions & 7 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,22 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized):


### Data format 2
def _get_dataset_formatting_handlers(data_args, packing):
def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):

if data_args.response_template is None:
if packing is False:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)
if is_padding_free:
logger.debug(
"Assuming pretraining scenario (loss over all tokens) "
+ "because, packing is false,"
+ " padding_free plugin is used and no response template was provided."
)
else:
raise ValueError(
"Since response_template is not provided for masking, \
either use packing or padding_free to enable \
pretraining scenario (loss over all tokens)."
)

if data_args.response_template:
# To use Response template, pass datasets with single sequence instances \
Expand Down Expand Up @@ -209,6 +216,7 @@ def _process_raw_data_args(
packing: bool,
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
is_padding_free: bool = False,
):

# Create a data processor with default processor config
Expand Down Expand Up @@ -248,6 +256,7 @@ def _process_raw_data_args(
tokenizer_kwargs = {}
tokenizer_kwargs["max_length"] = max_seq_length
tokenizer_kwargs["truncation"] = True
# Lets not pad in tokenizer...we can handle that in the collator
tokenizer_kwargs["padding"] = False

handlers = None
Expand All @@ -266,7 +275,7 @@ def _process_raw_data_args(
elif data_args.data_formatter_template or data_args.dataset_text_field:
# Data Format 3: Single Sequence Dataset
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing
data_args, packing, is_padding_free
)
else:
# Default Data Format: Dataset with Input/Output Fields
Expand Down Expand Up @@ -300,6 +309,7 @@ def process_dataargs(
tokenizer: AutoTokenizer,
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
is_padding_free: bool = False,
):
"""
Args:
Expand All @@ -310,6 +320,8 @@ def process_dataargs(
Used for packing and max_seq_length
additional_data_handlers: A Dict of [str, callable] data handlers
which need to be registered with the data preprocessor
is_padding_free: A bool representing if Padding free plugin is enabled.
Defaults to False.
Returns:
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
tuple containing
Expand Down Expand Up @@ -345,6 +357,7 @@ def process_dataargs(
train_args.packing,
max_seq_length,
additional_data_handlers,
is_padding_free,
)

# Note: This check should not be removed.
Expand All @@ -359,6 +372,7 @@ def process_dataargs(
is_tokenized_dataset,
max_seq_length,
data_args.instruction_template,
is_padding_free=is_padding_free,
)

dataset_kwargs = {}
Expand Down
12 changes: 11 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ def train(
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)

is_padding_free = False
if attention_and_distributed_packing_config is not None:
is_padding_free = attention_and_distributed_packing_config.is_padding_free

data_preprocessing_time = time.time()
(
formatted_train_dataset,
Expand All @@ -314,7 +318,13 @@ def train(
data_collator,
train_args.max_seq_length,
dataset_kwargs,
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
) = process_dataargs(
data_args,
tokenizer,
train_args,
additional_data_handlers,
is_padding_free=is_padding_free,
)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
)
Expand Down

0 comments on commit 53a9d18

Please sign in to comment.