15
15
16
16
"""`tfds build` command."""
17
17
18
- import argparse
18
+ import dataclasses
19
19
import functools
20
20
import importlib
21
21
import itertools
25
25
from typing import Any , Dict , Iterator , Optional , Tuple , Type , Union
26
26
27
27
from absl import logging
28
+ import simple_parsing
28
29
import tensorflow_datasets as tfds
29
30
from tensorflow_datasets .scripts .cli import cli_utils
30
31
31
32
# pylint: disable=logging-fstring-interpolation
32
33
33
34
34
- def register_subparser (parsers : argparse ._SubParsersAction ) -> None : # pylint: disable=protected-access
35
- """Add subparser for `build` command."""
36
- build_parser = parsers .add_parser (
37
- 'build' , help = 'Commands for downloading and preparing datasets.'
38
- )
39
- build_parser .add_argument (
40
- 'datasets' , # Positional arguments
41
- type = str ,
42
- nargs = '*' ,
43
- help = (
44
- 'Name(s) of the dataset(s) to build. Default to current dir. '
45
- 'See https://www.tensorflow.org/datasets/cli for accepted values.'
46
- ),
47
- )
48
- build_parser .add_argument ( # Also accept keyword arguments
49
- '--datasets' ,
50
- type = str ,
51
- nargs = '+' ,
52
- dest = 'datasets_keyword' ,
53
- help = 'Datasets can also be provided as keyword argument.' ,
54
- )
35
+ @dataclasses .dataclass (frozen = True , kw_only = True )
36
+ class _AutomationGroup :
37
+ """Used by automated scripts.
55
38
56
- cli_utils .add_debug_argument_group (build_parser )
57
- cli_utils .add_path_argument_group (build_parser )
58
- cli_utils .add_generation_argument_group (build_parser )
59
- cli_utils .add_publish_argument_group (build_parser )
39
+ Attributes:
40
+ exclude_datasets: If set, generate all datasets except the one defined here.
41
+ Comma separated list of datasets to exclude.
42
+ experimental_latest_version: Build the latest Version(experiments=...)
43
+ available rather than default version.
44
+ """
60
45
61
- # **** Automation options ****
62
- automation_group = build_parser .add_argument_group (
63
- 'Automation' , description = 'Used by automated scripts.'
64
- )
65
- automation_group .add_argument (
66
- '--exclude_datasets' ,
67
- type = str ,
68
- help = (
69
- 'If set, generate all datasets except the one defined here. '
70
- 'Comma separated list of datasets to exclude. '
71
- ),
46
+ exclude_datasets : list [str ] = cli_utils .comma_separated_list_field ()
47
+ experimental_latest_version : bool = False
48
+
49
+
50
+ @dataclasses .dataclass (frozen = True , kw_only = True )
51
+ class CmdArgs :
52
+ """Commands for downloading and preparing datasets.
53
+
54
+ Attributes:
55
+ datasets: Name(s) of the dataset(s) to build. Default to current dir. See
56
+ https://www.tensorflow.org/datasets/cli for accepted values.
57
+ datasets_keyword: Datasets can also be provided as keyword argument.
58
+ debug: Debug & tests options.
59
+ path: Paths options.
60
+ generation: Generation options.
61
+ publish: Publishing options.
62
+ automation: Automation options.
63
+ """
64
+
65
+ datasets : list [str ] = simple_parsing .field (
66
+ positional = True , default_factory = list , nargs = '*'
72
67
)
73
- automation_group .add_argument (
74
- '--experimental_latest_version' ,
75
- action = 'store_true' ,
76
- help = (
77
- 'Build the latest Version(experiments=...) available rather than '
78
- 'default version.'
79
- ),
68
+ datasets_keyword : list [str ] = simple_parsing .field (
69
+ alias = 'datasets' , default_factory = list , nargs = '*'
80
70
)
71
+ debug : cli_utils .DebugGroup = simple_parsing .field (prefix = '' )
72
+ path : cli_utils .PathGroup = simple_parsing .field (prefix = '' )
73
+ generation : cli_utils .GenerationGroup = simple_parsing .field (prefix = '' )
74
+ publish : cli_utils .PublishGroup = simple_parsing .field (prefix = '' )
75
+ automation : _AutomationGroup = simple_parsing .field (prefix = '' )
81
76
82
- build_parser .set_defaults (subparser_fn = _build_datasets )
77
+ def execute (self ):
78
+ _build_datasets (self )
83
79
84
80
85
- def _build_datasets (args : argparse . Namespace ) -> None :
81
+ def _build_datasets (args : CmdArgs ) -> None :
86
82
"""Build the given datasets."""
87
83
# Eventually register additional datasets imports
88
- if args .imports :
89
- list (importlib .import_module (m ) for m in args .imports . split ( ',' ) )
84
+ if args .generation . imports :
85
+ list (importlib .import_module (m ) for m in args .generation . imports )
90
86
91
87
# Select datasets to generate
92
- datasets = (args .datasets or []) + (args .datasets_keyword or [])
93
- if args .exclude_datasets : # Generate all datasets if `--exclude_datasets` set
88
+ datasets = args .datasets + args .datasets_keyword
89
+ if (
90
+ args .automation .exclude_datasets
91
+ ): # Generate all datasets if `--exclude_datasets` set
94
92
if datasets :
95
93
raise ValueError ("--exclude_datasets can't be used with `datasets`" )
96
94
datasets = set (tfds .list_builders (with_community_datasets = False )) - set (
97
- args .exclude_datasets . split ( ',' )
95
+ args .automation . exclude_datasets
98
96
)
99
97
datasets = sorted (datasets ) # `set` is not deterministic
100
98
else :
101
99
datasets = datasets or ['' ] # Empty string for default
102
100
103
101
# Import builder classes
104
102
builders_cls_and_kwargs = [
105
- _get_builder_cls_and_kwargs (dataset , has_imports = bool (args .imports ))
103
+ _get_builder_cls_and_kwargs (
104
+ dataset , has_imports = bool (args .generation .imports )
105
+ )
106
106
for dataset in datasets
107
107
]
108
108
@@ -112,19 +112,20 @@ def _build_datasets(args: argparse.Namespace) -> None:
112
112
for (builder_cls , builder_kwargs ) in builders_cls_and_kwargs
113
113
))
114
114
process_builder_fn = functools .partial (
115
- _download if args .download_only else _download_and_prepare , args
115
+ _download if args .generation .download_only else _download_and_prepare ,
116
+ args ,
116
117
)
117
118
118
- if args .num_processes == 1 :
119
+ if args .generation . num_processes == 1 :
119
120
for builder in builders :
120
121
process_builder_fn (builder )
121
122
else :
122
- with multiprocessing .Pool (args .num_processes ) as pool :
123
+ with multiprocessing .Pool (args .generation . num_processes ) as pool :
123
124
pool .map (process_builder_fn , builders )
124
125
125
126
126
127
def _make_builders (
127
- args : argparse . Namespace ,
128
+ args : CmdArgs ,
128
129
builder_cls : Type [tfds .core .DatasetBuilder ],
129
130
builder_kwargs : Dict [str , Any ],
130
131
) -> Iterator [tfds .core .DatasetBuilder ]:
@@ -139,7 +140,7 @@ def _make_builders(
139
140
Initialized dataset builders.
140
141
"""
141
142
# Eventually overwrite version
142
- if args .experimental_latest_version :
143
+ if args .automation . experimental_latest_version :
143
144
if 'version' in builder_kwargs :
144
145
raise ValueError (
145
146
"Can't have both `--experimental_latest` and version set (`:1.0.0`)"
@@ -150,19 +151,19 @@ def _make_builders(
150
151
builder_kwargs ['config' ] = _get_config_name (
151
152
builder_cls = builder_cls ,
152
153
config_kwarg = builder_kwargs .get ('config' ),
153
- config_name = args .config ,
154
- config_idx = args .config_idx ,
154
+ config_name = args .generation . config ,
155
+ config_idx = args .generation . config_idx ,
155
156
)
156
157
157
- if args .file_format :
158
- builder_kwargs ['file_format' ] = args .file_format
158
+ if args .generation . file_format :
159
+ builder_kwargs ['file_format' ] = args .generation . file_format
159
160
160
161
make_builder = functools .partial (
161
162
_make_builder ,
162
163
builder_cls ,
163
- overwrite = args .overwrite ,
164
- fail_if_exists = args .fail_if_exists ,
165
- data_dir = args .data_dir ,
164
+ overwrite = args .debug . overwrite ,
165
+ fail_if_exists = args .debug . fail_if_exists ,
166
+ data_dir = args .path . data_dir ,
166
167
** builder_kwargs ,
167
168
)
168
169
@@ -301,7 +302,7 @@ def _make_builder(
301
302
302
303
303
304
def _download (
304
- args : argparse . Namespace ,
305
+ args : CmdArgs ,
305
306
builder : tfds .core .DatasetBuilder ,
306
307
) -> None :
307
308
"""Downloads all files of the given builder."""
@@ -323,7 +324,7 @@ def _download(
323
324
if builder .MAX_SIMULTANEOUS_DOWNLOADS is not None :
324
325
max_simultaneous_downloads = builder .MAX_SIMULTANEOUS_DOWNLOADS
325
326
326
- download_dir = args .download_dir or os .path .join (
327
+ download_dir = args .path . download_dir or os .path .join (
327
328
builder ._data_dir_root , 'downloads' # pylint: disable=protected-access
328
329
)
329
330
dl_manager = tfds .download .DownloadManager (
@@ -345,51 +346,51 @@ def _download(
345
346
346
347
347
348
def _download_and_prepare (
348
- args : argparse . Namespace ,
349
+ args : CmdArgs ,
349
350
builder : tfds .core .DatasetBuilder ,
350
351
) -> None :
351
352
"""Generate a single builder."""
352
353
cli_utils .download_and_prepare (
353
354
builder = builder ,
354
355
download_config = _make_download_config (args , dataset_name = builder .name ),
355
- download_dir = args .download_dir ,
356
- publish_dir = args .publish_dir ,
357
- skip_if_published = args .skip_if_published ,
358
- overwrite = args .overwrite ,
356
+ download_dir = args .path . download_dir ,
357
+ publish_dir = args .publish . publish_dir ,
358
+ skip_if_published = args .publish . skip_if_published ,
359
+ overwrite = args .debug . overwrite ,
359
360
)
360
361
361
362
362
363
def _make_download_config (
363
- args : argparse . Namespace ,
364
+ args : CmdArgs ,
364
365
dataset_name : str ,
365
366
) -> tfds .download .DownloadConfig :
366
367
"""Generate the download and prepare configuration."""
367
368
# Load the download config
368
- manual_dir = args .manual_dir
369
- if args .add_name_to_manual_dir :
369
+ manual_dir = args .path . manual_dir
370
+ if args .path . add_name_to_manual_dir :
370
371
manual_dir = manual_dir / dataset_name
371
372
372
373
kwargs = {}
373
- if args .max_shard_size_mb :
374
- kwargs ['max_shard_size' ] = args .max_shard_size_mb << 20
375
- if args .download_config :
376
- kwargs .update (json .loads (args .download_config ))
374
+ if args .generation . max_shard_size_mb :
375
+ kwargs ['max_shard_size' ] = args .generation . max_shard_size_mb << 20
376
+ if args .generation . download_config :
377
+ kwargs .update (json .loads (args .generation . download_config ))
377
378
378
379
if 'download_mode' in kwargs :
379
380
kwargs ['download_mode' ] = tfds .download .GenerateMode (
380
381
kwargs ['download_mode' ]
381
382
)
382
383
else :
383
384
kwargs ['download_mode' ] = tfds .download .GenerateMode .REUSE_DATASET_IF_EXISTS
384
- if args .update_metadata_only :
385
+ if args .generation . update_metadata_only :
385
386
kwargs ['download_mode' ] = tfds .download .GenerateMode .UPDATE_DATASET_INFO
386
387
387
388
dl_config = tfds .download .DownloadConfig (
388
- extract_dir = args .extract_dir ,
389
+ extract_dir = args .path . extract_dir ,
389
390
manual_dir = manual_dir ,
390
- max_examples_per_split = args .max_examples_per_split ,
391
- register_checksums = args .register_checksums ,
392
- force_checksums_validation = args .force_checksums_validation ,
391
+ max_examples_per_split = args .debug . max_examples_per_split ,
392
+ register_checksums = args .generation . register_checksums ,
393
+ force_checksums_validation = args .generation . force_checksums_validation ,
393
394
** kwargs ,
394
395
)
395
396
@@ -400,9 +401,9 @@ def _make_download_config(
400
401
beam = None
401
402
402
403
if beam is not None :
403
- if args .beam_pipeline_options :
404
+ if args .generation . beam_pipeline_options :
404
405
dl_config .beam_options = beam .options .pipeline_options .PipelineOptions (
405
- flags = [f'--{ opt } ' for opt in args .beam_pipeline_options . split ( ',' ) ]
406
+ flags = [f'--{ opt } ' for opt in args .generation . beam_pipeline_options ]
406
407
)
407
408
408
409
return dl_config
0 commit comments