7
7
from contextlib import suppress
8
8
from typing import Any , Callable , List , Optional , Set , Tuple , Type , Union
9
9
10
+ from ._common import is_dataclass_like , is_subclass
10
11
from .actions import _ActionConfigLoad
11
- from .optionals import (
12
- attrs_support ,
13
- get_doc_short_description ,
14
- import_attrs ,
15
- import_pydantic ,
16
- pydantic_support ,
17
- )
12
+ from .optionals import get_doc_short_description , import_pydantic , pydantic_support
18
13
from .parameter_resolvers import (
19
14
ParamData ,
20
15
get_parameter_origins ,
21
16
get_signature_parameters ,
22
17
)
23
18
from .typehints import ActionTypeHint , LazyInitBaseClass , is_optional
24
- from .typing import is_final_class
25
- from .util import LoggerProperty , get_import_path , is_subclass , iter_to_set_str
19
+ from .util import LoggerProperty , get_import_path , iter_to_set_str
26
20
27
21
__all__ = [
28
22
'compose_dataclasses' ,
@@ -93,7 +87,7 @@ def add_class_arguments(
93
87
if default :
94
88
skip = skip or set ()
95
89
prefix = nested_key + '.' if nested_key else ''
96
- defaults = default .lazy_get_init_data (). as_dict ()
90
+ defaults = default .lazy_get_init_args ()
97
91
if defaults :
98
92
defaults = {prefix + k : v for k , v in defaults .items () if k not in skip }
99
93
self .set_defaults (** defaults ) # type: ignore
@@ -317,8 +311,7 @@ def _add_signature_parameter(
317
311
elif not as_positional :
318
312
kwargs ['required' ] = True
319
313
is_subclass_typehint = False
320
- is_final_class_typehint = is_final_class (annotation )
321
- is_pure_dataclass_typehint = is_pure_dataclass (annotation )
314
+ is_dataclass_like_typehint = is_dataclass_like (annotation )
322
315
dest = (nested_key + '.' if nested_key else '' ) + name
323
316
args = [dest if is_required and as_positional else '--' + dest ]
324
317
if param .origin :
@@ -332,8 +325,7 @@ def _add_signature_parameter(
332
325
)
333
326
if annotation in {str , int , float , bool } or \
334
327
is_subclass (annotation , (str , int , float )) or \
335
- is_final_class_typehint or \
336
- is_pure_dataclass_typehint :
328
+ is_dataclass_like_typehint :
337
329
kwargs ['type' ] = annotation
338
330
elif annotation != inspect_empty :
339
331
try :
@@ -360,7 +352,7 @@ def _add_signature_parameter(
360
352
'sub_configs' : sub_configs ,
361
353
'instantiate' : instantiate ,
362
354
}
363
- if is_final_class_typehint or is_pure_dataclass_typehint :
355
+ if is_dataclass_like_typehint :
364
356
kwargs .update (sub_add_kwargs )
365
357
action = group .add_argument (* args , ** kwargs )
366
358
action .sub_add_kwargs = sub_add_kwargs
@@ -401,8 +393,8 @@ def add_dataclass_arguments(
401
393
ValueError: When not given a dataclass.
402
394
ValueError: When default is not instance of or kwargs for theclass.
403
395
"""
404
- if not is_pure_dataclass (theclass ):
405
- raise ValueError (f'Expected "theclass" argument to be a pure dataclass, given { theclass } ' )
396
+ if not is_dataclass_like (theclass ):
397
+ raise ValueError (f'Expected "theclass" argument to be a dataclass-like , given { theclass } ' )
406
398
407
399
doc_group = get_doc_short_description (theclass , logger = self .logger )
408
400
for key in ['help' , 'title' ]:
@@ -420,6 +412,7 @@ def add_dataclass_arguments(
420
412
defaults = dataclass_to_dict (default )
421
413
422
414
added_args : List [str ] = []
415
+ param_kwargs = {k : v for k , v in kwargs .items () if k == 'sub_configs' }
423
416
for param in get_signature_parameters (theclass , None , logger = self .logger ):
424
417
self ._add_signature_parameter (
425
418
group ,
@@ -428,6 +421,7 @@ def add_dataclass_arguments(
428
421
added_args ,
429
422
fail_untyped = fail_untyped ,
430
423
default = defaults .get (param .name , inspect_empty ),
424
+ ** param_kwargs ,
431
425
)
432
426
433
427
return added_args
@@ -467,8 +461,8 @@ def add_subclass_arguments(
467
461
Raises:
468
462
ValueError: When given an invalid base class.
469
463
"""
470
- if is_final_class (baseclass ):
471
- raise ValueError ("Not allowed for classes that are final ." )
464
+ if is_dataclass_like (baseclass ):
465
+ raise ValueError ("Not allowed for dataclass-like classes ." )
472
466
if type (baseclass ) is not tuple :
473
467
baseclass = (baseclass ,) # type: ignore
474
468
if not all (inspect .isclass (c ) for c in baseclass ):
@@ -550,32 +544,18 @@ def is_factory_class(value):
550
544
return value .__class__ == dataclasses ._HAS_DEFAULT_FACTORY_CLASS
551
545
552
546
553
- def is_pure_dataclass (value ):
554
- if not inspect .isclass (value ):
555
- return False
556
- classes = [c for c in inspect .getmro (value ) if c != object ]
557
- all_dataclasses = all (dataclasses .is_dataclass (c ) for c in classes )
558
- if not all_dataclasses and pydantic_support :
559
- pydantic = import_pydantic ('is_pure_dataclass' )
560
- classes = [c for c in classes if c != pydantic .utils .Representation ]
561
- all_dataclasses = all (is_subclass (c , pydantic .BaseModel ) for c in classes )
562
- if not all_dataclasses and attrs_support :
563
- attrs = import_attrs ('is_pure_dataclass' )
564
- if attrs .has (value ):
565
- return True
566
- return all_dataclasses
567
-
568
-
569
- def dataclass_to_dict (value ):
547
+ def dataclass_to_dict (value ) -> dict :
570
548
if pydantic_support :
571
549
pydantic = import_pydantic ('dataclass_to_dict' )
572
550
if isinstance (value , pydantic .BaseModel ):
573
551
return value .dict ()
552
+ if isinstance (value , LazyInitBaseClass ):
553
+ return value .lazy_get_init_data ().as_dict ()
574
554
return dataclasses .asdict (value )
575
555
576
556
577
557
def compose_dataclasses (* args ):
578
- """Returns a pure dataclass inheriting all given dataclasses and properly handling __post_init__."""
558
+ """Returns a dataclass inheriting all given dataclasses and properly handling __post_init__."""
579
559
580
560
@dataclasses .dataclass
581
561
class ComposedDataclass (* args ):
0 commit comments