From 170406347b507c4a5c74a541fc43bb1683386d31 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Tue, 24 Dec 2024 00:08:50 +0530 Subject: [PATCH 1/8] fix broken README.md link Signed-off-by: Dushyant Behl --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8b2dd4e06..47e84a848 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # FMS HF Tuning - [Installation](#installation) -- [Data format](#data-format) +- [Data format support](#data-support) - [Supported Models](#supported-models) - [Training](#training) - [Single GPU](#single-gpu) From 3dc8ef79b4a3907928dc6ba54d46d32cbaa9950d Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Sat, 4 Jan 2025 06:37:43 +0530 Subject: [PATCH 2/8] feat: Allow hf dataset id to be passed via training_data_path (#431) * Allow hf dataset id to be loaded by training_data_path Signed-off-by: Dushyant Behl * update README Signed-off-by: Dushyant Behl * minor changes Signed-off-by: Abhishek --------- Signed-off-by: Dushyant Behl Signed-off-by: Abhishek Co-authored-by: Abhishek --- README.md | 6 ++- tests/test_sft_trainer.py | 34 ++++++++++++++++- tuning/data/data_handlers.py | 2 +- tuning/data/data_processors.py | 70 ++++++++++++++++++++-------------- 4 files changed, 79 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 47e84a848..c3198083a 100644 --- a/README.md +++ b/README.md @@ -62,13 +62,13 @@ pip install fms-hf-tuning[aim] For more details on how to enable and use the trackers, Please see, [the experiment tracking section below](#experiment-tracking). ## Data Support -Users can pass training data in a single file using the `--training_data_path` argument along with other arguments required for various [use cases](#use-cases-supported-with-training_data_path-argument) (see details below) and the file can be in any of the [supported formats](#supported-data-formats). Alternatively, you can use our powerful [data preprocessing backend](./docs/advanced-data-preprocessing.md) to preprocess datasets on the fly. +Users can pass training data as either a single file or a Hugging Face dataset ID using the `--training_data_path` argument along with other arguments required for various [use cases](#use-cases-supported-with-training_data_path-argument) (see details below). If user choose to pass a file, it can be in any of the [supported formats](#supported-data-formats). Alternatively, you can use our powerful [data preprocessing backend](./docs/advanced-data-preprocessing.md) to preprocess datasets on the fly. Below, we mention the list of supported data usecases via `--training_data_path` argument. For details of our advanced data preprocessing see more details in [Advanced Data Preprocessing](./docs/advanced-data-preprocessing.md). ## Supported Data Formats -We support the following data formats via `--training_data_path` argument +We support the following file formats via `--training_data_path` argument Data Format | Tested Support ------------|--------------- @@ -77,6 +77,8 @@ JSONL | ✅ PARQUET | ✅ ARROW | ✅ +As iterated above, we also support passing a HF dataset ID directly via `--training_data_path` argument. + ## Use cases supported with `training_data_path` argument ### 1. Data formats with a single sequence and a specified response_template to use for masking on completion. diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 529f21b66..0bca40afb 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -25,7 +25,7 @@ import tempfile # Third Party -from datasets.exceptions import DatasetGenerationError +from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError from transformers.trainer_callback import TrainerCallback import pytest import torch @@ -326,7 +326,7 @@ def test_run_train_fails_training_data_path_not_exist(): """Check fails when data path not found.""" updated_data_path_args = copy.deepcopy(DATA_ARGS) updated_data_path_args.training_data_path = "fake/path" - with pytest.raises(ValueError): + with pytest.raises(DatasetNotFoundError): sft_trainer.train(MODEL_ARGS, updated_data_path_args, TRAIN_ARGS, None) @@ -998,6 +998,36 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile): assert 'Provide two rhyming words for the word "love"' in output_inference +@pytest.mark.parametrize( + "data_args", + [ + ( + # sample hugging face dataset id + configs.DataArguments( + training_data_path="lhoestq/demo1", + data_formatter_template="### Text:{{review}} \n\n### Stars: {{star}}", + response_template="\n### Stars:", + ) + ) + ], +) +def test_run_e2e_with_hf_dataset_id(data_args): + """ + Check if we can run an e2e test with a hf dataset id as training_data_path. + """ + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_args, train_args) + + # validate ft tuning configs + _validate_training(tempdir) + + # validate inference + _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args) diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index d666a6e76..5b80dc4bb 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -130,7 +130,7 @@ def replace_text(match_obj): if index_object not in element: raise KeyError("Requested template string is not a valid key in dict") - return element[index_object] + return str(element[index_object]) return { dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index 170bc2a81..bdac6947b 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -130,42 +130,56 @@ def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None): f"Failed to generate the dataset from the provided {context}." ) from e - if datafile: - loader = get_loader_for_filepath(file_path=datafile) - if loader in (None, ""): - raise ValueError(f"data path is invalid [{datafile}]") - return _load_dataset(builder=loader, data_files=[datafile]) + def _try_load_dataset(dataset_path, dataset_builder): + """ + Helper function to call load dataset on case by case basis to ensure we handle + directories and files (with or without builders) and hf datasets. - data_paths = datasetconfig.data_paths - builder = datasetconfig.builder - all_datasets = [] + Args: + dataset_path: Path of directory/file, pattern, or hf dataset id. + dataset_builder: Optional builder to use if provided. + Returns: dataset + """ + if not dataset_path: + raise ValueError("Invalid dataset path") - for data_path in data_paths: # CASE 1: User passes directory - if os.path.isdir(data_path): # Checks if path exists and isdirectory + if os.path.isdir(dataset_path): # Checks if path exists and it is a dir # Directory case - if builder: + if dataset_builder: # Load using a builder with a data_dir - dataset = _load_dataset(builder=builder, data_dir=data_path) - else: - # Load directly from the directory - dataset = _load_dataset(data_path=data_path) - else: - # Non-directory (file, pattern, HF dataset name) - # If no builder provided, attempt to infer one - effective_builder = ( - builder if builder else get_loader_for_filepath(data_path) + return _load_dataset(builder=dataset_builder, data_dir=dataset_path) + + # If no builder then load directly from the directory + return _load_dataset(data_path=dataset_path) + + # Non-directory (file, pattern, HF dataset name) + # If no builder provided, attempt to infer one + effective_builder = ( + dataset_builder + if dataset_builder + else get_loader_for_filepath(dataset_path) + ) + + if effective_builder: + # CASE 2: Files passed with builder. Load using the builder and specific files + return _load_dataset( + builder=effective_builder, data_files=[dataset_path] ) - if effective_builder: - # CASE 2: Files passed with builder. Load using the builder and specific files - dataset = _load_dataset( - builder=effective_builder, data_files=[data_path] - ) - else: - # CASE 3: User passes files/folder/pattern/HF_dataset which has no builder - dataset = _load_dataset(data_path=data_path) + # CASE 3: User passes files/folder/pattern/HF_dataset which has no builder + # Still no builder, try if this is a dataset id + return _load_dataset(data_path=dataset_path) + + if datafile: + return _try_load_dataset(datafile, None) + data_paths = datasetconfig.data_paths + builder = datasetconfig.builder + all_datasets = [] + + for data_path in data_paths: + dataset = _try_load_dataset(data_path, builder) all_datasets.append(dataset) # Logs warning if datasets have different columns From 8851227fdd9a40ecc3cb4e98524b48831aa225df Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 7 Jan 2025 15:19:24 -0500 Subject: [PATCH 3/8] feat: dataclass args for accelerated MoE tuning (#390) * feat: accelerated MoE dataclass and init Signed-off-by: Will Johnson * fix: author's note Signed-off-by: Will Johnson * feat: accelerated moe in acceleration framework Signed-off-by: Will Johnson * feat: accelerated moe to sft_trainer Signed-off-by: Will Johnson * feat: fmt, testing Signed-off-by: Will Johnson * fix: rename accelerated moe to fast moe Signed-off-by: Will Johnson * test: add testing for scatter moe on accel framework Signed-off-by: Will Johnson * fix: model, dtype, assertions Signed-off-by: Will Johnson * fix: post init check removed from FastMoe, experimental set to True Signed-off-by: Will Johnson * fix: if non-iterable nested dataclass, still initialize Signed-off-by: Will Johnson * test: add failing test for wrong ep_degree Signed-off-by: Will Johnson * fix: actually expect failure Signed-off-by: Will Johnson * test: make sure fast moe doesn't work with non-moe model Signed-off-by: Will Johnson * fix: regex of new test Signed-off-by: Will Johnson * comment: explain iterable unpacking Signed-off-by: Will Johnson * docs: fast MOE in README Signed-off-by: Will Johnson * docs: Add note for post-processing Signed-off-by: Will Johnson * fix: Dockerfile Signed-off-by: Will Johnson * test: fix params Signed-off-by: Will Johnson * fix: file path Signed-off-by: Will Johnson * fix: expand on docs, remove from Dockerfile, move iterable data to else statement Signed-off-by: Will Johnson * lint Signed-off-by: Will Johnson * fix: spelling Signed-off-by: Will Johnson --------- Signed-off-by: Will Johnson --- README.md | 13 ++ .../test_acceleration_dataclasses.py | 8 + .../test_acceleration_framework.py | 159 +++++++++++++++++- tests/test_sft_trainer.py | 6 +- .../config/acceleration_configs/__init__.py | 1 + .../acceleration_framework_config.py | 13 ++ .../attention_and_distributed_packing.py | 14 ++ .../config/acceleration_configs/fast_moe.py | 36 ++++ tuning/config/acceleration_configs/utils.py | 10 +- tuning/sft_trainer.py | 22 ++- 10 files changed, 270 insertions(+), 12 deletions(-) create mode 100644 tuning/config/acceleration_configs/fast_moe.py diff --git a/README.md b/README.md index c3198083a..742d043b5 100644 --- a/README.md +++ b/README.md @@ -744,6 +744,8 @@ The list of configurations for various `fms_acceleration` plugins: - [attention_and_distributed_packing](./tuning/config/acceleration_configs/attention_and_distributed_packing.py): - `--padding_free`: technique to process multiple examples in single batch without adding padding tokens that waste compute. - `--multipack`: technique for *multi-gpu training* to balance out number of tokens processed in each device, to minimize waiting time. +- [fast_moe_config](./tuning/config/acceleration_configs/fast_moe.py) (experimental): + - `--fast_moe`: trains MoE models in parallel, increasing throughput and decreasing memory usage. Notes: * `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass. @@ -762,6 +764,17 @@ Notes: * Notes on Multipack - works only for *multi-gpu*. - currently only includes the version of *multipack* optimized for linear attention implementations like *flash-attn*. + * Notes on Fast MoE + - `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree). + - `world_size` must be divisible by the `ep_degree` + - Running fast moe modifies the state dict of the model, and must be post-processed using [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) to run inference (HF, vLLM, etc.). + - The typical usecase for this script is to run: + ``` + python -m fms_acceleration_moe.utils.checkpoint_utils \ + \ + \ + + ``` Note: To pass the above flags via a JSON config, each of the flags expects the value to be a mixed type list, so the values must be a list. For example: ```json diff --git a/tests/acceleration/test_acceleration_dataclasses.py b/tests/acceleration/test_acceleration_dataclasses.py index 130159933..fddf140b6 100644 --- a/tests/acceleration/test_acceleration_dataclasses.py +++ b/tests/acceleration/test_acceleration_dataclasses.py @@ -28,6 +28,7 @@ MultiPack, PaddingFree, ) +from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, FusedLoraConfig, @@ -88,6 +89,13 @@ def test_dataclass_parse_successfully(): ) assert isinstance(cfg.multipack, MultiPack) + # 5. Specifing "--fast_moe" will parse an FastMoe class + parser = transformers.HfArgumentParser(dataclass_types=FastMoeConfig) + (cfg,) = parser.parse_args_into_dataclasses( + ["--fast_moe", "1"], + ) + assert isinstance(cfg.fast_moe, FastMoe) + def test_two_dataclasses_parse_successfully_together(): """Ensure that the two dataclasses can parse arguments successfully diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index d25554fe6..94198d52e 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -43,6 +43,7 @@ MultiPack, PaddingFree, ) +from tuning.config.acceleration_configs.fast_moe import FastMoe, FastMoeConfig from tuning.config.acceleration_configs.fused_ops_and_kernels import ( FastKernelsConfig, FusedLoraConfig, @@ -56,7 +57,7 @@ # for some reason the CI will raise an import error if we try to import # these from tests.artifacts.testdata TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( - os.path.dirname(__file__), "../artifacts/testdata/twitter_complaints_json.json" + os.path.dirname(__file__), "../artifacts/testdata/json/twitter_complaints_json.json" ) TWITTER_COMPLAINTS_TOKENIZED = os.path.join( os.path.dirname(__file__), @@ -87,6 +88,10 @@ # Third Party from fms_acceleration_aadp import PaddingFreeAccelerationPlugin + if is_fms_accelerate_available(plugins="moe"): + # Third Party + from fms_acceleration_moe import ScatterMoEAccelerationPlugin + # There are more extensive unit tests in the # https://github.com/foundation-model-stack/fms-acceleration @@ -360,7 +365,7 @@ def test_framework_raises_due_to_invalid_arguments( acceleration_configs_map, ids=["bitsandbytes", "auto_gptq"], ) -def test_framework_intialized_properly_peft( +def test_framework_initialized_properly_peft( quantized_lora_config, model_name_or_path, mock_and_spy ): """Ensure that specifying a properly configured acceleration dataclass @@ -412,7 +417,7 @@ def test_framework_intialized_properly_peft( "and foak plugins" ), ) -def test_framework_intialized_properly_foak(): +def test_framework_initialized_properly_foak(): """Ensure that specifying a properly configured acceleration dataclass properly activates the framework plugin and runs the train sucessfully. """ @@ -477,6 +482,60 @@ def test_framework_intialized_properly_foak(): assert spy2["get_ready_for_train_calls"] == 1 +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="moe"), + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", +) +def test_framework_initialized_properly_moe(): + """Ensure that specifying a properly configured acceleration dataclass + properly activates the framework plugin and runs the train sucessfully. + """ + + with tempfile.TemporaryDirectory() as tempdir: + + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" + model_args.torch_dtype = torch.bfloat16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.bf16 = True + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.response_template = "\n\n### Label:" + data_args.dataset_text_field = "output" + + # initialize a config + moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) + + # create mocked plugin class for spying + MockedPlugin1, spy = create_mock_plugin_class_and_spy( + "FastMoeMock", ScatterMoEAccelerationPlugin + ) + + # 1. mock a plugin class + # 2. register the mocked plugins + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [ + (["training.moe.scattermoe"], MockedPlugin1), + ], + instantiate=False, + ): + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + data_args, + train_args, + fast_moe_config=moe_config, + ) + + # spy inside the train to ensure that the ilab plugin is called + assert spy["model_loader_calls"] == 1 + assert spy["augmentation_calls"] == 0 + assert spy["get_ready_for_train_calls"] == 1 + + @pytest.mark.skipif( not is_fms_accelerate_available(plugins="aadp"), reason="Only runs if fms-accelerate is installed along with \ @@ -661,6 +720,100 @@ def test_error_raised_with_fused_lora_enabled_without_quantized_argument(): ) +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="moe"), + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", +) +def test_error_raised_with_undividable_fastmoe_argument(): + """ + Ensure error is thrown when `--fast_moe` is passed and world_size + is not divisible by ep_degree + """ + with pytest.raises( + AssertionError, match="world size \\(1\\) not divisible by ep_size \\(3\\)" + ): + with tempfile.TemporaryDirectory() as tempdir: + + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" + model_args.torch_dtype = torch.bfloat16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.bf16 = True + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.response_template = "\n\n### Label:" + data_args.dataset_text_field = "output" + + # initialize a config + moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=3)) + + # 1. mock a plugin class + # 2. register the mocked plugins + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [ + (["training.moe.scattermoe"], ScatterMoEAccelerationPlugin), + ], + instantiate=False, + ): + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + data_args, + train_args, + fast_moe_config=moe_config, + ) + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="moe"), + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", +) +def test_error_raised_fast_moe_with_non_moe_model(): + """ + Ensure error is thrown when `--fast_moe` is passed and model is not MoE + """ + with pytest.raises( + AttributeError, + match="'LlamaConfig' object has no attribute 'num_local_experts'", + ): + with tempfile.TemporaryDirectory() as tempdir: + + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" + model_args.torch_dtype = torch.bfloat16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.bf16 = True + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.response_template = "\n\n### Label:" + data_args.dataset_text_field = "output" + + # initialize a config + moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1)) + + # 1. mock a plugin class + # 2. register the mocked plugins + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [ + (["training.moe.scattermoe"], ScatterMoEAccelerationPlugin), + ], + instantiate=False, + ): + with instantiate_model_patcher(): + sft_trainer.train( + model_args, + data_args, + train_args, + fast_moe_config=moe_config, + ) + + @pytest.mark.skipif( not is_fms_accelerate_available(plugins="foak"), reason="Only runs if fms-accelerate is installed along with \ diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 0bca40afb..f2d4a1ee1 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -362,6 +362,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -388,6 +389,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -398,14 +400,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tuning/config/acceleration_configs/__init__.py b/tuning/config/acceleration_configs/__init__.py index 4f20a0afe..98be34240 100644 --- a/tuning/config/acceleration_configs/__init__.py +++ b/tuning/config/acceleration_configs/__init__.py @@ -15,5 +15,6 @@ # Local from .acceleration_framework_config import AccelerationFrameworkConfig from .attention_and_distributed_packing import AttentionAndDistributedPackingConfig +from .fast_moe import FastMoeConfig from .fused_ops_and_kernels import FusedOpsAndKernelsConfig from .quantized_lora_config import QuantizedLoraConfig diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index 76fef1a78..a5a685897 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -22,6 +22,7 @@ # Local from .attention_and_distributed_packing import MultiPack, PaddingFree +from .fast_moe import FastMoe from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig from tuning.utils.import_utils import is_fms_accelerate_available @@ -65,6 +66,7 @@ class AccelerationFrameworkConfig: PACKAGE_PREFIX = "fms_acceleration_" # each field will a single-level use case dataclass + auto_gptq: Annotated[ AutoGPTQLoraConfig, ConfigAnnotation( @@ -89,6 +91,17 @@ class AccelerationFrameworkConfig: ), ] = None + fast_moe: Annotated[ + FastMoe, + ConfigAnnotation( + path="training.moe", + key="scattermoe", + standalone=True, + experimental=True, + required_packages=["moe"], + ), + ] = None + fast_kernels: Annotated[ FastKernelsConfig, ConfigAnnotation( diff --git a/tuning/config/acceleration_configs/attention_and_distributed_packing.py b/tuning/config/acceleration_configs/attention_and_distributed_packing.py index e1ed83a58..803c6a40b 100644 --- a/tuning/config/acceleration_configs/attention_and_distributed_packing.py +++ b/tuning/config/acceleration_configs/attention_and_distributed_packing.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from dataclasses import dataclass diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py new file mode 100644 index 000000000..14a44f929 --- /dev/null +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -0,0 +1,36 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass + +# Local +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass + + +@parsable_dataclass +@dataclass +class FastMoe: + + ep_degree: int = 1 + + +@dataclass +class FastMoeConfig: + + fast_moe: FastMoe = None + + def __post_init__(self): + # ensure nested dataclasses initialized + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/config/acceleration_configs/utils.py b/tuning/config/acceleration_configs/utils.py index 3085a9761..4a6fc316f 100644 --- a/tuning/config/acceleration_configs/utils.py +++ b/tuning/config/acceleration_configs/utils.py @@ -14,7 +14,7 @@ # Standard from dataclasses import fields, is_dataclass -from typing import Dict, List, Type, get_type_hints +from typing import List, Type, get_type_hints # Third Party from transformers.hf_argparser import DataClass, string_to_bool @@ -26,12 +26,16 @@ def ensure_nested_dataclasses_initialized(dataclass: DataClass): this is to be called at the top-level class to init all the nested dataclasses. """ - type_hints: Dict[str, type] = get_type_hints(dataclass) + type_hints = get_type_hints(dataclass) for f in fields(dataclass): nested_type = type_hints[f.name] values = getattr(dataclass, f.name) if values is not None and not is_dataclass(values): - values = nested_type(*values) + # Handle primitive data directly, unpack iterable data + if isinstance(values, (int, float, bool)): + values = nested_type(values) + else: + values = nested_type(*values) setattr(dataclass, f.name, values) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2afdd2dac..6864016fc 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -45,6 +45,7 @@ from tuning.config.acceleration_configs import ( AccelerationFrameworkConfig, AttentionAndDistributedPackingConfig, + FastMoeConfig, FusedOpsAndKernelsConfig, QuantizedLoraConfig, ) @@ -86,6 +87,7 @@ def train( attention_and_distributed_packing_config: Optional[ AttentionAndDistributedPackingConfig ] = None, + fast_moe_config: Optional[FastMoeConfig] = None, additional_data_handlers: Optional[Dict[str, Callable]] = None, ) -> tuple[SFTTrainer, dict]: """Call the SFTTrainer @@ -114,7 +116,8 @@ def train( fusedops_kernels_config: tuning.config.acceleration_configs.FusedOpsAndKernelsConfig \ Should be used in combination with quantized_lora_config. Also currently fused_lora and fast_kernels must used together (may change in future). \ - attention_and_distributed_packing_config: Used for padding-free attention and multipack. + attention_and_distributed_packing_config: Used for padding-free attention and multipack. \ + fast_moe_config: Used for ScatterMoE to run MoE models in parallel. additional_data_handlers: Dict [str:Callable] of any extra data handlers \ to be registered with the data preprocessor Returns: @@ -203,9 +206,10 @@ def train( trainer_callbacks.append(cb) framework = AccelerationFrameworkConfig.from_dataclasses( + fast_moe_config, + attention_and_distributed_packing_config, quantized_lora_config, fusedops_kernels_config, - attention_and_distributed_packing_config, ).get_framework() model_loader = AutoModelForCausalLM.from_pretrained @@ -445,6 +449,7 @@ def get_parser(): QuantizedLoraConfig, FusedOpsAndKernelsConfig, AttentionAndDistributedPackingConfig, + FastMoeConfig, MLflowConfig, ) ) @@ -495,6 +500,8 @@ def parse_arguments(parser, json_config=None): Configuration for fused operations and kernels. AttentionAndDistributedPackingConfig Configuration for padding free and packing. + FastMoeConfig + Configuration for accelerated MoE. MLflowConfig Configuration for mlflow tracker. dict[str, str] @@ -513,6 +520,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + fast_moe_config, mlflow_config, ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") @@ -530,6 +538,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + fast_moe_config, mlflow_config, additional, _, @@ -556,6 +565,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + fast_moe_config, mlflow_config, exp_metadata, ) @@ -578,6 +588,7 @@ def main(): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + fast_moe_config, mlflow_config, exp_metadata, ) = parse_arguments(parser, job_config) @@ -590,8 +601,9 @@ def main(): model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ tune_config %s, file_logger_config, %s aim_config %s, \ quantized_lora_config %s, fusedops_kernels_config %s, \ - attention_and_distributed_packing_config %s,\ - mlflow_config %s, exp_metadata %s", + attention_and_distributed_packing_config, %s,\ + mlflow_config %s, fast_moe_config %s, \ + exp_metadata %s", model_args, data_args, training_args, @@ -602,6 +614,7 @@ def main(): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + fast_moe_config, mlflow_config, exp_metadata, ) @@ -649,6 +662,7 @@ def main(): quantized_lora_config=quantized_lora_config, fusedops_kernels_config=fusedops_kernels_config, attention_and_distributed_packing_config=attention_and_distributed_packing_config, + fast_moe_config=fast_moe_config, ) except (MemoryError, OutOfMemoryError) as e: logger.error(traceback.format_exc()) From 53a9d186f3e3f76ce66f8a441ff65d643dcd1b3a Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 9 Jan 2025 14:53:01 +0530 Subject: [PATCH 4/8] feat: allow for padding free plugin to be used without response template (#430) * fix: allow for padding free + pretraining Signed-off-by: Harikrishnan Balagopal * add data collator for padding free plugin scenario to be used for extended pretraining Signed-off-by: Dushyant Behl * fix: update value error Signed-off-by: Mehant Kammakomati * fix: delete images only when exists Signed-off-by: Mehant Kammakomati --------- Signed-off-by: Harikrishnan Balagopal Signed-off-by: Dushyant Behl Signed-off-by: Mehant Kammakomati Co-authored-by: Harikrishnan Balagopal Co-authored-by: Mehant Kammakomati --- .github/workflows/image.yaml | 3 +- tests/data/test_data_preprocessing_utils.py | 62 +++++++++++++++---- .../attention_and_distributed_packing.py | 4 ++ tuning/data/data_preprocessing_utils.py | 13 ++++ tuning/data/setup_dataprocessor.py | 28 ++++++--- tuning/sft_trainer.py | 12 +++- 6 files changed, 99 insertions(+), 23 deletions(-) diff --git a/.github/workflows/image.yaml b/.github/workflows/image.yaml index 6c1e043c6..bc21039bc 100644 --- a/.github/workflows/image.yaml +++ b/.github/workflows/image.yaml @@ -15,9 +15,8 @@ jobs: sudo swapoff -a sudo rm -f /swapfile sudo apt clean - docker rmi $(docker image ls -aq) + if [ "$(docker image ls -q)" ]; then docker rmi $(docker image ls -aq); fi df -h - name: Build image run: | docker build -t fms-hf-tuning:dev . -f build/Dockerfile - \ No newline at end of file diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 578daffbf..8de5dfc36 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -489,7 +489,7 @@ def test_is_pretokenized_data(data, result): @pytest.mark.parametrize( "packing, response_template, formatted_train_dataset,\ - max_seq_length, instruction_template, expected_collator", + max_seq_length, instruction_template, is_padding_free, expected_collator", [ ( False, @@ -501,6 +501,7 @@ def test_is_pretokenized_data(data, result): ), 1024, None, + False, DataCollatorForCompletionOnlyLM, ), ( @@ -517,6 +518,7 @@ def test_is_pretokenized_data(data, result): ), 1024, None, + False, DataCollatorForSeq2Seq, ), ( @@ -529,6 +531,7 @@ def test_is_pretokenized_data(data, result): ), 1024, "\n### Text:", + False, DataCollatorForCompletionOnlyLM, ), ( @@ -545,6 +548,20 @@ def test_is_pretokenized_data(data, result): ), 1024, "\n### Text:", + False, + DataCollatorForSeq2Seq, + ), + ( + False, + None, + datasets.load_dataset( + "json", + data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + split="train", + ), + 1024, + None, + True, DataCollatorForSeq2Seq, ), ], @@ -555,6 +572,7 @@ def test_get_data_collator( formatted_train_dataset, max_seq_length, instruction_template, + is_padding_free, expected_collator, ): """Ensure that the correct collator type is fetched based on the data args""" @@ -565,6 +583,7 @@ def test_get_data_collator( is_pretokenized_dataset(formatted_train_dataset), max_seq_length, instruction_template, + is_padding_free, ) assert isinstance(collator, expected_collator) @@ -1044,7 +1063,7 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( @pytest.mark.parametrize( - "data_args", + "data_args, is_padding_free", [ # single sequence JSON and response template ( @@ -1053,7 +1072,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_JSON, dataset_text_field="output", response_template="\n### Label:", - ) + ), + False, ), # single sequence JSONL and response template ( @@ -1062,7 +1082,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_JSONL, dataset_text_field="output", response_template="\n### Label:", - ) + ), + False, ), # single sequence PARQUET and response template ( @@ -1071,7 +1092,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET, dataset_text_field="output", response_template="\n### Label:", - ) + ), + False, ), # data formatter template with input/output JSON ( @@ -1080,7 +1102,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", response_template="\n### Label:", - ) + ), + False, ), # data formatter template with input/output JSONL ( @@ -1089,7 +1112,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", response_template="\n### Label:", - ) + ), + False, ), # data formatter template with input/output PARQUET ( @@ -1098,32 +1122,44 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", response_template="\n### Label:", - ) + ), + False, ), # input/output JSON with masking on input ( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - ) + ), + False, ), # input/output JSONL with masking on input ( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - ) + ), + False, ), # input/output PARQUET with masking on input ( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - ) + ), + False, + ), + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA_JSON, + validation_data_path=TWITTER_COMPLAINTS_DATA_JSON, + dataset_text_field="output", + ), + True, ), ], ) -def test_process_dataargs(data_args): +def test_process_dataargs(data_args, is_padding_free): """Ensure that the train/eval data are properly formatted based on the data args / text field""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) TRAIN_ARGS = configs.TrainingArguments( @@ -1132,7 +1168,7 @@ def test_process_dataargs(data_args): output_dir="tmp", # Not needed but positional ) (train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs( - data_args, tokenizer, TRAIN_ARGS + data_args, tokenizer, TRAIN_ARGS, is_padding_free=is_padding_free ) assert isinstance(train_set, Dataset) assert isinstance(eval_set, Dataset) diff --git a/tuning/config/acceleration_configs/attention_and_distributed_packing.py b/tuning/config/acceleration_configs/attention_and_distributed_packing.py index 803c6a40b..3c62c0597 100644 --- a/tuning/config/acceleration_configs/attention_and_distributed_packing.py +++ b/tuning/config/acceleration_configs/attention_and_distributed_packing.py @@ -47,3 +47,7 @@ class AttentionAndDistributedPackingConfig: def __post_init__(self): # ensure nested dataclasses initialized ensure_nested_dataclasses_initialized(self) + + @property + def is_padding_free(self): + return self.padding_free is not None diff --git a/tuning/data/data_preprocessing_utils.py b/tuning/data/data_preprocessing_utils.py index 2c3386e34..b77fdba1d 100644 --- a/tuning/data/data_preprocessing_utils.py +++ b/tuning/data/data_preprocessing_utils.py @@ -29,6 +29,7 @@ def get_data_collator( is_traindata_tokenized: bool, max_seq_length: int, instruction_template: Optional[str], + is_padding_free: bool = False, ) -> Callable: """Create and return the the appropriate collator type based on the configuration for packing, response_template, and dataset_text_field. @@ -46,6 +47,8 @@ def get_data_collator( Max sequence length expected instruction_template: str str representing the human response in a chat template + is_padding_free: bool + if padding free plugin is used or not Returns: Callable @@ -74,6 +77,16 @@ def get_data_collator( tokenizer=tokenizer, ignore_index=configs.IGNORE_INDEX, ) + + if is_padding_free: + # when packing is false but padding_free is used and + # no response template is used then its a pretrained scenario. + # Current plugin in fms-acceleration is compatible with + # `DataCollatorForSeq2Seq` collator hence we use this. + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=False, max_length=max_seq_length + ) + # Note that this automatically pads labels with -100 # TODO check if this is sufficient for preprocessed if is_traindata_tokenized: diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index b6f09c323..7921652b8 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -107,15 +107,22 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): ### Data format 2 -def _get_dataset_formatting_handlers(data_args, packing): +def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False): if data_args.response_template is None: if packing is False: - raise ValueError( - "Since dataset_text_field or data_formatter_template \ - is provided and packing is disabled, \ - needs a corresponding response template for masking" - ) + if is_padding_free: + logger.debug( + "Assuming pretraining scenario (loss over all tokens) " + + "because, packing is false," + + " padding_free plugin is used and no response template was provided." + ) + else: + raise ValueError( + "Since response_template is not provided for masking, \ + either use packing or padding_free to enable \ + pretraining scenario (loss over all tokens)." + ) if data_args.response_template: # To use Response template, pass datasets with single sequence instances \ @@ -209,6 +216,7 @@ def _process_raw_data_args( packing: bool, max_seq_length: int, additional_data_handlers: Dict[str, Callable] = None, + is_padding_free: bool = False, ): # Create a data processor with default processor config @@ -248,6 +256,7 @@ def _process_raw_data_args( tokenizer_kwargs = {} tokenizer_kwargs["max_length"] = max_seq_length tokenizer_kwargs["truncation"] = True + # Lets not pad in tokenizer...we can handle that in the collator tokenizer_kwargs["padding"] = False handlers = None @@ -266,7 +275,7 @@ def _process_raw_data_args( elif data_args.data_formatter_template or data_args.dataset_text_field: # Data Format 3: Single Sequence Dataset handlers, dataset_text_field = _get_dataset_formatting_handlers( - data_args, packing + data_args, packing, is_padding_free ) else: # Default Data Format: Dataset with Input/Output Fields @@ -300,6 +309,7 @@ def process_dataargs( tokenizer: AutoTokenizer, train_args: TrainingArguments, additional_data_handlers: Dict[str, Callable] = None, + is_padding_free: bool = False, ): """ Args: @@ -310,6 +320,8 @@ def process_dataargs( Used for packing and max_seq_length additional_data_handlers: A Dict of [str, callable] data handlers which need to be registered with the data preprocessor + is_padding_free: A bool representing if Padding free plugin is enabled. + Defaults to False. Returns: Tuple(Dataset, Dataset, str, DataCollator, int, Dict) tuple containing @@ -345,6 +357,7 @@ def process_dataargs( train_args.packing, max_seq_length, additional_data_handlers, + is_padding_free, ) # Note: This check should not be removed. @@ -359,6 +372,7 @@ def process_dataargs( is_tokenized_dataset, max_seq_length, data_args.instruction_template, + is_padding_free=is_padding_free, ) dataset_kwargs = {} diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 6864016fc..32b8735cb 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -306,6 +306,10 @@ def train( data_collator = None logger.info("Packing is set to %s ", train_args.packing) + is_padding_free = False + if attention_and_distributed_packing_config is not None: + is_padding_free = attention_and_distributed_packing_config.is_padding_free + data_preprocessing_time = time.time() ( formatted_train_dataset, @@ -314,7 +318,13 @@ def train( data_collator, train_args.max_seq_length, dataset_kwargs, - ) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers) + ) = process_dataargs( + data_args, + tokenizer, + train_args, + additional_data_handlers, + is_padding_free=is_padding_free, + ) additional_metrics["data_preprocessing_time"] = ( time.time() - data_preprocessing_time ) From 6eb541df657eee9e298ef155913a4aec4e6f0ea3 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 9 Jan 2025 09:26:36 -0500 Subject: [PATCH 5/8] fix: function name from requires_agumentation to requires_augmentation (#434) * fix: function name from requires_agumentation to requires_augmentation Signed-off-by: Will Johnson * fix: file path Signed-off-by: Will Johnson * fmt Signed-off-by: Will Johnson --------- Signed-off-by: Will Johnson --- tests/acceleration/test_acceleration_framework.py | 3 ++- tuning/sft_trainer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 94198d52e..80a445304 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -57,7 +57,8 @@ # for some reason the CI will raise an import error if we try to import # these from tests.artifacts.testdata TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( - os.path.dirname(__file__), "../artifacts/testdata/json/twitter_complaints_json.json" + os.path.dirname(__file__), + "../artifacts/testdata/json/twitter_complaints_small.json", ) TWITTER_COMPLAINTS_TOKENIZED = os.path.join( os.path.dirname(__file__), diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 32b8735cb..85058e098 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -329,7 +329,7 @@ def train( time.time() - data_preprocessing_time ) - if framework is not None and framework.requires_agumentation: + if framework is not None and framework.requires_augmentation: model, (peft_config,) = framework.augmentation( model, train_args, modifiable_args=(peft_config,) ) From 24f7e42bb73d293b50bf71c3757871128895d91f Mon Sep 17 00:00:00 2001 From: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Date: Thu, 9 Jan 2025 16:26:24 -0500 Subject: [PATCH 6/8] add tokens to special_tokens_dict (#436) Signed-off-by: Abhishek --- tuning/sft_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 85058e098..b3e28f686 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -290,8 +290,10 @@ def train( ) if tokenizer.eos_token != configs.DEFAULT_PAD_TOKEN: tokenizer.pad_token = configs.DEFAULT_PAD_TOKEN + special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN else: tokenizer.eos_token = configs.DEFAULT_EOS_TOKEN + special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN # TODO: lower priority but understand if resizing impacts inference quality and why its needed. # It makes sense if we manipulate tokenizer that we also save it and provide it to inference. From 28c3d3843007f47b1741b9f79f306385314d5010 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 16 Jan 2025 13:06:27 -0500 Subject: [PATCH 7/8] deps: upgrade fms-acceleration to >= 0.6 (#440) Signed-off-by: Will Johnson --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8301ce253..b930f7680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "sci flash-attn = ["flash-attn>=2.5.3,<3.0"] aim = ["aim>=3.19.0,<4.0"] mlflow = ["mlflow"] -fms-accel = ["fms-acceleration>=0.1"] +fms-accel = ["fms-acceleration>=0.6"] gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] From d03072bbe34e50d1862c3a268e787ef4ec19ae2a Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Thu, 16 Jan 2025 11:28:11 -0700 Subject: [PATCH 8/8] docs: update granite3 model support (#441) Signed-off-by: Anh Uong --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 742d043b5..f7e3c2c28 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,10 @@ For advanced data preprocessing support including mixing and custom preprocessin Model Name & Size | Model Architecture | Full Finetuning | Low Rank Adaptation (i.e. LoRA) | qLoRA(quantized LoRA) | -------------------- | ---------------- | --------------- | ------------------------------- | --------------------- | Granite PowerLM 3B | GraniteForCausalLM | ✅* | ✅* | ✅* | +Granite 3.1 1B | GraniteForCausalLM | ✔️* | ✔️* | ✔️* | +Granite 3.1 2B | GraniteForCausalLM | ✔️* | ✔️* | ✔️* | +Granite 3.1 3B | GraniteForCausalLM | ✔️* | ✔️* | ✔️* | +Granite 3.1 8B | GraniteForCausalLM | ✔️* | ✔️* | ✔️* | Granite 3.0 2B | GraniteForCausalLM | ✔️* | ✔️* | ✔️* | Granite 3.0 8B | GraniteForCausalLM | ✅* | ✅* | ✔️ | GraniteMoE 1B | GraniteMoeForCausalLM | ✅ | ✅** | ? | @@ -219,7 +223,7 @@ Mixtral 8x7B | Mixtral | ✅ | ✅ | ✅ | Mistral-7b | Mistral | ✅ | ✅ | ✅ |   Mistral large | Mistral | 🚫 | 🚫 | 🚫 | -(*) - Supported with `fms-hf-tuning` v2.0.1 or later +(*) - Supported with `fms-hf-tuning` v2.4.0 or later. (**) - Supported for q,k,v,o layers . `all-linear` target modules does not infer on vLLM yet.