Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom local dataset features #1085

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,18 @@ Refer to the documentation for the model you are fine-tuning for more details.
{"text": "This is an example for the model."}
```

Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
Note, the format is automatically determined by the dataset.

For the completion data format, a different key can be used for the _prompt_ and for the _completion_ by specifying
the following, for example, in the YAML config:

```yaml
prompt_feature: "input"
completion_feature: "output"
```

Here, `input` is now the expected key instead of "prompt" and `output` is the expected key instead of "completion".
Note also, keys in each line not expected by the loader will be ignored.

> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
Expand All @@ -270,7 +280,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:

```
```yaml
hf_dataset:
name: "billsum"
prompt_feature: "text"
Expand Down
2 changes: 2 additions & 0 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
"test_batches": 500,
"max_seq_length": 2048,
"lr_schedule": None,
"prompt_feature": "prompt",
"completion_feature": "completion",
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}

Expand Down
54 changes: 44 additions & 10 deletions llms/mlx_lm/tuner/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,27 @@ def __getitem__(self, idx: int):
return text


def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
sample = data[0]
def create_dataset(
data,
tokenizer: PreTrainedTokenizer = None,
prompt_feature: str = None,
completion_feature: str = None,
):
from mlx_lm.lora import CONFIG_DEFAULTS

sample = data[0]
prompt_feature = (
prompt_feature if prompt_feature else CONFIG_DEFAULTS["prompt_feature"]
)
completion_feature = (
completion_feature
if completion_feature
else CONFIG_DEFAULTS["completion_feature"]
)
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data)
else:
Expand All @@ -92,20 +106,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
)


def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: str = None,
completion_feature: str = None,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
return create_dataset(data, tokenizer, prompt_feature, completion_feature)

names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test


def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: str = None,
completion_feature: str = None,
):
from datasets import exceptions, load_dataset

try:
Expand All @@ -114,7 +138,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")

train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
)
for n in names
]

Expand Down Expand Up @@ -171,10 +201,14 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
else:
data_path = Path(args.data)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
train, valid, test = load_local_dataset(
data_path, tokenizer, args.prompt_feature, args.completion_feature
)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
train, valid, test = load_hf_dataset(
args.data, tokenizer, args.prompt_feature, args.completion_feature
)

if args.train and len(train) == 0:
raise ValueError(
Expand Down