forked from foundation-model-stack/fms-hf-tuning
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/main'
- Loading branch information
Showing
57 changed files
with
1,434 additions
and
832 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Helpful datasets for configuring individual unit tests. | ||
""" | ||
# Standard | ||
import os | ||
|
||
### Constants used for data | ||
PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__)) | ||
APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( | ||
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" | ||
) | ||
PRETOKENIZE_JSON_DATA_YAML = os.path.join( | ||
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" | ||
) | ||
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join( | ||
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml" | ||
) |
14 changes: 14 additions & 0 deletions
14
tests/artifacts/predefined_data_configs/apply_custom_template.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
dataprocessor: | ||
type: default | ||
datasets: | ||
- name: apply_custom_data_template | ||
data_paths: | ||
- "FILE_PATH" | ||
data_handlers: | ||
- name: apply_custom_data_formatting_template | ||
arguments: | ||
remove_columns: all | ||
batched: false | ||
fn_kwargs: | ||
dataset_text_field: "dataset_text_field" | ||
dataset_template: "dataset_template" |
6 changes: 6 additions & 0 deletions
6
tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
dataprocessor: | ||
type: default | ||
datasets: | ||
- name: pretokenized_dataset | ||
data_paths: | ||
- "FILE_PATH" |
14 changes: 14 additions & 0 deletions
14
tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
dataprocessor: | ||
type: default | ||
datasets: | ||
- name: text_dataset_input_output_masking | ||
data_paths: | ||
- "FILE_PATH" | ||
data_handlers: | ||
- name: tokenize_and_apply_input_masking | ||
arguments: | ||
remove_columns: all | ||
batched: false | ||
fn_kwargs: | ||
input_field: "INPUT" | ||
output_field: "OUTPUT" |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
# https://spdx.dev/learn/handling-license-info/ | ||
|
||
# Third Party | ||
from transformers import AutoTokenizer | ||
import datasets | ||
import pytest | ||
|
||
# First Party | ||
from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL | ||
|
||
# Local | ||
from tuning.data.data_handlers import ( | ||
apply_custom_data_formatting_template, | ||
combine_sequence, | ||
) | ||
|
||
|
||
def test_apply_custom_formatting_template(): | ||
json_dataset = datasets.load_dataset( | ||
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL | ||
) | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
formatted_dataset_field = "formatted_data_field" | ||
formatted_dataset = json_dataset.map( | ||
apply_custom_data_formatting_template, | ||
fn_kwargs={ | ||
"tokenizer": tokenizer, | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
}, | ||
) | ||
# First response from the data file that is read. | ||
expected_response = ( | ||
"### Input: @HMRCcustomers No this is my first job" | ||
+ " \n\n ### Response: no complaint" | ||
+ tokenizer.eos_token | ||
) | ||
|
||
# a new dataset_text_field is created in Dataset | ||
assert formatted_dataset_field in formatted_dataset["train"][0] | ||
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response | ||
|
||
|
||
def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): | ||
"""Tests that the formatting function will throw error if wrong keys are passed to template""" | ||
json_dataset = datasets.load_dataset( | ||
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL | ||
) | ||
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" | ||
formatted_dataset_field = "formatted_data_field" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
with pytest.raises(KeyError): | ||
json_dataset.map( | ||
apply_custom_data_formatting_template, | ||
fn_kwargs={ | ||
"tokenizer": tokenizer, | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
}, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_element,output_element,expected_res", | ||
[ | ||
("foo ", "bar", "foo bar"), | ||
("foo\n", "bar", "foo\nbar"), | ||
("foo\t", "bar", "foo\tbar"), | ||
("foo", "bar", "foo bar"), | ||
], | ||
) | ||
def test_combine_sequence(input_element, output_element, expected_res): | ||
"""Ensure that input / output elements are combined with correct whitespace handling.""" | ||
comb_seq = combine_sequence(input_element, output_element) | ||
assert isinstance(comb_seq, str) | ||
assert comb_seq == expected_res | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_element,output_element,expected_res", | ||
[ | ||
("foo ", "bar", "foo bar"), | ||
("foo\n", "bar", "foo\nbar"), | ||
("foo\t", "bar", "foo\tbar"), | ||
("foo", "bar", "foo bar"), | ||
], | ||
) | ||
def test_combine_sequence_adds_eos(input_element, output_element, expected_res): | ||
"""Ensure that input / output elements are combined with correct whitespace handling.""" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) | ||
expected_res += tokenizer.eos_token | ||
assert isinstance(comb_seq, str) | ||
assert comb_seq == expected_res |
Oops, something went wrong.