diff --git a/.pylintrc b/.pylintrc index d6f8a5d6c..222bdf6cb 100644 --- a/.pylintrc +++ b/.pylintrc @@ -638,7 +638,7 @@ callbacks=cb_, dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. -ignored-argument-names=_.*|^ignored_|^unused_ +ignored-argument-names=_.*|^ignored_|^unused_|kwargs # Tells whether we should check for unused import in __init__ files. init-import=no diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py new file mode 100644 index 000000000..f9b766be6 --- /dev/null +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -0,0 +1,30 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpful datasets for configuring individual unit tests. +""" +# Standard +import os + +### Constants used for data +PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__)) +APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" +) +PRETOKENIZE_JSON_DATA_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" +) +TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml" +) diff --git a/tests/artifacts/predefined_data_configs/apply_custom_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml new file mode 100644 index 000000000..4aab0d76a --- /dev/null +++ b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: apply_custom_data_template + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + dataset_template: "dataset_template" \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml b/tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml new file mode 100644 index 000000000..833173dea --- /dev/null +++ b/tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml @@ -0,0 +1,6 @@ +dataprocessor: + type: default +datasets: + - name: pretokenized_dataset + data_paths: + - "FILE_PATH" \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml new file mode 100644 index 000000000..d8fc16eec --- /dev/null +++ b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: text_dataset_input_output_masking + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_input_masking + arguments: + remove_columns: all + batched: false + fn_kwargs: + input_field: "INPUT" + output_field: "OUTPUT" \ No newline at end of file diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py new file mode 100644 index 000000000..d2a390fe9 --- /dev/null +++ b/tests/data/test_data_handlers.py @@ -0,0 +1,110 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +from transformers import AutoTokenizer +import datasets +import pytest + +# First Party +from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL + +# Local +from tuning.data.data_handlers import ( + apply_custom_data_formatting_template, + combine_sequence, +) + + +def test_apply_custom_formatting_template(): + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + formatted_dataset_field = "formatted_data_field" + formatted_dataset = json_dataset.map( + apply_custom_data_formatting_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + # First response from the data file that is read. + expected_response = ( + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaint" + + tokenizer.eos_token + ) + + # a new dataset_text_field is created in Dataset + assert formatted_dataset_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response + + +def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): + """Tests that the formatting function will throw error if wrong keys are passed to template""" + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + formatted_dataset_field = "formatted_data_field" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(KeyError): + json_dataset.map( + apply_custom_data_formatting_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + + +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + comb_seq = combine_sequence(input_element, output_element) + assert isinstance(comb_seq, str) + assert comb_seq == expected_res + + +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence_adds_eos(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) + expected_res += tokenizer.eos_token + assert isinstance(comb_seq, str) + assert comb_seq == expected_res diff --git a/tests/utils/test_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py similarity index 51% rename from tests/utils/test_preprocessing_utils.py rename to tests/data/test_data_preprocessing_utils.py index cd67a78bb..02308b2f5 100644 --- a/tests/utils/test_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -1,13 +1,36 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import json +import tempfile + # Third Party from datasets import Dataset -from datasets.exceptions import DatasetGenerationError from transformers import AutoTokenizer, DataCollatorForSeq2Seq from trl import DataCollatorForCompletionOnlyLM +import datasets import pytest +import yaml # First Party +from tests.artifacts.predefined_data_configs import ( + APPLY_CUSTOM_TEMPLATE_YAML, + PRETOKENIZE_JSON_DATA_YAML, + TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, +) from tests.artifacts.testdata import ( - MALFORMATTED_DATA, MODEL_NAME, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -19,137 +42,149 @@ # Local from tuning.config import configs -from tuning.utils.preprocessing_utils import ( - combine_sequence, - format_dataset, - get_data_collator, - get_formatted_dataset_with_single_sequence, - get_preprocessed_dataset, +from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig +from tuning.data.data_preprocessing_utils import get_data_collator +from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor +from tuning.data.setup_dataprocessor import ( + _process_dataconfig_file, is_pretokenized_dataset, - load_hf_dataset_from_file, - validate_data_args, + process_dataargs, ) @pytest.mark.parametrize( - "input_element,output_element,expected_res", + "datafile, column_names", [ - ("foo ", "bar", "foo bar"), - ("foo\n", "bar", "foo\nbar"), - ("foo\t", "bar", "foo\tbar"), - ("foo", "bar", "foo bar"), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + set(["ID", "Label", "input", "output"]), + ), + ( + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + ), + ( + TWITTER_COMPLAINTS_DATA_JSONL, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + ), ], ) -def test_combine_sequence(input_element, output_element, expected_res): - """Ensure that input / output elements are combined with correct whitespace handling.""" - comb_seq = combine_sequence(input_element, output_element) - assert isinstance(comb_seq, str) - assert comb_seq == expected_res +def test_load_dataset_with_datafile(datafile, column_names): + """Ensure that both dataset is loaded with datafile.""" + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=None, splitName="train", datafile=datafile + ) + assert set(load_dataset.column_names) == column_names @pytest.mark.parametrize( - "input_element,output_element,expected_res", + "datafile, column_names, datasetconfigname", [ - ("foo ", "bar", "foo bar"), - ("foo\n", "bar", "foo\nbar"), - ("foo\t", "bar", "foo\tbar"), - ("foo", "bar", "foo bar"), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + ), + ( + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + "pretokenized_dataset", + ), + ( + TWITTER_COMPLAINTS_DATA_JSONL, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + "apply_custom_data_template", + ), ], ) -def test_combine_sequence_adds_eos(input_element, output_element, expected_res): - """Ensure that input / output elements are combined with correct whitespace handling.""" - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) - expected_res += tokenizer.eos_token - assert isinstance(comb_seq, str) - assert comb_seq == expected_res +def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigname): + """Ensure that both dataset is loaded with datafile.""" + datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile]) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + assert set(load_dataset.column_names) == column_names -# Tests for loading the dataset from disk @pytest.mark.parametrize( - "dataset_path", - [TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSON], + "datafile, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + "text_dataset_input_output_masking", + ), + (TWITTER_COMPLAINTS_TOKENIZED_JSONL, "pretokenized_dataset"), + (TWITTER_COMPLAINTS_DATA_JSONL, "apply_custom_data_template"), + ], ) -def test_load_hf_dataset_from_file(dataset_path): - input_field_name = "Tweet text" - output_field_name = "text_label" - data = load_hf_dataset_from_file( - dataset_path, - input_field_name=input_field_name, - output_field_name=output_field_name, +def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname): + """Ensure that both datasetconfig and datafile cannot be passed.""" + datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile]) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None ) - # Our dataset should contain dicts that contain the input / output field name types - next_data = next(iter(data)) - assert input_field_name in next_data - assert output_field_name in next_data - - -def test_load_hf_dataset_from_jsonl_file_wrong_keys(): - """Ensure that we explode if the keys are not in the jsonl file.""" - with pytest.raises(DatasetGenerationError): - load_hf_dataset_from_file( - TWITTER_COMPLAINTS_DATA_JSONL, - input_field_name="foo", - output_field_name="bar", - ) - - -def test_load_hf_dataset_from_malformatted_data(): - """Ensure that we explode if the data is not properly formatted.""" - # NOTE: The actual keys don't matter here - with pytest.raises(DatasetGenerationError): - load_hf_dataset_from_file( - MALFORMATTED_DATA, input_field_name="foo", output_field_name="bar" + with pytest.raises(ValueError): + processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=datafile ) -def test_load_hf_dataset_from_jsonl_file_duplicate_keys(): - """Ensure we cannot have the same key for input / output.""" +def test_load_dataset_without_dataconfig_and_datafile(): + """Ensure that both datasetconfig and datafile cannot be None.""" + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) with pytest.raises(ValueError): - load_hf_dataset_from_file( - TWITTER_COMPLAINTS_DATA_JSONL, - input_field_name="Tweet text", - output_field_name="Tweet text", - ) + processor.load_dataset(datasetconfig=None, splitName="train", datafile=None) -# Tests for custom masking / preprocessing logic @pytest.mark.parametrize( - "dataset_path, max_sequence_length", + "data, result", [ - (TWITTER_COMPLAINTS_DATA_JSONL, 1), - (TWITTER_COMPLAINTS_DATA_JSONL, 10), - (TWITTER_COMPLAINTS_DATA_JSONL, 100), - (TWITTER_COMPLAINTS_DATA_JSONL, 1000), - (TWITTER_COMPLAINTS_DATA_JSON, 1), - (TWITTER_COMPLAINTS_DATA_JSON, 10), - (TWITTER_COMPLAINTS_DATA_JSON, 100), - (TWITTER_COMPLAINTS_DATA_JSON, 1000), + (TWITTER_COMPLAINTS_DATA_JSONL, False), + ( + Dataset.from_list( + [ + { + "input_ids": [9437, 29, 210], + "attention_mask": [1, 1, 1], + "labels": [1, 20, 30], + } + ] + ), + True, + ), ], ) -def test_get_preprocessed_dataset(dataset_path, max_sequence_length): - """Ensure we can handle preprocessed datasets with different max_sequence_lengths - to ensure proper tokenization and truncation. - """ - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - preprocessed_data = get_preprocessed_dataset( - data_path=dataset_path, - tokenizer=tokenizer, - max_sequence_length=max_sequence_length, - input_field_name="Tweet text", - output_field_name="text_label", - ) - for tok_res in preprocessed_data: - # Since the padding is left to the collator, there should be no 0s in the attention mask yet - assert sum(tok_res["attention_mask"]) == len(tok_res["attention_mask"]) - # If the source text isn't empty, we start with masked inputs - assert tok_res["labels"][0] == -100 - # All keys in the produced record must be the same length - key_lengths = {len(tok_res[k]) for k in tok_res.keys()} - assert len(key_lengths) == 1 - # And also that length should be less than or equal to the max length depending on if we - # are going up to / over the max size and truncating - padding is handled separately - assert key_lengths.pop() <= max_sequence_length +def test_is_pretokenized_data(data, result): + """Ensure that the correct collator type is fetched based on the data args""" + assert is_pretokenized_dataset(data=data) == result @pytest.mark.parametrize( @@ -158,10 +193,10 @@ def test_get_preprocessed_dataset(dataset_path, max_sequence_length): ( False, "\n### Label:", - load_hf_dataset_from_file( - TWITTER_COMPLAINTS_DATA_JSONL, - input_field_name="Tweet text", - output_field_name="text_label", + datasets.load_dataset( + "json", + data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + split="train", ), 1024, DataCollatorForCompletionOnlyLM, @@ -195,35 +230,12 @@ def test_get_data_collator( packing, response_template, AutoTokenizer.from_pretrained(MODEL_NAME), - formatted_train_dataset, + is_pretokenized_dataset(formatted_train_dataset), max_seq_length, ) assert isinstance(collator, expected_collator) -@pytest.mark.parametrize( - "data, result", - [ - (TWITTER_COMPLAINTS_DATA_JSONL, False), - ( - Dataset.from_list( - [ - { - "input_ids": [9437, 29, 210], - "attention_mask": [1, 1, 1], - "labels": [1, 20, 30], - } - ] - ), - True, - ), - ], -) -def test_is_pretokenized_data(data, result): - """Ensure that the correct collator type is fetched based on the data args""" - assert is_pretokenized_dataset(data=data) == result - - # Tests for validating data args # Invalid args return ValueError @pytest.mark.parametrize( @@ -310,63 +322,75 @@ def test_is_pretokenized_data(data, result): ), ], ) -def test_validate_args(data_args, packing): +def test_process_data_args_throws_error_where_needed(data_args, packing): """Ensure that respective errors are thrown for incorrect data arguments""" with pytest.raises(ValueError): - validate_data_args(data_args, packing) - - -@pytest.mark.parametrize( - "data_args, packing", - [ - # pretokenized train dataset and no validation dataset passed - ( - configs.DataArguments( - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - ), - False, - ), - # pretokenized train and validation datasets - ( - configs.DataArguments( - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - ), - False, - ), - ], -) -def test_validate_args_pretokenized(data_args, packing): - """Ensure that supported data args do not error out when passing pretokenized datasets""" - validate_data_args(data_args, packing) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + TRAIN_ARGS = configs.TrainingArguments( + packing=packing, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + (_, _, _, _, _, _) = process_dataargs(data_args, tokenizer, TRAIN_ARGS) @pytest.mark.parametrize( - "data_path, dataset_text_field, data_formatter_template", + "data_config_path, data_path", [ - (TWITTER_COMPLAINTS_DATA_JSON, "output", None), - (TWITTER_COMPLAINTS_DATA_JSONL, "output", None), + (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), + (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), + (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), + (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), ( + TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - "formatted_field", - "### Text:{{input}} \n\n### Label: {{output}}", ), ( + TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - "formatted_field", - "### Text:{{input}} \n\n### Label: {{output}}", ), ], ) -def test_get_formatted_dataset_with_single_sequence( - data_path, dataset_text_field, data_formatter_template -): +def test_process_dataconfig_file(data_config_path, data_path): + """Ensure that datasets are formatted and validated correctly based on the arguments passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"][0] = data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - formatted_dataset = get_formatted_dataset_with_single_sequence( - data_path, dataset_text_field, tokenizer, data_formatter_template - ) - assert isinstance(formatted_dataset, Dataset) - assert dataset_text_field in formatted_dataset.column_names + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + if datasets_name == "text_dataset_input_output_masking": + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + elif datasets_name == "pretokenized_dataset": + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + elif datasets_name == "apply_custom_data_template": + assert formatted_dataset_field in set(train_set.column_names) @pytest.mark.parametrize( @@ -395,8 +419,8 @@ def test_get_formatted_dataset_with_single_sequence( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - dataset_text_field="formatted_field", data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", + response_template="\n### Label:", ) ), # data formatter template with input/output JSONL @@ -404,8 +428,8 @@ def test_get_formatted_dataset_with_single_sequence( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - dataset_text_field="formatted_field", data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", + response_template="\n### Label:", ) ), # input/output JSON with masking on input @@ -424,11 +448,16 @@ def test_get_formatted_dataset_with_single_sequence( ), ], ) -def test_format_dataset(data_args): +def test_process_dataargs(data_args): """Ensure that the train/eval data are properly formatted based on the data args / text field""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - train_set, eval_set, dataset_text_field = format_dataset( - data_args, tokenizer, max_seq_length=1024 + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + (train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs( + data_args, tokenizer, TRAIN_ARGS ) assert isinstance(train_set, Dataset) assert isinstance(eval_set, Dataset) @@ -472,9 +501,17 @@ def test_format_dataset(data_args): ), ], ) -def test_format_dataset_pretokenized(data_args): +def test_process_dataargs_pretokenized(data_args): """Ensure that pretokenized datasets are loaded and returned as is""" - train_set, eval_set, _ = format_dataset(data_args, None, max_seq_length=1024) + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, eval_set, _, _, _, _) = process_dataargs( + data_args, tokenizer, TRAIN_ARGS + ) assert isinstance(train_set, Dataset) if eval_set: assert isinstance(eval_set, Dataset) @@ -482,3 +519,52 @@ def test_format_dataset_pretokenized(data_args): assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) if eval_set: assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names)) + + +@pytest.mark.parametrize( + "datafile, column_names, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + ), + ( + TWITTER_COMPLAINTS_TOKENIZED_JSON, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + "pretokenized_dataset", + ), + ( + TWITTER_COMPLAINTS_DATA_JSON, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + "apply_custom_data_template", + ), + ], +) +def test_process_dataset_configs(datafile, column_names, datasetconfigname): + """Test process_dataset_configs for expected output.""" + dataprocessor_config = DataPreProcessorConfig() + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + processor = DataPreProcessor( + processor_config=dataprocessor_config, + tokenizer=tokenizer, + ) + datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])] + train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig) + + assert isinstance(train_dataset, Dataset) + assert set(train_dataset.column_names) == column_names + + with open(datafile, "r") as file: + data = json.load(file) + assert len(train_dataset) == len(data) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 0a4ab3d14..69ccbf4fa 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -300,7 +300,7 @@ def test_run_train_fails_training_data_path_not_exist(): """Check fails when data path not found.""" updated_data_path_args = copy.deepcopy(DATA_ARGS) updated_data_path_args.training_data_path = "fake/path" - with pytest.raises(FileNotFoundError): + with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, updated_data_path_args, TRAIN_ARGS, None) @@ -906,15 +906,12 @@ def test_empty_data(): def test_data_path_is_a_directory(): - """Ensure that we get FileNotFoundError if we point the data path at a dir, not a file.""" + """Ensure that we get ValueError if we point the data path at a dir, not a file.""" with tempfile.TemporaryDirectory() as tempdir: data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = tempdir - # Confusingly, if we pass a directory for our data path, it will throw a - # FileNotFoundError saying "unable to find ''", since it can't - # find a matchable file in the path. - with pytest.raises(FileNotFoundError): + with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py deleted file mode 100644 index e56a708b5..000000000 --- a/tests/utils/test_data_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# SPDX-License-Identifier: Apache-2.0 -# https://spdx.dev/learn/handling-license-info/ - -# Third Party -import datasets -import pytest - -# First Party -from tests.artifacts.testdata import TWITTER_COMPLAINTS_DATA_JSONL - -# Local -from tuning.utils import data_utils - - -def test_apply_custom_formatting_template(): - json_dataset = datasets.load_dataset( - "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL - ) - template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" - # First response from the data file that is read. - expected_response = ( - "### Input: @HMRCcustomers No this is my first job" - + " \n\n ### Response: no complaint" - ) - formatted_dataset_field = "formatted_data_field" - formatted_dataset = data_utils.apply_custom_formatting_template( - json_dataset, template, formatted_dataset_field - ) - # a new dataset_text_field is created in Dataset - assert formatted_dataset_field in formatted_dataset["train"][0] - assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response - - -def test_apply_custom_formatting_template_adds_eos_token(): - json_dataset = datasets.load_dataset( - "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL - ) - template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" - # First response from the data file that is read. - expected_response = ( - "### Input: @HMRCcustomers No this is my first job" - + " \n\n ### Response: no complaintEOS" - ) - formatted_dataset_field = "formatted_data_field" - formatted_dataset = data_utils.apply_custom_formatting_template( - json_dataset, template, formatted_dataset_field, "EOS" - ) - # a new dataset_text_field is created in Dataset - assert formatted_dataset_field in formatted_dataset["train"][0] - assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response - - -def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): - """Tests that the formatting function will throw error if wrong keys are passed to template""" - json_dataset = datasets.load_dataset( - "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL - ) - template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" - formatted_dataset_field = "formatted_data_field" - with pytest.raises(KeyError): - data_utils.apply_custom_formatting_template( - json_dataset, template, formatted_dataset_field, "EOS" - ) diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 7b7aa1a2a..88a38c839 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -28,33 +28,32 @@ from tuning.utils.logging import set_log_level -@mock.patch.dict(os.environ, {}, clear=True) def test_set_log_level_for_logger_default(): """ Ensure that the correct log level is being set for python native logger and transformers logger when no env var or CLI flag is passed """ - train_args = copy.deepcopy(TRAIN_ARGS) - training_args, logger = set_log_level(train_args) - assert logger.getEffectiveLevel() == logging.WARNING - assert training_args.log_level == "passive" + with mock.patch.dict(os.environ, {}, clear=True): + train_args = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.WARNING + assert training_args.log_level == "passive" -@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) def test_set_log_level_for_logger_with_env_var(): """ Ensure that the correct log level is being set for python native logger and transformers logger when env var LOG_LEVEL is used """ - train_args_env = copy.deepcopy(TRAIN_ARGS) - training_args, logger = set_log_level(train_args_env) - assert logger.getEffectiveLevel() == logging.INFO - assert training_args.log_level == "info" + with mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True): + train_args_env = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args_env) + assert logger.getEffectiveLevel() == logging.INFO + assert training_args.log_level == "info" -@mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True) def test_set_log_level_for_logger_with_set_verbosity_and_cli(): """ Ensure that the correct log level is being set for python native logger and @@ -62,14 +61,14 @@ def test_set_log_level_for_logger_with_set_verbosity_and_cli(): and CLI flag is passed """ - train_args = copy.deepcopy(TRAIN_ARGS) - train_args.log_level = "error" - training_args, logger = set_log_level(train_args) - assert logger.getEffectiveLevel() == logging.ERROR - assert training_args.log_level == "error" + with mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True): + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" -@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) def test_set_log_level_for_logger_with_env_var_and_cli(): """ Ensure that the correct log level is being set for python native logger and @@ -77,8 +76,9 @@ def test_set_log_level_for_logger_with_env_var_and_cli(): In this case, CLI arg takes precedence over the set env var LOG_LEVEL. """ - train_args = copy.deepcopy(TRAIN_ARGS) - train_args.log_level = "error" - training_args, logger = set_log_level(train_args) - assert logger.getEffectiveLevel() == logging.ERROR - assert training_args.log_level == "error" + with mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True): + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 4bff99f19..222bf4424 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -95,6 +95,13 @@ class DataArguments: or data_formatter_template needs to be supplied." }, ) + data_config_path: str = field( + default=None, + metadata={ + "help": "data config file which specifies the data preprocessing logic to apply.\ + Supports both JSON and YAML based config files." + }, + ) @dataclass diff --git a/tuning/data/__init__.py b/tuning/data/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tuning/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py new file mode 100644 index 000000000..7e3ccd83b --- /dev/null +++ b/tuning/data/data_config.py @@ -0,0 +1,134 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass +from typing import Dict, List, Optional +import logging +import os + +# Local +from tuning.utils.utils import load_yaml_or_json + + +@dataclass +class DataHandlerConfig: + name: str + arguments: Optional[Dict] + + +@dataclass +class DataSetConfig: + name: str + data_paths: List[str] + sampling: Optional[Dict] = None + data_handlers: Optional[List[DataHandlerConfig]] = None + + +@dataclass +class DataPreProcessorConfig: + type: Optional[str] = "default" + + +@dataclass +class DataConfig: + dataprocessor: DataPreProcessorConfig + datasets: List[DataSetConfig] + + +def _validate_data_handler_config(data_handler) -> DataHandlerConfig: + kwargs = data_handler + assert isinstance(kwargs, dict), "data_handlers in data_config needs to be a dict" + assert "name" in kwargs and isinstance( + kwargs["name"], str + ), "data_handlers need to have a name with type str" + assert "arguments" in kwargs, "data handlers need to have arguments" + assert isinstance( + kwargs["arguments"], dict + ), "data handler arguments should be of the type dict" + return DataHandlerConfig(**kwargs) + + +def _validate_dataset_config(dataset_config) -> DataSetConfig: + kwargs = dataset_config + assert isinstance(kwargs, dict), "dataset_config in data_config needs to be a dict" + + c = DataSetConfig(name=kwargs.get("name", ""), data_paths=[]) + + if "name" in kwargs: + assert isinstance(kwargs["name"], str), "dataset name should be string" + if "data_paths" not in kwargs: + raise ValueError("data_paths should be specified for each dataset") + data_paths = kwargs["data_paths"] + # TODO: Support that data_paths can be a directory or directories + assert isinstance(data_paths, list), "data_paths should be an array of files" + c.data_paths = [] + for p in data_paths: + assert isinstance(p, str), f"path {p} should be of the type string" + assert os.path.exists(p), f"data_paths {p} does not exist" + if not os.path.isabs(p): + _p = os.path.abspath(p) + logging.warning( + " Provided path %s is not absolute changing it to %s", p, _p + ) + p = _p + c.data_paths.append(p) + if "sampling" in kwargs: + sampling_kwargs = kwargs["sampling"] + assert isinstance( + dict, sampling_kwargs + ), "sampling arguments should be of the type dict" + if "ratio" in sampling_kwargs: + ratio = sampling_kwargs["ratio"] + assert isinstance(ratio, float) and ( + 0 <= ratio <= 1.0 + ), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]" + c.sampling = sampling_kwargs + if "data_handlers" in kwargs: + c.data_handlers = [] + for handler in kwargs["data_handlers"]: + c.data_handlers.append(_validate_data_handler_config(handler)) + return c + + +def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConfig: + kwargs = dataprocessor_config + c = DataPreProcessorConfig() + assert isinstance(kwargs, dict), "dataprocessor in data_config needs to be a dict" + return c + + +def validate_data_config(dataconfig: DataConfig): + _validate_dataprocessor_config(dataconfig.dataprocessor) + for d in dataconfig.datasets: + _validate_dataset_config(d) + + +def load_and_validate_data_config(data_config_file: str) -> DataConfig: + raw_data = load_yaml_or_json(data_config_file) + assert isinstance( + raw_data, dict + ), f"The provided data_config file is invalid: {data_config_file}" + assert "datasets" in raw_data, "datasets should be provided in data config" + assert isinstance( + raw_data["datasets"], list + ), "datasets should be provided as a list" + datasets = [] + for d in raw_data["datasets"]: + datasets.append(_validate_dataset_config(d)) + if "dataprocessor" in raw_data: + dataprocessor = _validate_dataprocessor_config(raw_data["dataprocessor"]) + + data_config = DataConfig(dataprocessor=dataprocessor, datasets=datasets) + return data_config diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py new file mode 100644 index 000000000..f0100072b --- /dev/null +++ b/tuning/data/data_handlers.py @@ -0,0 +1,142 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Definition of some predefined data preprocessing functions that we need. + +# Standard +from typing import Dict, List +import re + +# Third Party +from transformers import AutoTokenizer + + +### Utils for custom masking / manipulating input / output strs, etc +def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): + """Combines / concatenates input & output element. + + Args: + input_element: str + Input component of the combined sequence. + output_element: str + Output component of the combined sequence. + eos_token: str + EOS token associated with the tokenizer. \ + If passed, it will be concatenated at end + + Returns: + str + Sequence combined with whitespace. + """ + if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( + (" ", "\n", "\t") + ): + return input_element + " " + output_element + eos_token + return input_element + output_element + eos_token + + +def tokenize_and_apply_input_masking( + element: Dict[str, str], + tokenizer: AutoTokenizer, + column_names: List[str], + input_field_name: str, + output_field_name: str, + **tokenizer_kwargs, +): + if (input_field_name or output_field_name) not in column_names: + raise ValueError( + f"Dataset should contain {input_field_name} \ + and {output_field_name} field if \ + no dataset_text_field or data_formatter_template specified" + ) + + input_text = element[input_field_name] + output_text = element[output_field_name] + + combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token) + + fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {}) + tokenizer_inner_kwargs = fn_kwargs.get("tokenizer_kwargs", {}) + + tokenized_comb_seqs = tokenizer(combined, **tokenizer_inner_kwargs) + tokenized_input = tokenizer(input_text, **tokenizer_inner_kwargs) + + masked_labels = [-100] * len( + tokenized_input.input_ids + ) + tokenized_comb_seqs.input_ids[len(tokenized_input.input_ids) :] + + # Any benefit of retaining the old columns? + return { + "input_ids": tokenized_comb_seqs.input_ids, + "labels": masked_labels, + "attention_mask": tokenized_comb_seqs.attention_mask, + } + + +def apply_dataset_formatting( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + **kwargs, +): + return { + f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token + } + + +def apply_custom_data_formatting_template( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + template: str, + **kwargs, +): + """Function to format datasets with Alpaca style / other templates. + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element loaded from a JSON or DatasetDict object. + template: Template to format data with. Features of Dataset + should be referred to by {{key}} + formatted_dataset_field: Dataset_text_field + eos_token: string EOS token to be appended while formatting data to a single sequence. + Defaults to empty + Returns: + Formatted HF Dataset + """ + + template += tokenizer.eos_token + + def replace_text(match_obj): + captured_groups = match_obj.groups() + if len(captured_groups) != 1: + raise ValueError( + "Unexpectedly captured multiple groups in template formatting" + ) + + index_object = captured_groups[0] + if index_object not in element: + raise KeyError("Requested template string is not a valid key in dict") + + return element[index_object] + + return { + dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template) + } + + +AVAILABLE_DATA_HANDLERS = { + "tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, + "apply_dataset_formatting": apply_dataset_formatting, + "apply_custom_data_formatting_template": apply_custom_data_formatting_template, +} diff --git a/tuning/data/data_preprocessing_utils.py b/tuning/data/data_preprocessing_utils.py new file mode 100644 index 000000000..589e4c9ef --- /dev/null +++ b/tuning/data/data_preprocessing_utils.py @@ -0,0 +1,74 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +from typing import Callable, Optional + +# Third Party +from transformers import AutoTokenizer, DataCollatorForSeq2Seq +from trl import DataCollatorForCompletionOnlyLM + +# Local +from tuning.config import configs + + +def get_data_collator( + packing: bool, + response_template: Optional[str], + tokenizer: AutoTokenizer, + is_traindata_tokenized: bool, + max_seq_length: int, +) -> Callable: + """Create and return the the appropriate collator type based on the configuration for packing, + response_template, and dataset_text_field. + + Args: + packing: bool + Whether or not we should apply packing or not. + response_template: Optional[str] + Response template to be used for formatting by TRL. + tokenizer: AutoTokenizer + Loaded tokenizer object to be used by the collator. + is_traindata_tokenized: bool + Whether train Dataset is tokenized or not + max_seq_length: int + Max sequence length expected + + Returns: + Callable + Callable collator to be leveraged by the trainer. + """ + + if not packing: + # TODO: near term - how response template ids are parsed out needs to be cleaned. + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, + # otherwise template is not found. We will create issue to clean this out after we discuss + # data formats and collators we will support. + if response_template: + response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False + )[2:] + return DataCollatorForCompletionOnlyLM( + response_template=response_template_ids, + tokenizer=tokenizer, + ignore_index=configs.IGNORE_INDEX, + ) + # Note that this automatically pads labels with -100 + # TODO check if this is sufficient for preprocessed + if is_traindata_tokenized: + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=True, max_length=max_seq_length + ) + raise ValueError( + "Could not pick a data collator. Please refer to supported data formats" + ) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py new file mode 100644 index 000000000..f6f3b0ec9 --- /dev/null +++ b/tuning/data/data_processors.py @@ -0,0 +1,213 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Dict, List, Union +import logging +import os + +# Third Party +from datasets import Dataset, DatasetDict, IterableDataset +from datasets.exceptions import DatasetNotFoundError +from transformers import AutoTokenizer +import datasets +import torch + +# Local +from tuning.data.data_config import DataConfig, DataPreProcessorConfig, DataSetConfig +from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS +from tuning.utils.utils import get_extension, get_loader_for_filepath + + +class DataPreProcessor: + + tokenizer = None + data_config: DataConfig = None + processor_config: DataPreProcessorConfig = None + registered_handlers: Dict[str, callable] = None + + def __init__( + self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer + ): + self.tokenizer = tokenizer + self.processor_config = processor_config + + # Initialize other objects + self.registered_handlers = {} + + def register_data_handler(self, name: str, func: callable): + self.registered_handlers[name] = func + + def load_dataset( + self, + datasetconfig: DataSetConfig, + splitName: str, + datafile: str = None, + **kwargs, + ): + + if datafile and datasetconfig: + raise ValueError("Both datafile and datasetconfig should not be set") + if (not datafile) and (not datasetconfig): + raise ValueError("Either datafile or datasetconfig must be set") + + if datafile: + files = [datafile] + loader = get_loader_for_filepath(file_path=datafile) + elif datasetconfig: + files = datasetconfig.data_paths + name = datasetconfig.name + # simple check to make sure all files are of same type. + extns = [get_extension(f) for f in files] + assert extns.count(extns[0]) == len( + extns + ), f"All files in the dataset {name} should have the same extension" + loader = get_loader_for_filepath(file_path=files[0]) + + if loader in (None, ""): + raise ValueError(f"data path is invalid [{', '.join(files)}]") + + try: + return datasets.load_dataset( + loader, + data_files=files, + split=splitName, + **kwargs, + ) + except DatasetNotFoundError as e: + raise e + except FileNotFoundError as e: + raise ValueError(f"data path is invalid [{', '.join(files)}]") from e + + def _process_dataset_configs( + self, dataset_configs: List[DataSetConfig], **extra_kwargs + ) -> Union[Dataset, IterableDataset]: + train_dataset = None + final_datasets = None + splitName = "train" # default + + logging.info("Starting DataPreProcessor...") + # Iterate over the multiple datasets provided to us + for d in dataset_configs: + logging.info("Loading %s", d.name) + + # In future the streaming etc go as kwargs of this function + raw_dataset = self.load_dataset(d, splitName) + + logging.info("Loaded raw dataset : {raw_datasets}") + + raw_datasets = DatasetDict() + + # Assume all is train split + if isinstance(raw_dataset, Dataset): + raw_datasets[splitName] = raw_dataset + else: + raw_datasets = raw_dataset + + if d.sampling: + logging.warning("Sampling multiple datasets is not supported yet") + + if d.data_handlers: # Execute the datahandlers + for data_handler in d.data_handlers: + handler_name: str = data_handler.name + handler: callable = self.registered_handlers[handler_name] + kwargs: Dict = data_handler.arguments + + if "batched" not in kwargs: + kwargs["batched"] = False + + column_names = raw_datasets[splitName].column_names + + # remove __content__ from all processing + if "__content__" in column_names: + column_names.remove("__content__") + + if "remove_columns" not in kwargs: + kwargs["remove_columns"] = None + if kwargs["remove_columns"] == "all": + kwargs["remove_columns"] = column_names + + if "num_proc" not in kwargs: + kwargs["num_proc"] = os.cpu_count() + + if "fn_kwargs" not in kwargs: + kwargs["fn_kwargs"] = {} + + kwargs["fn_kwargs"]["tokenizer"] = self.tokenizer + kwargs["fn_kwargs"]["column_names"] = column_names + + kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs) + + logging.info("Applying Handler: %s Args: %s", data_handler, kwargs) + + raw_datasets = raw_datasets.map(handler, **kwargs) + + if final_datasets is None: + final_datasets = raw_datasets + else: + for k in raw_datasets.keys(): + if k in final_datasets: + final_datasets[k] = datasets.concatenate_datasets( + [final_datasets[k], raw_datasets[k]] + ) + else: + final_datasets[k] = raw_datasets[k] + + if "train" in final_datasets: + train_dataset = final_datasets["train"] + + return train_dataset + + def process_dataset_configs( + self, dataset_configs: List[DataSetConfig], **kwargs + ) -> Union[Dataset, IterableDataset]: + train_dataset = None + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + logging.info("Processing data on rank 0...") + train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) + else: + train_dataset = None + + # Use broadcast_object_list to share the dataset object across ranks + # TODO: Check if torch.distributed.barrier() is called in broadcast_object_list() + # See https://github.com/pytorch/pytorch/issues/56142 + # for why the list is shared like this + to_share = [train_dataset] + torch.distributed.broadcast_object_list(to_share, src=0) + train_dataset = to_share[0] + else: + logging.info("Processing data...") + train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) + + return train_dataset + + +def autoregister_available_handlers(processor: DataPreProcessor): + if processor is None: + return + for name, func in AVAILABLE_DATA_HANDLERS.items(): + processor.register_data_handler(name=name, func=func) + + +def get_datapreprocessor( + processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer +) -> DataPreProcessor: + processor = DataPreProcessor( + processor_config=processor_config, + tokenizer=tokenizer, + ) + autoregister_available_handlers(processor) + return processor diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py new file mode 100644 index 000000000..5db8e0aee --- /dev/null +++ b/tuning/data/setup_dataprocessor.py @@ -0,0 +1,322 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Union +import logging + +# Third Party +from datasets import Dataset, IterableDataset + +# Third +from transformers import AutoTokenizer + +# Local +from tuning.config.configs import DataArguments, TrainingArguments +from tuning.data.data_config import ( + DataHandlerConfig, + DataPreProcessorConfig, + DataSetConfig, + load_and_validate_data_config, +) +from tuning.data.data_preprocessing_utils import get_data_collator +from tuning.data.data_processors import get_datapreprocessor + +# In future we may make the fields configurable +DEFAULT_JSON_INPUT_KEY = "input" +DEFAULT_JSON_OUTPUT_KEY = "output" + +# check if the provided dataset is pretokenized or not +# the check is taken from trl +# https://github.com/huggingface/trl/blob/ddf4c8dc3ecf6d9ee2b24f94c62182ffd682c808/trl/trainer/sft_trainer.py#L498-L509 +def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): + if not data: + return False + if isinstance(data, str): + # Create a data processor with default processor config + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + data = processor.load_dataset(None, splitName="train[:1]", datafile=data) + + return ("input_ids" in data.column_names) and ("labels" in data.column_names) + + +# TODO: For now assume only training dataset is passed via data config file. +# This is very limited but is done to keep first implementation minimal +def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer): + data_config = load_and_validate_data_config(data_args.data_config_path) + processor = get_datapreprocessor( + processor_config=data_config.dataprocessor, tokenizer=tokenizer + ) + train_dataset = processor.process_dataset_configs(data_config.datasets) + + return (train_dataset, None, data_args.dataset_text_field) + + +# Data Format 1: Pretokenized Data +def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): + + # if the provided train dataset is pretokenized + # however user provides formatting flags, error out + if ( + data_args.response_template + or data_args.data_formatter_template + or data_args.dataset_text_field + ): + raise ValueError( + "fields response_template, data_formatter_template, and dataset_text_field \ + are not applicable for pretokenized datasets" + ) + + # if the train dataset is pretokenized + # ensure validation dataset is pretokenized otherwise error out + if is_eval_tokenized: + raise ValueError( + "validation data should be pretokenized to be used \ + along with pretokenized train data" + ) + + # Support for packing pretokenized datasets has been merged in trl library + # see: https://github.com/huggingface/trl/pull/2011 + # but we wait till a new transformers version is released to remove this check. + if packing: + raise ValueError("packing will not be used when datasets are pretokenized") + + # We do not need a handler here as this is tokenized dataset + return [], None + + +### Data format 2 +def _get_dataset_formatting_handlers(data_args, packing): + + if data_args.response_template is None: + if packing is False: + raise ValueError( + "Since dataset_text_field or data_formatter_template \ + is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) + + if data_args.response_template: + # To use Response template, pass datasets with single sequence instances \ + # or a formatter template to create single sequence on the fly. + if not (data_args.dataset_text_field or data_args.data_formatter_template): + raise ValueError( + "dataset_text_field and data_formatter_template are None. \ + One of them needs to be set to use response_template" + ) + # Only one of dataset_text_field or data_formatter_template should be set. + if data_args.dataset_text_field and data_args.data_formatter_template: + raise ValueError( + "dataset_text_field and data_formatter_template are both set,\ + but are mutually exclusive options" + ) + + fn_kwargs = {} + dataset_text_field = data_args.dataset_text_field + + if dataset_text_field is None: + dataset_text_field = "new_formatted_field" + + fn_kwargs["dataset_text_field"] = dataset_text_field + if data_args.data_formatter_template is None: + handler = DataHandlerConfig( + "apply_dataset_formatting", + arguments={"fn_kwargs": fn_kwargs, "batched": False}, + ) + else: + fn_kwargs["template"] = data_args.data_formatter_template + handler = DataHandlerConfig( + "apply_custom_data_formatting_template", + arguments={"fn_kwargs": fn_kwargs, "batched": False}, + ) + return [handler], dataset_text_field + + +### Data format 3 +def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs): + + fn_kwargs = {} + fn_kwargs["input_field_name"] = DEFAULT_JSON_INPUT_KEY + fn_kwargs["output_field_name"] = DEFAULT_JSON_OUTPUT_KEY + fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs + + kwargs = { + "fn_kwargs": fn_kwargs, + "batched": False, + "remove_columns": "all", + } + + handler = DataHandlerConfig("tokenize_and_apply_input_masking", arguments=kwargs) + return [handler], data_args.dataset_text_field + + +# Process raw dataargs for various usecases. +# Data Format 1: Pretokenized Data +# Use pretokenized data as-is without preprocessing. +# No handlers are needed for this format. +# Data Format 2: Single Sequence Dataset +# If a text field is specified, append the tokenizer's EOS token to it. +# If a formatter template is provided, apply it and save the result. +# Data remains un-tokenized. +# Data Format 3: JSON Dataset with Input/Output Fields +# Combine input and output fields, tokenize the data, and apply input attention masking. +# Requires both input and output fields; throws an error if missing. +def _process_raw_data_args( + data_args: DataArguments, + tokenizer: AutoTokenizer, + packing: bool, + max_seq_length: int, +): + + # Create a data processor with default processor config + default_processor_config = DataPreProcessorConfig() + data_processor = get_datapreprocessor( + processor_config=default_processor_config, tokenizer=tokenizer + ) + + assert isinstance( + data_args.training_data_path, str + ), "Training data path has to be set and str" + + is_eval_dataset_present = False + if data_args.validation_data_path: + is_eval_dataset_present = True + + # TODO: This check loads first slice of the dataset to view its columns + # Since this load is not done via processor it is redundant + is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) + is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) + + train_dataset_config = DataSetConfig( + name="training_data", + data_paths=[data_args.training_data_path], + data_handlers=None, + ) + if is_eval_dataset_present: + eval_dataset_config = DataSetConfig( + name="validation_data", + data_paths=[data_args.validation_data_path], + data_handlers=None, + ) + + # Setup some tokenizer kwargs for when we need a tokenizer + # TODO: Figure out a way to not hardcode this. + tokenizer_kwargs = {} + tokenizer_kwargs["max_length"] = max_seq_length + tokenizer_kwargs["truncation"] = True + tokenizer_kwargs["padding"] = False + + handlers = None + dataset_text_field = None + if is_traindata_tokenized: + # Data Format 1: Pretokenized Data + handlers, dataset_text_field = _get_pretokenized_dataset_handlers( + data_args, packing, (is_eval_dataset_present and not is_evaldata_tokenized) + ) + elif data_args.data_formatter_template or data_args.dataset_text_field: + # Data Format 2: Single Sequence Dataset + handlers, dataset_text_field = _get_dataset_formatting_handlers( + data_args, packing + ) + else: + # Data Format 3: JSON Dataset with Input/Output Fields + handlers, dataset_text_field = _get_default_json_dataset_handlers( + data_args, tokenizer_kwargs + ) + + # Now set handlers in the dataset configs + train_dataset_config.data_handlers = handlers + if is_eval_dataset_present: + eval_dataset_config.data_handlers = handlers + + # And let processor handle the logic + train_dataset = data_processor.process_dataset_configs([train_dataset_config]) + + eval_dataset = None + if is_eval_dataset_present: + eval_dataset = data_processor.process_dataset_configs([eval_dataset_config]) + + return (train_dataset, eval_dataset, dataset_text_field) + + +# If a data config file is provided, load it to get the training dataset. +# - Assumes only the training dataset is specified in the config file. +# - Expects a complete and valid data config file from the user. +# +# If no data config file is specified, process the remaining data arguments +# to determine the use case based on their presence, as explained in _process_raw_data_args. +def process_dataargs( + data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments +): + """ + Args: + data_args: tuning.config.configs.DataArguments + tokenizer: AutoTokenizer + train_args: TrainingArguments + Training arguments passed to the library + Used for packing and max_seq_length + Returns: + Tuple(Dataset, Dataset, str, DataCollator, int, Dict) + tuple containing train_dataset, eval_dataset, dataset_text_field, + data_collator, max_seq_length and dataset_kwargs + + """ + + max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) + logging.info("Max sequence length is %s", max_seq_length) + if train_args.max_seq_length > tokenizer.model_max_length: + logging.warning( + "max_seq_length %s exceeds tokenizer.model_max_length \ + %s, using tokenizer.model_max_length %s", + train_args.max_seq_length, + tokenizer.model_max_length, + tokenizer.model_max_length, + ) + + train_dataset = eval_dataset = dataset_text_field = None + + if data_args.data_config_path: + train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file( + data_args, tokenizer + ) + else: + train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( + data_args, tokenizer, train_args.packing, max_seq_length + ) + + data_collator = get_data_collator( + train_args.packing, + data_args.response_template, + tokenizer, + # Note: This check should not be removed. + # Its important to recompute this post handling to + # check if we already tokenized the dataset or not. + is_pretokenized_dataset(train_dataset), + max_seq_length, + ) + + dataset_kwargs = {} + if is_pretokenized_dataset(train_dataset or eval_dataset): + dataset_kwargs["skip_prepare_dataset"] = True + + return ( + train_dataset, + eval_dataset, + dataset_text_field, + data_collator, + max_seq_length, + dataset_kwargs, + ) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index fa7d0875c..c02d73781 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -53,6 +53,7 @@ FileLoggingTrackerConfig, TrackerConfigFactory, ) +from tuning.data.setup_dataprocessor import process_dataargs from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER, get_tracker from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config, get_json_config @@ -63,12 +64,6 @@ write_termination_log, ) from tuning.utils.logging import set_log_level -from tuning.utils.preprocessing_utils import ( - format_dataset, - get_data_collator, - is_pretokenized_dataset, - validate_data_args, -) from tuning.utils.tokenizer_data_utils import tokenizer_and_embedding_resize @@ -257,17 +252,6 @@ def train( elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): special_tokens_dict["pad_token"] = "" - max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) - logger.info("Max sequence length is %s", max_seq_length) - if train_args.max_seq_length > tokenizer.model_max_length: - logger.warning( - "max_seq_length %s exceeds tokenizer.model_max_length \ - %s, using tokenizer.model_max_length %s", - train_args.max_seq_length, - tokenizer.model_max_length, - tokenizer.model_max_length, - ) - # add special tokens only when a custom tokenizer is not passed if not model_args.tokenizer_name_or_path: # TODO: we need to change this, perhaps follow what open instruct does? @@ -302,28 +286,20 @@ def train( ) # Configure the collator and validate args related to packing prior to formatting the dataset - if train_args.packing: - logger.info("Packing is set to True") - data_collator = None - packing = True - else: - logger.info("Packing is set to False") - packing = False - - # Validate if data args are set properly - validate_data_args(data_args, packing) + data_collator = None + logger.info("Packing is set to %s ", train_args.packing) + data_preprocessing_time = time.time() ( formatted_train_dataset, formatted_validation_dataset, dataset_text_field, - ) = format_dataset(data_args, tokenizer, max_seq_length) - data_collator = get_data_collator( - packing, - data_args.response_template, - tokenizer, - formatted_train_dataset, + data_collator, max_seq_length, + dataset_kwargs, + ) = process_dataargs(data_args, tokenizer, train_args) + additional_metrics["data_preprocessing_time"] = ( + time.time() - data_preprocessing_time ) if framework is not None and framework.requires_agumentation: @@ -348,17 +324,12 @@ def train( } training_args = SFTConfig(**transformer_kwargs) - dataset_kwargs = {} - if is_pretokenized_dataset( - data_args.training_data_path or data_args.validation_data_path - ): - dataset_kwargs["skip_prepare_dataset"] = True trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=formatted_train_dataset, eval_dataset=formatted_validation_dataset, - packing=packing, + packing=train_args.packing, data_collator=data_collator, dataset_text_field=dataset_text_field, args=training_args, diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py deleted file mode 100644 index db5ff0f0f..000000000 --- a/tuning/utils/data_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -# Standard -import re - - -def apply_custom_formatting_template( - dataset, template, formatted_dataset_field, eos_token="" -): - """Function to format datasets with Alpaca style / other templates. - Args: - dataset: the HF Dataset element loaded from a JSON or DatasetDict object. - template: Template to format data with. Features of Dataset - should be referred to by {{key}} - formatted_dataset_field: Dataset_text_field - eos_token: string EOS token to be appended while formatting data to a single sequence. - Defaults to empty - Returns: - Formatted HF Dataset - """ - - template += eos_token - - if not formatted_dataset_field: - raise ValueError( - "Unable to apply custom formatting because the formatted_dataset_field was not provided" - ) - - def formatter(element): - def replace_text(match_obj): - captured_groups = match_obj.groups() - if len(captured_groups) != 1: - raise ValueError( - "Unexpectedly captured multiple groups in template formatting" - ) - - index_object = captured_groups[0] - if index_object not in element: - raise KeyError("Requested template string is not a valid key in dict") - - return element[index_object] - - return { - formatted_dataset_field: re.sub( - r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template - ) - } - - return dataset.map(formatter) diff --git a/tuning/utils/preprocessing_utils.py b/tuning/utils/preprocessing_utils.py deleted file mode 100644 index a07e99a4e..000000000 --- a/tuning/utils/preprocessing_utils.py +++ /dev/null @@ -1,451 +0,0 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Standard -from typing import Any, Callable, Dict, Optional, Union -import json -import logging -import os - -# Third Party -from datasets import Dataset, IterableDataset -from datasets.exceptions import DatasetGenerationError -from transformers import AutoTokenizer, DataCollatorForSeq2Seq -from trl import DataCollatorForCompletionOnlyLM -import datasets - -# Local -from tuning.config import configs -from tuning.utils.data_utils import apply_custom_formatting_template - -# In future we may make the fields configurable -JSON_INPUT_KEY = "input" -JSON_OUTPUT_KEY = "output" - - -# check if the provided dataset is pretokenized or not -# the check is taken from trl -# https://github.com/huggingface/trl/blob/ddf4c8dc3ecf6d9ee2b24f94c62182ffd682c808/trl/trainer/sft_trainer.py#L498-L509 -def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): - if not data: - return False - if isinstance(data, str): - try: - data = datasets.load_dataset("json", data_files=data, split="train[:1]") - except DatasetGenerationError as e: - raise DatasetGenerationError("failed to load the provided dataset") from e - - return ("input_ids" in data.column_names) and ("labels" in data.column_names) - - -def validate_data_args(data_args: configs.DataArguments, packing: bool): - - assert isinstance( - data_args.training_data_path, str - ), "Training data path has to be set and str" - - is_train_data_pretokenized = is_pretokenized_dataset(data_args.training_data_path) - is_eval_data_pretokenized = is_pretokenized_dataset(data_args.validation_data_path) - - ### Data format 1 - # if the provided train dataset is pretokenized - # however user provides formatting flags, error out - if is_train_data_pretokenized: - if ( - data_args.response_template - or data_args.data_formatter_template - or data_args.dataset_text_field - ): - raise ValueError( - "fields response_template, data_formatter_template, and dataset_text_field \ - are not applicable for pretokenized datasets" - ) - - # if the train dataset is pretokenized - # ensure validation dataset is pretokenized otherwise error out - if data_args.validation_data_path and not is_eval_data_pretokenized: - raise ValueError( - "validation data should be pretokenized to be used \ - along with pretokenized train data" - ) - - # packing wont be available for pretokenized datasets in trl library - # see: https://github.com/huggingface/trl/issues/1848 - if packing: - raise ValueError("packing will not be used when datasets are pretokenized") - return - - ### Data format 2 - # Dataset containing single sequence needs a response template for masking - if data_args.dataset_text_field or data_args.data_formatter_template: - if data_args.response_template is None: - if packing is False: - raise ValueError( - "Since dataset_text_field or data_formatter_template \ - is provided and packing is disabled, \ - needs a corresponding response template for masking" - ) - - if data_args.response_template: - # To use Response template, pass datasets with single sequence instances \ - # or a formatter template to create single sequence on the fly. - if not (data_args.dataset_text_field or data_args.data_formatter_template): - raise ValueError( - "dataset_text_field and data_formatter_template are None. \ - One of them needs to be set to use response_template" - ) - # Only one of dataset_text_field or data_formatter_template should be set. - if data_args.dataset_text_field and data_args.data_formatter_template: - raise ValueError( - "dataset_text_field and data_formatter_template are both set,\ - but are mutually exclusive options" - ) - - ### Data format 3 - # If not single sequence, JSON should contain input/output fields - if not (data_args.dataset_text_field or data_args.data_formatter_template): - json_dataset = datasets.load_dataset( - "json", data_files=data_args.training_data_path - ) - if JSON_INPUT_KEY not in json_dataset["train"].column_names: - raise ValueError( - "JSON should contain input field if no dataset_text_field or \ - data_formatter_template specified" - ) - if JSON_OUTPUT_KEY not in json_dataset["train"].column_names: - raise ValueError( - "JSON should contain output field if no dataset_text_field or \ - data_formatter_template specified" - ) - - -def get_data_collator( - packing: bool, - response_template: Optional[str], - tokenizer: AutoTokenizer, - formatted_train_dataset: Dataset, - max_seq_length: int, -) -> Callable: - """Create and return the the appropriate collator type based on the configuration for packing, - response_template, and dataset_text_field. - - Args: - packing: bool - Whether or not we should apply packing or not. - response_template: Optional[str] - Response template to be used for formatting by TRL. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - formatted_train_dataset: Dataset - Train Dataset formatted for tuning - max_seq_length: int - Max sequence length expected - - Returns: - Callable - Callable collator to be leveraged by the trainer. - """ - is_train_data_pretokenized = is_pretokenized_dataset(formatted_train_dataset) - - if not packing: - # TODO: near term - how response template ids are parsed out needs to be cleaned. - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, - # otherwise template is not found. We will create issue to clean this out after we discuss - # data formats and collators we will support. - if response_template: - response_template_ids = tokenizer.encode( - response_template, add_special_tokens=False - )[2:] - return DataCollatorForCompletionOnlyLM( - response_template=response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) - # Note that this automatically pads labels with -100 - # TODO check if this is sufficient for preprocessed - if is_train_data_pretokenized: - return DataCollatorForSeq2Seq( - tokenizer=tokenizer, padding=True, max_length=max_seq_length - ) - raise ValueError( - "Could not pick a data collator. Please refer to supported data formats" - ) - - -def format_dataset( - data_args: configs.DataArguments, tokenizer: AutoTokenizer, max_seq_length: int -): - """ - Args: - data_args: tuning.config.configs.DataArguments - tokenizer: AutoTokenizer - max_seq_length: int - Max sequence length expected - Returns: - Tuple(Dataset, Dataset, str) - tuple containing train_dataset, eval_dataset and dataset_text_field - """ - eval_dataset = None - is_train_data_pretokenized = is_pretokenized_dataset(data_args.training_data_path) - - if is_train_data_pretokenized: - train_dataset = datasets.load_dataset( - "json", data_files=data_args.training_data_path, split="train" - ) - if data_args.validation_data_path: - eval_dataset = datasets.load_dataset( - "json", data_files=data_args.validation_data_path, split="train" - ) - # dataset_text_field is irrelevant to pretokenized datasets - return train_dataset, eval_dataset, None - - dataset_text_field = data_args.dataset_text_field - if data_args.data_formatter_template or dataset_text_field: - if dataset_text_field is None: - dataset_text_field = "new_formatted_field" - train_dataset = get_formatted_dataset_with_single_sequence( - data_args.training_data_path, - dataset_text_field, - tokenizer, - data_args.data_formatter_template, - ) - logging.info("Training dataset length is %s", len(train_dataset)) - if data_args.validation_data_path: - (eval_dataset) = get_formatted_dataset_with_single_sequence( - data_args.validation_data_path, - dataset_text_field, - tokenizer, - data_args.data_formatter_template, - ) - logging.info("Validation dataset length is %s", len(eval_dataset)) - else: - # This is for JSON containing input/output fields - train_dataset = get_preprocessed_dataset( - data_args.training_data_path, - tokenizer, - max_seq_length, - input_field_name=JSON_INPUT_KEY, - output_field_name=JSON_OUTPUT_KEY, - ) - if data_args.validation_data_path: - eval_dataset = get_preprocessed_dataset( - data_args.validation_data_path, - tokenizer, - max_seq_length, - input_field_name=JSON_INPUT_KEY, - output_field_name=JSON_OUTPUT_KEY, - ) - - return train_dataset, eval_dataset, dataset_text_field - - -def get_formatted_dataset_with_single_sequence( - data_path: str, - dataset_text_field: str, - tokenizer: AutoTokenizer, - data_formatter_template: Optional[str] = None, -) -> Dataset: - """Applies formatting to the loaded dataset instance; does NOT pretokenize data. - - Args: - data_path: str - Path to the file to be loaded. - dataset_text_field: str - Dataset text field to be used for formatting. - If data_formatter_template specified, \ - this will be the new field creating single sequence. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - data_formatter_template: str - Template to apply to create single sequence and store it in dataset_text_field - - Returns: - Dataset - HF Dataset with formatted [str] data. - """ - - json_dataset = datasets.load_dataset("json", data_files=data_path) - format_dataset_EOS = ( - lambda example: { # pylint: disable=unnecessary-lambda-assignment - f"{dataset_text_field}": example[f"{dataset_text_field}"] - + tokenizer.eos_token - } - ) - if data_formatter_template: - formatted_train_dataset = apply_custom_formatting_template( - json_dataset["train"], - data_formatter_template, - dataset_text_field, - tokenizer.eos_token, - ) - else: - formatted_train_dataset = json_dataset.map(format_dataset_EOS)[ - "train" - ] # HACK - for now, we just do both datasets separately; train is the default split - return formatted_train_dataset - - -def get_preprocessed_dataset( - data_path: str, - tokenizer: AutoTokenizer, - max_sequence_length: int, - input_field_name: str, - output_field_name: str, -) -> Dataset: - """Loads the dataset and applies the tokenizer + custom masking logic. - - Args: - data_path: str - Path to the file to be loaded. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - input_field_name: str - Name of the input field in the data. - output_field_name: str - Name of the output field in the data. - - Returns: - Dataset - HF Dataset with the pretokenized data. - """ - dataset = load_hf_dataset_from_file(data_path, input_field_name, output_field_name) - return dataset.map( - preprocess_and_tokenize, - fn_kwargs={ - "tokenizer": tokenizer, - "max_seq_length": max_sequence_length, - "input_field_name": input_field_name, - "output_field_name": output_field_name, - }, - remove_columns=[input_field_name, output_field_name], - ) - - -### Utils for loading the data from disk in supported formats [currently only jsonl] -def load_hf_dataset_from_file( - data_path: str, input_field_name: str, output_field_name: str -) -> Dataset: - """Loads the HuggingFace dataset from JSON or JSONL file. - - Args: - data_path: str - Path to the file to be loaded. - input_field_name: str - Name of the input field in the data. - output_field_name: str - Name of the output field in the data. - - Returns: - Dataset - HF Dataset with the data to be tokenized. - """ - if input_field_name == output_field_name: - raise ValueError("Input field name and output field name should not match!") - - def get_json_object(): - with open(data_path, "r", encoding="utf-8") as json_file: - file_extension = os.path.splitext(data_path)[-1].lower() - if file_extension == ".jsonl": - data_stream = (json.loads(line) for line in json_file) - elif file_extension == ".json": - data_stream = json.load(json_file) - else: - raise ValueError("Unsupported file format! Use 'json' or 'jsonl'.") - - for data in data_stream: - yield { - input_field_name: data[input_field_name], - output_field_name: data[output_field_name], - } - - return Dataset.from_generator(get_json_object) - - -### Utils for custom masking / manipulating input / output strs, etc -def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): - """Combines / concatenates input & output element. - - Args: - input_element: str - Input component of the combined sequence. - output_element: str - Output component of the combined sequence. - eos_token: str - EOS token associated with the tokenizer. \ - If passed, it will be concatenated at end - - Returns: - str - Sequence combined with whitespace. - """ - if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( - (" ", "\n", "\t") - ): - return input_element + " " + output_element + eos_token - return input_element + output_element + eos_token - - -def preprocess_and_tokenize( - element: Dict[str, str], - tokenizer: AutoTokenizer, - max_seq_length: int, - input_field_name: str, - output_field_name: str, -) -> Dict[str, Any]: - """Loads the dataset and applies the tokenizer + custom masking logic. - NOTE: Truncation is done in this step, but padding is not, and generally - handled by the collator. - - Args: - element: Dict[str, str] - A single element of the raw Dataset of strings, whose data we would like to apply - custom masking + tokenization logic to. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - input_field_name: str - Name of the input field in the data. - output_field_name: str - Name of the output field in the data. - - Returns: - Dict[str, Any] - Dictionary containing the input IDs/labels/attention mask for this record. - """ - combined_seq = combine_sequence( - element[input_field_name], element[output_field_name], tokenizer.eos_token - ) - - tokenized_comb_seqs = tokenizer( - combined_seq, max_length=max_seq_length, truncation=True, padding=False - ) - tokenized_input = tokenizer( - element[input_field_name], - max_length=max_seq_length, - truncation=True, - padding=False, - ) - - # mask the prompt part for avoiding loss - masked_labels = [-100] * len( - tokenized_input.input_ids - ) + tokenized_comb_seqs.input_ids[len(tokenized_input.input_ids) :] - - return { - "input_ids": tokenized_comb_seqs.input_ids, - "labels": masked_labels, - "attention_mask": tokenized_comb_seqs.attention_mask, - } diff --git a/tuning/utils/utils.py b/tuning/utils/utils.py new file mode 100644 index 000000000..9def53df9 --- /dev/null +++ b/tuning/utils/utils.py @@ -0,0 +1,44 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import json +import os + +# Third Party +import yaml + + +def get_extension(file_path: str) -> str: + _, ext = os.path.splitext(file_path) + return ext.lower() + + +def get_loader_for_filepath(file_path: str) -> str: + ext = get_extension(file_path) + if ext in (".txt", ".md"): + return "text" + if ext in (".json", ".jsonl"): + return "json" + return ext + + +def load_yaml_or_json(file_path: str) -> dict: + with open(file_path, "r", encoding="utf-8") as f: + ext = get_extension(file_path) + if ext in (".yaml", ".yml"): + return yaml.safe_load(f) + if ext == ".json": + return json.load(f) + return None