Skip to content

Commit 927ec02

Browse files
committed
- Added support for dataclasses nested in a type.
- Fixed add_dataclass_arguments not forwarding sub_configs parameter.
1 parent 17fcf0a commit 927ec02

13 files changed

+167
-81
lines changed

CHANGELOG.rst

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ v4.21.0 (2023-04-??)
1717

1818
Added
1919
^^^^^
20+
- Support for dataclasses nested in a type.
2021
- Support for pydantic models and attr defines similar to dataclasses.
2122

2223
Fixed
@@ -27,6 +28,7 @@ Fixed
2728
<https://github.com/Lightning-AI/lightning/issues/17254>`__).
2829
- ``dataclass`` from pydantic not working (`#100 (comment)
2930
<https://github.com/omni-us/jsonargparse/issues/100#issuecomment-1408413796>`__).
31+
- ``add_dataclass_arguments`` not forwarding ``sub_configs`` parameter.
3032

3133
Changed
3234
^^^^^^^

README.rst

+6-7
Original file line numberDiff line numberDiff line change
@@ -430,19 +430,18 @@ Some notes about this support are:
430430
config files and environment variables, tuples and sets are represented as an
431431
array.
432432

433-
- ``dataclasses`` are supported as a type but only for pure data classes and not
434-
nested in a type. By pure it is meant that the class only inherits from data
435-
classes. Not a mixture of normal classes and data classes. Data classes as
436-
fields of other data classes is supported. Pydantic's ``dataclass`` decorator
437-
and ``BaseModel`` classes, and attrs' ``define`` decorator are supported
438-
like standard dataclasses. Though, this support is currently experimental.
439-
440433
- To set a value to ``None`` it is required to use ``null`` since this is how
441434
json/yaml defines it. To avoid confusion in the help, ``NoneType`` is
442435
displayed as ``null``. For example a function argument with type and default
443436
``Optional[str] = None`` would be shown in the help as ``type: Union[str,
444437
null], default: null``.
445438

439+
- ``dataclasses`` are supported even when nested. Final classes, attrs'
440+
``define`` decorator, and pydantic's ``dataclass`` decorator and ``BaseModel``
441+
classes are supported and behave like standard dataclasses. If a dataclass
442+
inherits from a normal class, the type is considered a subclass instead of a
443+
dataclass, see :ref:`sub-classes`.
444+
446445
- Normal classes can be used as a type, which are specified with a dict
447446
containing ``class_path`` and optionally ``init_args``.
448447
:py:meth:`.ArgumentParser.instantiate_classes` can be used to instantiate all

jsonargparse/_common.py

+34
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import dataclasses
2+
import inspect
13
from contextlib import contextmanager
24
from contextvars import ContextVar
35
from typing import Optional, Union
46

57
from .namespace import Namespace
8+
from .optionals import attrs_support, import_attrs, import_pydantic, pydantic_support
69
from .type_checking import ArgumentParser
710

811
parent_parser: ContextVar['ArgumentParser'] = ContextVar('parent_parser')
@@ -33,3 +36,34 @@ def parser_context(**kwargs):
3336
finally:
3437
for context_var, token in context_var_tokens:
3538
context_var.reset(token)
39+
40+
41+
def is_subclass(cls, class_or_tuple) -> bool:
42+
"""Extension of issubclass that supports non-class arguments."""
43+
try:
44+
return inspect.isclass(cls) and issubclass(cls, class_or_tuple)
45+
except TypeError:
46+
return False
47+
48+
49+
def is_final_class(cls) -> bool:
50+
"""Checks whether a class is final, i.e. decorated with ``typing.final``."""
51+
return getattr(cls, '__final__', False)
52+
53+
54+
def is_dataclass_like(cls) -> bool:
55+
if not inspect.isclass(cls):
56+
return False
57+
if is_final_class(cls):
58+
return True
59+
classes = [c for c in inspect.getmro(cls) if c != object]
60+
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)
61+
if not all_dataclasses and pydantic_support:
62+
pydantic = import_pydantic('is_dataclass_like')
63+
classes = [c for c in classes if c != pydantic.utils.Representation]
64+
all_dataclasses = all(is_subclass(c, pydantic.BaseModel) for c in classes)
65+
if not all_dataclasses and attrs_support:
66+
attrs = import_attrs('is_dataclass_like')
67+
if attrs.has(cls):
68+
return True
69+
return all_dataclasses

jsonargparse/actions.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextvars import ContextVar
1111
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1212

13-
from ._common import parser_context
13+
from ._common import is_subclass, parser_context
1414
from .loaders_dumpers import get_loader_exceptions, load_value
1515
from .namespace import Namespace, split_key, split_key_root
1616
from .optionals import FilesCompleterMethod, get_config_read_mode
@@ -25,7 +25,6 @@
2525
get_typehint_origin,
2626
import_object,
2727
indent_text,
28-
is_subclass,
2928
iter_to_set_str,
3029
parse_value_or_config,
3130
)

jsonargparse/core.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Union,
2323
)
2424

25-
from ._common import lenient_check, parser_context
25+
from ._common import is_dataclass_like, is_subclass, lenient_check, parser_context
2626
from .actions import (
2727
ActionConfigFile,
2828
ActionParser,
@@ -71,16 +71,14 @@
7171
import_jsonnet,
7272
)
7373
from .parameter_resolvers import UnknownDefault
74-
from .signatures import SignatureArguments, is_pure_dataclass
74+
from .signatures import SignatureArguments
7575
from .typehints import ActionTypeHint, is_subclass_spec
76-
from .typing import is_final_class
7776
from .util import (
7877
Path,
7978
argument_error,
8079
change_to_path_dir,
8180
get_private_kwargs,
8281
identity,
83-
is_subclass,
8482
return_parser_if_captured,
8583
)
8684

@@ -118,14 +116,10 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs):
118116
if is_subclass(kwargs['action'], ActionConfigFile) and any(isinstance(a, ActionConfigFile) for a in self._actions):
119117
raise ValueError('A parser is only allowed to have a single ActionConfigFile argument.')
120118
if 'type' in kwargs:
121-
if is_final_class(kwargs['type']) or is_pure_dataclass(kwargs['type']):
119+
if is_dataclass_like(kwargs['type']):
122120
theclass = kwargs.pop('type')
123121
nested_key = re.sub('^--', '', args[0])
124-
if is_final_class(theclass):
125-
kwargs.pop('help', None)
126-
self.add_class_arguments(theclass, nested_key, **kwargs)
127-
else:
128-
self.add_dataclass_arguments(theclass, nested_key, **kwargs)
122+
self.add_dataclass_arguments(theclass, nested_key, **kwargs)
129123
return _find_action(parser, nested_key)
130124
if ActionTypeHint.is_supported_typehint(kwargs['type']):
131125
args = ActionTypeHint.prepare_add_argument(
@@ -1109,7 +1103,7 @@ def instantiate_classes(
11091103
components: List[Union[ActionTypeHint, _ActionConfigLoad, _ArgumentGroup]] = []
11101104
for action in filter_default_actions(self._actions):
11111105
if isinstance(action, ActionTypeHint) or \
1112-
(isinstance(action, _ActionConfigLoad) and is_pure_dataclass(action.basetype)):
1106+
(isinstance(action, _ActionConfigLoad) and is_dataclass_like(action.basetype)):
11131107
components.append(action)
11141108

11151109
if instantiate_groups:

jsonargparse/deprecated.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class ActionEnum:
182182

183183
def __init__(self, **kwargs):
184184
if 'enum' in kwargs:
185-
from .util import is_subclass
185+
from ._common import is_subclass
186186
if not is_subclass(kwargs['enum'], Enum):
187187
raise ValueError('Expected enum to be an subclass of Enum.')
188188
self._type = kwargs['enum']

jsonargparse/parameter_resolvers.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
from functools import partial
1111
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
1212

13+
from ._common import is_dataclass_like, is_subclass
1314
from ._stubs_resolver import get_stub_types
1415
from .optionals import parse_docs
1516
from .util import (
1617
ClassFromFunctionBase,
1718
LoggerProperty,
1819
get_import_path,
19-
is_subclass,
2020
iter_to_set_str,
2121
parse_logger,
2222
unique,
@@ -325,8 +325,7 @@ def get_kwargs_pop_or_get_parameter(node, component, parent, doc_params, log_deb
325325

326326

327327
def is_param_subclass_instance_default(param: ParamData) -> bool:
328-
from .signatures import is_pure_dataclass
329-
if is_pure_dataclass(type(param.default)):
328+
if is_dataclass_like(type(param.default)):
330329
return False
331330
from .typehints import ActionTypeHint, get_subclass_types
332331
class_types = get_subclass_types(param.annotation)

jsonargparse/signatures.py

+17-37
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,16 @@
77
from contextlib import suppress
88
from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union
99

10+
from ._common import is_dataclass_like, is_subclass
1011
from .actions import _ActionConfigLoad
11-
from .optionals import (
12-
attrs_support,
13-
get_doc_short_description,
14-
import_attrs,
15-
import_pydantic,
16-
pydantic_support,
17-
)
12+
from .optionals import get_doc_short_description, import_pydantic, pydantic_support
1813
from .parameter_resolvers import (
1914
ParamData,
2015
get_parameter_origins,
2116
get_signature_parameters,
2217
)
2318
from .typehints import ActionTypeHint, LazyInitBaseClass, is_optional
24-
from .typing import is_final_class
25-
from .util import LoggerProperty, get_import_path, is_subclass, iter_to_set_str
19+
from .util import LoggerProperty, get_import_path, iter_to_set_str
2620

2721
__all__ = [
2822
'compose_dataclasses',
@@ -93,7 +87,7 @@ def add_class_arguments(
9387
if default:
9488
skip = skip or set()
9589
prefix = nested_key+'.' if nested_key else ''
96-
defaults = default.lazy_get_init_data().as_dict()
90+
defaults = default.lazy_get_init_args()
9791
if defaults:
9892
defaults = {prefix+k: v for k, v in defaults.items() if k not in skip}
9993
self.set_defaults(**defaults) # type: ignore
@@ -317,8 +311,7 @@ def _add_signature_parameter(
317311
elif not as_positional:
318312
kwargs['required'] = True
319313
is_subclass_typehint = False
320-
is_final_class_typehint = is_final_class(annotation)
321-
is_pure_dataclass_typehint = is_pure_dataclass(annotation)
314+
is_dataclass_like_typehint = is_dataclass_like(annotation)
322315
dest = (nested_key+'.' if nested_key else '') + name
323316
args = [dest if is_required and as_positional else '--'+dest]
324317
if param.origin:
@@ -332,8 +325,7 @@ def _add_signature_parameter(
332325
)
333326
if annotation in {str, int, float, bool} or \
334327
is_subclass(annotation, (str, int, float)) or \
335-
is_final_class_typehint or \
336-
is_pure_dataclass_typehint:
328+
is_dataclass_like_typehint:
337329
kwargs['type'] = annotation
338330
elif annotation != inspect_empty:
339331
try:
@@ -360,7 +352,7 @@ def _add_signature_parameter(
360352
'sub_configs': sub_configs,
361353
'instantiate': instantiate,
362354
}
363-
if is_final_class_typehint or is_pure_dataclass_typehint:
355+
if is_dataclass_like_typehint:
364356
kwargs.update(sub_add_kwargs)
365357
action = group.add_argument(*args, **kwargs)
366358
action.sub_add_kwargs = sub_add_kwargs
@@ -401,8 +393,8 @@ def add_dataclass_arguments(
401393
ValueError: When not given a dataclass.
402394
ValueError: When default is not instance of or kwargs for theclass.
403395
"""
404-
if not is_pure_dataclass(theclass):
405-
raise ValueError(f'Expected "theclass" argument to be a pure dataclass, given {theclass}')
396+
if not is_dataclass_like(theclass):
397+
raise ValueError(f'Expected "theclass" argument to be a dataclass-like, given {theclass}')
406398

407399
doc_group = get_doc_short_description(theclass, logger=self.logger)
408400
for key in ['help', 'title']:
@@ -420,6 +412,7 @@ def add_dataclass_arguments(
420412
defaults = dataclass_to_dict(default)
421413

422414
added_args: List[str] = []
415+
param_kwargs = {k: v for k, v in kwargs.items() if k == 'sub_configs'}
423416
for param in get_signature_parameters(theclass, None, logger=self.logger):
424417
self._add_signature_parameter(
425418
group,
@@ -428,6 +421,7 @@ def add_dataclass_arguments(
428421
added_args,
429422
fail_untyped=fail_untyped,
430423
default=defaults.get(param.name, inspect_empty),
424+
**param_kwargs,
431425
)
432426

433427
return added_args
@@ -467,8 +461,8 @@ def add_subclass_arguments(
467461
Raises:
468462
ValueError: When given an invalid base class.
469463
"""
470-
if is_final_class(baseclass):
471-
raise ValueError("Not allowed for classes that are final.")
464+
if is_dataclass_like(baseclass):
465+
raise ValueError("Not allowed for dataclass-like classes.")
472466
if type(baseclass) is not tuple:
473467
baseclass = (baseclass,) # type: ignore
474468
if not all(inspect.isclass(c) for c in baseclass):
@@ -550,32 +544,18 @@ def is_factory_class(value):
550544
return value.__class__ == dataclasses._HAS_DEFAULT_FACTORY_CLASS
551545

552546

553-
def is_pure_dataclass(value):
554-
if not inspect.isclass(value):
555-
return False
556-
classes = [c for c in inspect.getmro(value) if c != object]
557-
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)
558-
if not all_dataclasses and pydantic_support:
559-
pydantic = import_pydantic('is_pure_dataclass')
560-
classes = [c for c in classes if c != pydantic.utils.Representation]
561-
all_dataclasses = all(is_subclass(c, pydantic.BaseModel) for c in classes)
562-
if not all_dataclasses and attrs_support:
563-
attrs = import_attrs('is_pure_dataclass')
564-
if attrs.has(value):
565-
return True
566-
return all_dataclasses
567-
568-
569-
def dataclass_to_dict(value):
547+
def dataclass_to_dict(value) -> dict:
570548
if pydantic_support:
571549
pydantic = import_pydantic('dataclass_to_dict')
572550
if isinstance(value, pydantic.BaseModel):
573551
return value.dict()
552+
if isinstance(value, LazyInitBaseClass):
553+
return value.lazy_get_init_data().as_dict()
574554
return dataclasses.asdict(value)
575555

576556

577557
def compose_dataclasses(*args):
578-
"""Returns a pure dataclass inheriting all given dataclasses and properly handling __post_init__."""
558+
"""Returns a dataclass inheriting all given dataclasses and properly handling __post_init__."""
579559

580560
@dataclasses.dataclass
581561
class ComposedDataclass(*args):

0 commit comments

Comments
 (0)