From db9898d10484d7680c597c5c6c7db10fa8a5ba6c Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sat, 2 Nov 2024 18:05:27 -0400 Subject: [PATCH 1/3] Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats. --- llms/mlx_lm/LORA.md | 16 ++++++++++--- llms/mlx_lm/lora.py | 2 ++ llms/mlx_lm/tuner/datasets.py | 44 ++++++++++++++++++++++++++++------- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 156763607..e1a863258 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -247,8 +247,18 @@ Refer to the documentation for the model you are fine-tuning for more details. {"text": "This is an example for the model."} ``` -Note, the format is automatically determined by the dataset. Note also, keys in -each line not expected by the loader will be ignored. +Note, the format is automatically determined by the dataset. + +For the completion data format, a different key can be used for the _prompt_ and for the _completion_ by specifying +the following, for example, in the YAML config: + +```yaml +prompt_feature: "input" +completion_feature: "output" +``` + +Here, `input` is now the expected key instead of "prompt" and `output` is the expected key instead of "completion". +Note also, keys in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than @@ -270,7 +280,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: -``` +```yaml hf_dataset: name: "billsum" prompt_feature: "text" diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 4d050bd54..41a618b11 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -61,6 +61,8 @@ "config": None, "grad_checkpoint": False, "lr_schedule": None, + "prompt_feature": "prompt", + "completion_feature": "completion", "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index fa848f47e..692bdb5ce 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -81,12 +81,20 @@ def __len__(self): return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer): +<<<<<<< HEAD +def create_dataset( + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): sample = data[0] + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in sample and "completion" in sample: + elif prompt_feature in sample and completion_feature in sample: return CompletionsDataset(data, tokenizer) elif "text" in sample: return Dataset(data, tokenizer) @@ -97,20 +105,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer): ) -def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): +def load_local_dataset( + data_path: Path, + tokenizer: PreTrainedTokenizer, + prompt_feature: str = None, + completion_feature: str = None, +): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer) + return create_dataset(data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] return train, valid, test -def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): +def load_hf_dataset( + data_id: str, + tokenizer: PreTrainedTokenizer, + prompt_feature: str = None, + completion_feature: str = None, +): from datasets import exceptions, load_dataset try: @@ -119,7 +137,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): names = ("train", "valid", "test") train, valid, test = [ - create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] + ( + create_dataset( + dataset[n], tokenizer, prompt_feature, completion_feature + ) + if n in dataset.keys() + else [] + ) for n in names ] @@ -176,10 +200,14 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): else: data_path = Path(args.data) if data_path.exists(): - train, valid, test = load_local_dataset(data_path, tokenizer) + train, valid, test = load_local_dataset( + data_path, tokenizer, args.prompt_feature, args.completion_feature + ) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer) + train, valid, test = load_hf_dataset( + args.data, tokenizer, args.prompt_feature, args.completion_feature + ) if args.train and len(train) == 0: raise ValueError( From 40438b137143441793e7dc33a2e7b350ad6c9cec Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sat, 2 Nov 2024 19:02:47 -0400 Subject: [PATCH 2/3] Persist configured prompt/completion key --- llms/mlx_lm/tuner/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 692bdb5ce..81e5d293a 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -95,7 +95,7 @@ def create_dataset( if "messages" in sample: return ChatDataset(data, tokenizer) elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer) + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: return Dataset(data, tokenizer) else: From 7499720b099bbf12fa73422cde68d7a990384ab1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 13 Jan 2025 09:43:23 -0800 Subject: [PATCH 3/3] rebase + nits --- llms/mlx_lm/LORA.md | 23 ++++++++++++----------- llms/mlx_lm/lora.py | 2 -- llms/mlx_lm/tuner/datasets.py | 25 +++++++++++++------------ 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index e1a863258..9eac9d7f9 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -241,24 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details. {"prompt": "What is the capital of France?", "completion": "Paris."} ``` -`text`: +For the `completions` data format, a different key can be used for the prompt +and completion by specifying the following in the YAML config: -```jsonl -{"text": "This is an example for the model."} +```yaml +prompt_feature: "input" +completion_feature: "output" ``` -Note, the format is automatically determined by the dataset. +Here, `"input"` is the expected key instead of the default `"prompt"`, and +`"output"` is the expected key instead of `"completion"`. -For the completion data format, a different key can be used for the _prompt_ and for the _completion_ by specifying -the following, for example, in the YAML config: +`text`: -```yaml -prompt_feature: "input" -completion_feature: "output" +```jsonl +{"text": "This is an example for the model."} ``` -Here, `input` is now the expected key instead of "prompt" and `output` is the expected key instead of "completion". -Note also, keys in each line not expected by the loader will be ignored. +Note, the format is automatically determined by the dataset. Note also, keys +in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 41a618b11..4d050bd54 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -61,8 +61,6 @@ "config": None, "grad_checkpoint": False, "lr_schedule": None, - "prompt_feature": "prompt", - "completion_feature": "completion", "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 81e5d293a..1b09c7e28 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional from transformers import PreTrainedTokenizer @@ -61,8 +61,8 @@ def __init__( self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - completion_key: str = "completion", + prompt_key: str, + completion_key: str, ): self._data = [ tokenizer.apply_chat_template( @@ -81,17 +81,15 @@ def __len__(self): return len(self._data) -<<<<<<< HEAD def create_dataset( data, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, completion_feature: Optional[str] = None, ): - sample = data[0] prompt_feature = prompt_feature or "prompt" completion_feature = completion_feature or "completion" - + sample = data[0] if "messages" in sample: return ChatDataset(data, tokenizer) elif prompt_feature in sample and completion_feature in sample: @@ -108,8 +106,8 @@ def create_dataset( def load_local_dataset( data_path: Path, tokenizer: PreTrainedTokenizer, - prompt_feature: str = None, - completion_feature: str = None, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, ): def load_subset(path): if not path.exists(): @@ -126,8 +124,8 @@ def load_subset(path): def load_hf_dataset( data_id: str, tokenizer: PreTrainedTokenizer, - prompt_feature: str = None, - completion_feature: str = None, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, ): from datasets import exceptions, load_dataset @@ -199,14 +197,17 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) + + prompt_feature = getattr(args, "prompt_feature", None) + completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, args.prompt_feature, args.completion_feature + data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, args.prompt_feature, args.completion_feature + args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: