Skip to content

Commit

Permalink
support alias for variant
Browse files Browse the repository at this point in the history
  • Loading branch information
y1xiaoc committed Jan 6, 2021
1 parent 6bec152 commit e55a4bd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 16 deletions.
67 changes: 53 additions & 14 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

INDENT = " " # doc is indented by four spaces
DUMMYHOOK = lambda a,x: None
class _Flags(Enum): NONE = 0
class _Flags(Enum): NONE = 0 # for no value in dict

class Argument:

Expand Down Expand Up @@ -112,30 +112,37 @@ def add_subvariant(self, flag_name: Union[str, "Variant"],
def traverse(self, argdict: dict,
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK):
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
# first, do something with the key
# then, take out the vaule and do something with it
key_hook(self, argdict)
if self.name in argdict:
# this is the key step that we traverse into the tree
self.traverse_value(argdict[self.name], key_hook, value_hook, sub_hook)
self.traverse_value(argdict[self.name],
key_hook, value_hook, sub_hook, variant_hook)

def traverse_value(self, value: Any,
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK):
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
# this is not private, and can be called directly
# in the condition where there is no leading key
value_hook(self, value)
if isinstance(value, dict):
sub_hook(self, value)
self._traverse_subfield(value, key_hook, value_hook, sub_hook)
self._traverse_subvariant(value, key_hook, value_hook, sub_hook)
self._traverse_subfield(value,
key_hook, value_hook, sub_hook, variant_hook)
self._traverse_subvariant(value,
key_hook, value_hook, sub_hook, variant_hook)
if isinstance(value, list) and self.repeat:
for item in value:
sub_hook(self, item)
self._traverse_subfield(item, key_hook, value_hook, sub_hook)
self._traverse_subvariant(item, key_hook, value_hook, sub_hook)
self._traverse_subfield(item,
key_hook, value_hook, sub_hook, variant_hook)
self._traverse_subvariant(item,
key_hook, value_hook, sub_hook, variant_hook)

def _traverse_subfield(self, value: dict, *args, **kwargs):
assert isinstance(value, dict)
Expand Down Expand Up @@ -202,9 +209,13 @@ def normalize(self, argdict: dict, inplace: bool = False,
if not inplace:
argdict = deepcopy(argdict)
if do_alias:
self.traverse(argdict, key_hook=Argument._convert_alias)
self.traverse(argdict,
key_hook=Argument._convert_alias,
variant_hook=Variant._convert_alias)
if do_default:
self.traverse(argdict, key_hook=Argument._assign_default)
self.traverse(argdict,
key_hook=Argument._assign_default,
variant_hook=Variant._assign_default)
if trim_pattern is not None:
self._trim_unrequired(argdict, trim_pattern, reserved=[self.name])
self.traverse(argdict, sub_hook=lambda a, d:
Expand All @@ -217,9 +228,13 @@ def normalize_value(self, value: Any, inplace: bool = False,
if not inplace:
value = deepcopy(value)
if do_alias:
self.traverse_value(value, key_hook=Argument._convert_alias)
self.traverse_value(value,
key_hook=Argument._convert_alias,
variant_hook=Variant._convert_alias)
if do_default:
self.traverse_value(value, key_hook=Argument._assign_default)
self.traverse_value(value,
key_hook=Argument._assign_default,
variant_hook=Variant._assign_default)
if trim_pattern is not None:
self.traverse_value(value, sub_hook=lambda a, d:
Argument._trim_unrequired(d, trim_pattern, a._get_allowed_sub(d)))
Expand Down Expand Up @@ -314,6 +329,7 @@ def __init__(self,
doc: str = ""):
self.flag_name = flag_name
self.choice_dict = {}
self.alias_dict = {}
if choices is not None:
self.extend_choices(choices)
self.optional = optional
Expand All @@ -340,6 +356,13 @@ def extend_choices(self, choices: Iterable["Argument"]):
raise ValueError(f"duplicate tag `{tag}` appears in "
f"variant with flag `{self.flag_name}`")
self.choice_dict[tag] = arg
# also update alias here
for atag in arg.alias:
if atag in self.choice_dict or atag in self.alias_dict:
raise ValueError(f"duplicate alias tag `{atag}` appears in "
f"variant with flag `{self.flag_name}` "
f"and choice name `{arg.name}`")
self.alias_dict[atag] = arg.name

def set_default(self, default_tag : Union[bool, str]):
if not default_tag:
Expand All @@ -363,10 +386,16 @@ def add_choice(self, tag: Union[str, "Argument"],
# above are creation part
# below are general traverse part

def traverse(self, argdict: dict, *args, **kwargs):
def traverse(self, argdict: dict,
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
variant_hook(self, argdict)
choice = self._load_choice(argdict)
# here we use check_value to flatten the tag
choice.traverse_value(argdict, *args, **kwargs)
choice.traverse_value(argdict,
key_hook, value_hook, sub_hook, variant_hook)

def _load_choice(self, argdict: dict) -> "Argument":
if self.flag_name in argdict:
Expand All @@ -384,6 +413,16 @@ def _get_allowed_sub(self, argdict: dict) -> List[str]:
allowed.extend(choice._get_allowed_sub(argdict))
return allowed

def _assign_default(self, argdict: dict):
if self.flag_name not in argdict and self.optional:
argdict[self.flag_name] = self.default_tag

def _convert_alias(self, argdict: dict):
if self.flag_name in argdict:
tag = argdict[self.flag_name]
if tag not in self.choice_dict and tag in self.alias_dict:
argdict[self.flag_name] = self.alias_dict[tag]

# above are type checking part
# below are doc generation part

Expand Down
5 changes: 3 additions & 2 deletions tests/test_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ def test_complicated(self):
Argument("type2", dict, [
Argument("shared", int, optional=True, default=-2, alias=["sharedb"]),
Argument("vnt2", int, optional=True, default=222, alias=["vnt2a"]),
])
], alias = ['type3'])
], optional=True, default_tag="type1")
])
beg1 = {"base": {"sub2": [{}, {}]}}
ref1 = {
'base': {
'sub1': 1,
'sub2': [{'ss1': 21}, {'ss1': 21}],
'vnt_flag': "type1",
'shared': -1,
'vnt1': 111}
}
Expand All @@ -96,7 +97,7 @@ def test_complicated(self):
"base": {
"sub1a": 2,
"sub2a": [{"ss1a":22}, {"_comment1": None}],
"vnt_flag": "type2",
"vnt_flag": "type3",
"sharedb": -3,
"vnt2a": 223,
"_comment2": None}
Expand Down

0 comments on commit e55a4bd

Please sign in to comment.