Skip to content

Commit cf13c47

Browse files
fineguyThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Migrate to simple_parsing
PiperOrigin-RevId: 668933314
1 parent caa4485 commit cf13c47

11 files changed

+351
-563
lines changed

tensorflow_datasets/scripts/cli/build.py

+86-85
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
"""`tfds build` command."""
1717

18-
import argparse
18+
import dataclasses
1919
import functools
2020
import importlib
2121
import itertools
@@ -25,84 +25,84 @@
2525
from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union
2626

2727
from absl import logging
28+
import simple_parsing
2829
import tensorflow_datasets as tfds
2930
from tensorflow_datasets.scripts.cli import cli_utils
3031

3132
# pylint: disable=logging-fstring-interpolation
3233

3334

34-
def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
35-
"""Add subparser for `build` command."""
36-
build_parser = parsers.add_parser(
37-
'build', help='Commands for downloading and preparing datasets.'
38-
)
39-
build_parser.add_argument(
40-
'datasets', # Positional arguments
41-
type=str,
42-
nargs='*',
43-
help=(
44-
'Name(s) of the dataset(s) to build. Default to current dir. '
45-
'See https://www.tensorflow.org/datasets/cli for accepted values.'
46-
),
47-
)
48-
build_parser.add_argument( # Also accept keyword arguments
49-
'--datasets',
50-
type=str,
51-
nargs='+',
52-
dest='datasets_keyword',
53-
help='Datasets can also be provided as keyword argument.',
54-
)
35+
@dataclasses.dataclass(frozen=True, kw_only=True)
36+
class _AutomationGroup:
37+
"""Used by automated scripts.
5538
56-
cli_utils.add_debug_argument_group(build_parser)
57-
cli_utils.add_path_argument_group(build_parser)
58-
cli_utils.add_generation_argument_group(build_parser)
59-
cli_utils.add_publish_argument_group(build_parser)
39+
Attributes:
40+
exclude_datasets: If set, generate all datasets except the one defined here.
41+
Comma separated list of datasets to exclude.
42+
experimental_latest_version: Build the latest Version(experiments=...)
43+
available rather than default version.
44+
"""
6045

61-
# **** Automation options ****
62-
automation_group = build_parser.add_argument_group(
63-
'Automation', description='Used by automated scripts.'
64-
)
65-
automation_group.add_argument(
66-
'--exclude_datasets',
67-
type=str,
68-
help=(
69-
'If set, generate all datasets except the one defined here. '
70-
'Comma separated list of datasets to exclude. '
71-
),
46+
exclude_datasets: list[str] = cli_utils.comma_separated_list_field()
47+
experimental_latest_version: bool = False
48+
49+
50+
@dataclasses.dataclass(frozen=True, kw_only=True)
51+
class CmdArgs:
52+
"""Commands for downloading and preparing datasets.
53+
54+
Attributes:
55+
datasets: Name(s) of the dataset(s) to build. Default to current dir. See
56+
https://www.tensorflow.org/datasets/cli for accepted values.
57+
datasets_keyword: Datasets can also be provided as keyword argument.
58+
debug: Debug & tests options.
59+
path: Paths options.
60+
generation: Generation options.
61+
publish: Publishing options.
62+
automation: Automation options.
63+
"""
64+
65+
datasets: list[str] = simple_parsing.field(
66+
positional=True, default_factory=list, nargs='*'
7267
)
73-
automation_group.add_argument(
74-
'--experimental_latest_version',
75-
action='store_true',
76-
help=(
77-
'Build the latest Version(experiments=...) available rather than '
78-
'default version.'
79-
),
68+
datasets_keyword: list[str] = simple_parsing.field(
69+
alias='datasets', default_factory=list, nargs='*'
8070
)
71+
debug: cli_utils.DebugGroup = simple_parsing.field(prefix='')
72+
path: cli_utils.PathGroup = simple_parsing.field(prefix='')
73+
generation: cli_utils.GenerationGroup = simple_parsing.field(prefix='')
74+
publish: cli_utils.PublishGroup = simple_parsing.field(prefix='')
75+
automation: _AutomationGroup = simple_parsing.field(prefix='')
8176

82-
build_parser.set_defaults(subparser_fn=_build_datasets)
77+
def execute(self):
78+
_build_datasets(self)
8379

8480

85-
def _build_datasets(args: argparse.Namespace) -> None:
81+
def _build_datasets(args: CmdArgs) -> None:
8682
"""Build the given datasets."""
8783
# Eventually register additional datasets imports
88-
if args.imports:
89-
list(importlib.import_module(m) for m in args.imports.split(','))
84+
if args.generation.imports:
85+
list(importlib.import_module(m) for m in args.generation.imports)
9086

9187
# Select datasets to generate
92-
datasets = (args.datasets or []) + (args.datasets_keyword or [])
93-
if args.exclude_datasets: # Generate all datasets if `--exclude_datasets` set
88+
datasets = args.datasets + args.datasets_keyword
89+
if (
90+
args.automation.exclude_datasets
91+
): # Generate all datasets if `--exclude_datasets` set
9492
if datasets:
9593
raise ValueError("--exclude_datasets can't be used with `datasets`")
9694
datasets = set(tfds.list_builders(with_community_datasets=False)) - set(
97-
args.exclude_datasets.split(',')
95+
args.automation.exclude_datasets
9896
)
9997
datasets = sorted(datasets) # `set` is not deterministic
10098
else:
10199
datasets = datasets or [''] # Empty string for default
102100

103101
# Import builder classes
104102
builders_cls_and_kwargs = [
105-
_get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports))
103+
_get_builder_cls_and_kwargs(
104+
dataset, has_imports=bool(args.generation.imports)
105+
)
106106
for dataset in datasets
107107
]
108108

@@ -112,19 +112,20 @@ def _build_datasets(args: argparse.Namespace) -> None:
112112
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
113113
))
114114
process_builder_fn = functools.partial(
115-
_download if args.download_only else _download_and_prepare, args
115+
_download if args.generation.download_only else _download_and_prepare,
116+
args,
116117
)
117118

118-
if args.num_processes == 1:
119+
if args.generation.num_processes == 1:
119120
for builder in builders:
120121
process_builder_fn(builder)
121122
else:
122-
with multiprocessing.Pool(args.num_processes) as pool:
123+
with multiprocessing.Pool(args.generation.num_processes) as pool:
123124
pool.map(process_builder_fn, builders)
124125

125126

126127
def _make_builders(
127-
args: argparse.Namespace,
128+
args: CmdArgs,
128129
builder_cls: Type[tfds.core.DatasetBuilder],
129130
builder_kwargs: Dict[str, Any],
130131
) -> Iterator[tfds.core.DatasetBuilder]:
@@ -139,7 +140,7 @@ def _make_builders(
139140
Initialized dataset builders.
140141
"""
141142
# Eventually overwrite version
142-
if args.experimental_latest_version:
143+
if args.automation.experimental_latest_version:
143144
if 'version' in builder_kwargs:
144145
raise ValueError(
145146
"Can't have both `--experimental_latest` and version set (`:1.0.0`)"
@@ -150,19 +151,19 @@ def _make_builders(
150151
builder_kwargs['config'] = _get_config_name(
151152
builder_cls=builder_cls,
152153
config_kwarg=builder_kwargs.get('config'),
153-
config_name=args.config,
154-
config_idx=args.config_idx,
154+
config_name=args.generation.config,
155+
config_idx=args.generation.config_idx,
155156
)
156157

157-
if args.file_format:
158-
builder_kwargs['file_format'] = args.file_format
158+
if args.generation.file_format:
159+
builder_kwargs['file_format'] = args.generation.file_format
159160

160161
make_builder = functools.partial(
161162
_make_builder,
162163
builder_cls,
163-
overwrite=args.overwrite,
164-
fail_if_exists=args.fail_if_exists,
165-
data_dir=args.data_dir,
164+
overwrite=args.debug.overwrite,
165+
fail_if_exists=args.debug.fail_if_exists,
166+
data_dir=args.path.data_dir,
166167
**builder_kwargs,
167168
)
168169

@@ -301,7 +302,7 @@ def _make_builder(
301302

302303

303304
def _download(
304-
args: argparse.Namespace,
305+
args: CmdArgs,
305306
builder: tfds.core.DatasetBuilder,
306307
) -> None:
307308
"""Downloads all files of the given builder."""
@@ -323,7 +324,7 @@ def _download(
323324
if builder.MAX_SIMULTANEOUS_DOWNLOADS is not None:
324325
max_simultaneous_downloads = builder.MAX_SIMULTANEOUS_DOWNLOADS
325326

326-
download_dir = args.download_dir or os.path.join(
327+
download_dir = args.path.download_dir or os.path.join(
327328
builder._data_dir_root, 'downloads' # pylint: disable=protected-access
328329
)
329330
dl_manager = tfds.download.DownloadManager(
@@ -345,51 +346,51 @@ def _download(
345346

346347

347348
def _download_and_prepare(
348-
args: argparse.Namespace,
349+
args: CmdArgs,
349350
builder: tfds.core.DatasetBuilder,
350351
) -> None:
351352
"""Generate a single builder."""
352353
cli_utils.download_and_prepare(
353354
builder=builder,
354355
download_config=_make_download_config(args, dataset_name=builder.name),
355-
download_dir=args.download_dir,
356-
publish_dir=args.publish_dir,
357-
skip_if_published=args.skip_if_published,
358-
overwrite=args.overwrite,
356+
download_dir=args.path.download_dir,
357+
publish_dir=args.publish.publish_dir,
358+
skip_if_published=args.publish.skip_if_published,
359+
overwrite=args.debug.overwrite,
359360
)
360361

361362

362363
def _make_download_config(
363-
args: argparse.Namespace,
364+
args: CmdArgs,
364365
dataset_name: str,
365366
) -> tfds.download.DownloadConfig:
366367
"""Generate the download and prepare configuration."""
367368
# Load the download config
368-
manual_dir = args.manual_dir
369-
if args.add_name_to_manual_dir:
369+
manual_dir = args.path.manual_dir
370+
if args.path.add_name_to_manual_dir:
370371
manual_dir = manual_dir / dataset_name
371372

372373
kwargs = {}
373-
if args.max_shard_size_mb:
374-
kwargs['max_shard_size'] = args.max_shard_size_mb << 20
375-
if args.download_config:
376-
kwargs.update(json.loads(args.download_config))
374+
if args.generation.max_shard_size_mb:
375+
kwargs['max_shard_size'] = args.generation.max_shard_size_mb << 20
376+
if args.generation.download_config:
377+
kwargs.update(json.loads(args.generation.download_config))
377378

378379
if 'download_mode' in kwargs:
379380
kwargs['download_mode'] = tfds.download.GenerateMode(
380381
kwargs['download_mode']
381382
)
382383
else:
383384
kwargs['download_mode'] = tfds.download.GenerateMode.REUSE_DATASET_IF_EXISTS
384-
if args.update_metadata_only:
385+
if args.generation.update_metadata_only:
385386
kwargs['download_mode'] = tfds.download.GenerateMode.UPDATE_DATASET_INFO
386387

387388
dl_config = tfds.download.DownloadConfig(
388-
extract_dir=args.extract_dir,
389+
extract_dir=args.path.extract_dir,
389390
manual_dir=manual_dir,
390-
max_examples_per_split=args.max_examples_per_split,
391-
register_checksums=args.register_checksums,
392-
force_checksums_validation=args.force_checksums_validation,
391+
max_examples_per_split=args.debug.max_examples_per_split,
392+
register_checksums=args.generation.register_checksums,
393+
force_checksums_validation=args.generation.force_checksums_validation,
393394
**kwargs,
394395
)
395396

@@ -400,9 +401,9 @@ def _make_download_config(
400401
beam = None
401402

402403
if beam is not None:
403-
if args.beam_pipeline_options:
404+
if args.generation.beam_pipeline_options:
404405
dl_config.beam_options = beam.options.pipeline_options.PipelineOptions(
405-
flags=[f'--{opt}' for opt in args.beam_pipeline_options.split(',')]
406+
flags=[f'--{opt}' for opt in args.generation.beam_pipeline_options]
406407
)
407408

408409
return dl_config

tensorflow_datasets/scripts/cli/build_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def test_download_only():
316316
)
317317
def test_make_download_config(args: str, download_config_kwargs):
318318
args = main._parse_flags(f'tfds build x {args}'.split())
319-
actual = build_lib._make_download_config(args, dataset_name='x')
319+
actual = build_lib._make_download_config(args.command, dataset_name='x')
320320
# Ignore the beam runner
321321
actual = actual.replace(beam_runner=None)
322322
expected = tfds.download.DownloadConfig(**download_config_kwargs)

0 commit comments

Comments
 (0)