-
Notifications
You must be signed in to change notification settings - Fork 187
/
data_mixture.py
76 lines (61 loc) · 2.72 KB
/
data_mixture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
from data_juicer.core.exporter import Exporter
from data_juicer.format import load_formatter
def parse_args():
"""Parse all arguments."""
parser = argparse.ArgumentParser(
description='Mix multiple datasets Arguments')
parser.add_argument('--data_path',
nargs='*',
default=None,
help='Path to datasets. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
parser.add_argument('--export_path',
default='mixed.jsonl',
help='Path to save the mixed dataset. '
'Supported suffixes include '
'["jsonl", "json", "parquet"]')
parser.add_argument('--export_shard_size',
type=int,
default=0,
help='Shard size of exported dataset in Byte. In '
'default, it\'s 0, which means export the whole '
'dataset into only one file. If it\'s set a '
'positive number, the exported dataset will be '
'split into several dataset shards, and the max '
'size of each shard won\'t larger than the '
'export_shard_size')
parser.add_argument('--max_samples',
type=int,
default=None,
help='Number of samples of mixed dataset.')
parser.add_argument('--num_proc',
type=int,
default=4,
help='Number of processes to process dataset.')
args = parser.parse_args()
return args
def run_mixture():
"""
Mix multiple datasets into one dataset.
Randomly select samples from every dataset and mix theses
samples, then export to a new mixed dataset
`data_path` with optional weight(1.0 as default),
e.g.
1) a single data path
2) multiple datasets in the format: <w1> dataset1-path
<w2> dataset1-file <w3>dataset3-path ...'
"""
args = parse_args()
data_path = ' '.join(args.data_path)
formatter = load_formatter(data_path, max_samples=args.max_samples)
dataset = formatter.load_dataset(args.num_proc)
exporter = Exporter(export_path=args.export_path,
export_shard_size=args.export_shard_size,
num_proc=args.num_proc,
export_stats=False)
exporter.export(dataset)
if __name__ == '__main__':
run_mixture()