Skip to content

Commit

Permalink
support indexing of Argument
Browse files Browse the repository at this point in the history
  • Loading branch information
y1xiaoc committed Apr 11, 2021
1 parent 1e9d64d commit b032ee7
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 9 deletions.
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,3 @@ Set `make_anchor=True` when calling `gendoc` function and use standard ref synta
The id is the same as the argument path. Variant types would be in square brackets.

Please refer to test files for detailed usage.


## TODO

- [ ] possibly support of indexing by keys
35 changes: 32 additions & 3 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ def __eq__(self, other: "Argument") -> bool:
def __repr__(self) -> str:
return f"<Argument {self.name}: {' | '.join(dd.__name__ for dd in self.dtype)}>"

def __getitem__(self, key: str) -> "Argument":
key = key.lstrip("/")
if key in ("", "."):
return self
if key.startswith("["):
vkey, rkey = key[1:].split("]", 1)
if vkey.count("=") == 1:
fkey, ckey = vkey.split("=")
else:
[fkey] = self.sub_variants.keys()
ckey = vkey
return self.sub_variants[fkey][ckey][rkey]
p1, p2 = key.find("/"), key.find("[")
if max(p1, p2) < 0: # not found
return self.sub_fields[key]
else: # at least one found
p = p1 if p2 < 0 or 0 < p1 < p2 else p2
skey, rkey = key[:p], key[p:]
return self[skey][rkey]

@property
def I(self):
# return a dummy argument that only has self as a sub field
# can be used in indexing
return Argument("_", dict, [self])

def _reorg_dtype(self):
if isinstance(self.dtype, type) or self.dtype is None:
self.dtype = [self.dtype]
Expand Down Expand Up @@ -111,7 +137,7 @@ def add_subfield(self, name: Union[str, "Argument"],
newarg = Argument(name, *args, **kwargs)
self.extend_subfields([newarg])
return newarg

def extend_subvariants(self, sub_variants: Optional[Iterable["Variant"]]):
if sub_variants is None:
return
Expand Down Expand Up @@ -223,7 +249,7 @@ def _check_strict(self, value: dict):
if name not in allowed_keys:
raise KeyError(f"undefined key `{name}` is "
"not allowed in strict mode")

# above are type checking part
# below are normalizing part

Expand Down Expand Up @@ -371,6 +397,9 @@ def __eq__(self, other: "Variant") -> bool:
def __repr__(self) -> str:
return f"<Variant {self.flag_name} in {{ {', '.join(self.choice_dict.keys())} }}>"

def __getitem__(self, key: str) -> "Argument":
return self.choice_dict[key]

def set_default(self, default_tag : Union[bool, str]):
if not default_tag:
self.optional = False
Expand Down Expand Up @@ -424,7 +453,7 @@ def get_choice(self, argdict: dict) -> "Argument":
else:
raise KeyError(f"key `{self.flag_name}` is required "
"to choose variant but not found.")

def flatten_sub(self, argdict: dict) -> Dict[str, "Argument"]:
choice = self.get_choice(argdict)
fields = {self.flag_name: self.dummy_argument(), # as a placeholder
Expand Down
126 changes: 125 additions & 1 deletion tests/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ def test_sub_fields(self):
ca.set_repeat(True)
self.assertTrue(ca == ref)

def test_idx_fields(self):
s1 = Argument("sub1", int)
vt1 = Argument("type1", dict, [
Argument("shared", str),
Argument("vnt1_1", dict, [
Argument("vnt1_1_1", int)
])
])
vt2 = Argument("type2", dict, [
Argument("shared", int),
])
v1 = Variant("vnt_flag", [vt1, vt2])
ca = Argument("base", dict, [s1], [v1])
self.assertTrue(ca[''] is ca)
self.assertTrue(ca['.'] is ca)
self.assertTrue(ca['sub1'] == ca["./sub1"] == s1)
with self.assertRaises(KeyError):
ca["sub2"]
self.assertTrue(ca['[type1]'] is vt1)
self.assertTrue(ca['[type1]///'] is vt1)
self.assertTrue(ca['[type1]/vnt1_1/vnt1_1_1'] == Argument("vnt1_1_1", int))
self.assertTrue(ca['[type2]//shared'] == Argument("shared", int))
with self.assertRaises(KeyError):
s1["sub1"]
self.assertTrue(s1.I["sub1"] is s1)
self.assertTrue(ca.I["base[type1]"] is vt1)
self.assertTrue(ca.I['base[type2]//shared'] == Argument("shared", int))

def test_sub_variants(self):
ref = Argument("base", dict, [
Argument("sub1", int),
Expand Down Expand Up @@ -64,7 +92,7 @@ def test_sub_variants(self):
vt2s0 = vt2.add_subfield("shared", int)
vt2s1 = vt2.add_subfield("vnt2_1", int)
self.assertTrue(ca == ref)

# make sure we can modify the reference
ref1 = Argument("base", dict, [
Argument("sub1", int),
Argument("sub2", str)
Expand All @@ -88,6 +116,102 @@ def test_sub_variants(self):
v1.set_default(False)
self.assertTrue(ca == ref)

def test_idx_variants(self):
vt1 = Argument("type1", dict, [
Argument("shared", int),
Argument("vnt1_1", int),
Argument("vnt1_2", dict, [
Argument("vnt1_1_1", int)
])
])
vt2 = Argument("type2", dict, [
Argument("shared", int),
Argument("vnt2_1", int),
])
vnt = Variant("vnt_flag", [vt1, vt2])
self.assertTrue(vnt["type1"] is vt1)
self.assertTrue(vnt["type2"] is vt2)
with self.assertRaises(KeyError):
vnt["type3"]

def test_complicated(self):
ref = Argument("base", dict, [
Argument("sub1", int),
Argument("sub2", str)
], [
Variant("vnt_flag", [
Argument("type1", dict, [
Argument("shared", int),
Argument("vnt1_1", int),
Argument("vnt1_2", dict, [
Argument("vnt1_1_1", int)
])
]),
Argument("type2", dict, [
Argument("shared", int),
Argument("vnt2_1", int),
]),
Argument("type3", dict, [
Argument("vnt3_1", int)
], [ # testing cascade variants here
Variant("vnt3_flag1", [
Argument("v3f1t1", dict, [
Argument('v3f1t1_1', int),
Argument('v3f1t1_2', int)
]),
Argument("v3f1t2", dict, [
Argument('v3f1t2_1', int)
])
]),
Variant("vnt3_flag2", [
Argument("v3f2t1", dict, [
Argument('v3f2t1_1', int),
Argument('v3f2t1_2', int)
]),
Argument("v3f2t2", dict, [
Argument('v3f2t2_1', int)
])
])
])
])
])
ca = Argument("base", dict)
s1 = ca.add_subfield("sub1", int)
s2 = ca.add_subfield("sub2", str)
v1 = ca.add_subvariant("vnt_flag")
vt1 = v1.add_choice("type1", dict)
vt1s0 = vt1.add_subfield("shared", int)
vt1s1 = vt1.add_subfield("vnt1_1", int)
vt1s2 = vt1.add_subfield("vnt1_2", dict)
vt1ss = vt1s2.add_subfield("vnt1_1_1", int)
vt2 = v1.add_choice("type2")
vt2s0 = vt2.add_subfield("shared", int)
vt2s1 = vt2.add_subfield("vnt2_1", int)
vt3 = v1.add_choice("type3")
vt3s1 = vt3.add_subfield("vnt3_1", int)
vt3f1 = vt3.add_subvariant('vnt3_flag1')
vt3f1t1 = vt3f1.add_choice("v3f1t1")
vt3f1t1s1 = vt3f1t1.add_subfield("v3f1t1_1", int)
vt3f1t1s2 = vt3f1t1.add_subfield("v3f1t1_2", int)
vt3f1t2 = vt3f1.add_choice("v3f1t2")
vt3f1t2s1 = vt3f1t2.add_subfield("v3f1t2_1", int)
vt3f2 = vt3.add_subvariant('vnt3_flag2')
vt3f2t1 = vt3f2.add_choice("v3f2t1")
vt3f2t1s1 = vt3f2t1.add_subfield("v3f2t1_1", int)
vt3f2t1s2 = vt3f2t1.add_subfield("v3f2t1_2", int)
vt3f2t2 = vt3f2.add_choice("v3f2t2")
vt3f2t2s1 = vt3f2t2.add_subfield("v3f2t2_1", int)
self.assertTrue(ca == ref)
self.assertTrue(ca['[type3][vnt3_flag1=v3f1t1]'] is vt3f1t1)
self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t1]/v3f1t1_2'] is vt3f1t1s2)
self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t2]/v3f1t2_1'] is vt3f1t2s1)
self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t1]/v3f2t1_1'] is vt3f2t1s1)
self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t2]/v3f2t2_1'] is vt3f2t2s1)
with self.assertRaises((KeyError, ValueError)):
ca.I['base[type3][v3f2t2]']
with self.assertRaises((KeyError, ValueError)):
ca.I['base[type3][vnt3_flag3=v3f2t2]/v3f2t2_1']


if __name__ == "__main__":
unittest.main()

0 comments on commit b032ee7

Please sign in to comment.