-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
476 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .argparser import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from xingyun.argparser import ArgumentParser, PreCondition, MyDict | ||
import unittest | ||
import pdb | ||
|
||
def make_parser(): | ||
argp = ArgumentParser() | ||
|
||
# ---- model ---- | ||
argp.add_argument("model/model" , type = str, default = "TWIRLS", verify = lambda v: v in ["TWIRLS", "UGNN", "IGNN", "CNN"] ) | ||
argp.add_alias ("model" , "model/model") | ||
|
||
argp.add_argument("model/num_layers", type = int, default = 12 ) | ||
|
||
with PreCondition(lambda C: C["model/model"] in ["TWIRLS" , "UGNN", "IGNN"]): | ||
argp.add_argument("model/GNN-spec/prop_method", type = str, default = "message", verify = lambda v: v in ["message", "identity"] ) | ||
|
||
with PreCondition(lambda C: C["model/model"] in ["CNN"]): | ||
argp.add_bool ("model/CNN-spec/bn") | ||
|
||
|
||
# ---- optim ---- | ||
argp.add_argument("optim/lr", type = float, default = 0.1, verify = lambda v: v >= 0 ) | ||
argp.add_argument("optim/wd", type = float, default = 0.0, verify = lambda v: v >= 0 ) | ||
|
||
# ---- general ---- | ||
argp.add_bool ("no_activation") | ||
|
||
return argp | ||
|
||
|
||
|
||
|
||
class TestArgParser(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.argp = make_parser() | ||
return super().setUp() | ||
|
||
def test_main(self): | ||
argv = [ | ||
"--model=UGNN" , | ||
"--model/num_layers=3", | ||
"--model/GNN-spec/prop_method=message", | ||
"--model/CNN-spec/bn", | ||
"--optim/lr=0.1", | ||
"--optim/wd=5e-5", | ||
"--no_activation" | ||
] | ||
|
||
C = self.argp.parse(argv) | ||
|
||
print (C.sub("model")) | ||
|
||
self.assertTrue( C.sub("model")["model"] == "UGNN" ) | ||
self.assertTrue( C.sub("model")["num_layers"] == 3 ) | ||
self.assertTrue( C.sub("model")["GNN-spec/prop_method"] == "message" ) | ||
self.assertTrue( C.sub("model")["CNN-spec/bn"] is None ) | ||
self.assertTrue( C("model")["CNN-spec/bn"] == C("model")("CNN-spec")["bn"] ) | ||
self.assertTrue( C.sub("optim")["lr"] == 0.1) | ||
self.assertTrue( C.sub("optim")["wd"] == 5e-5 ) | ||
self.assertTrue( C["no_activation"] == True ) | ||
|
||
def test_raise(self): | ||
argv = [ | ||
"--model=UGNN" , | ||
"--model/num_layers=3", | ||
"--model/GNN-spec/prop_method=UGNN_message", | ||
"--model/CNN-spec/bn", | ||
"--optim/lr=0.1", | ||
"--optim/wd=5e-5", | ||
"--no_activation" | ||
] | ||
self.assertRaises(RuntimeError, lambda : self.argp.parse(argv)) | ||
|
||
def test_default(self): | ||
argv = [ | ||
"--model=CNN" , | ||
"--model/num_layers=12", | ||
"--model/GNN-spec/prop_method=UGNN_message", | ||
"--model/CNN-spec/bn", | ||
"--optim/lr=0.1", | ||
] | ||
C = self.argp.parse(argv) | ||
|
||
self.assertTrue( C.sub("model")["model"] == "CNN" ) | ||
self.assertTrue( C.sub("model")["num_layers"] == 12 ) | ||
self.assertTrue( C.sub("model")["GNN-spec/prop_method"] is None ) | ||
self.assertTrue( C.sub("model")["CNN-spec/bn"] == True ) | ||
self.assertTrue( C.sub("optim")["lr"] == 0.1) | ||
self.assertTrue( C.sub("optim")["wd"] == 0.0 ) | ||
self.assertTrue( C["no_activation"] == False ) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .my_random import * | ||
from .fix_random import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import sys | ||
import re | ||
print (sys.argv[1:]) | ||
|
||
pattern = r"^--([^=]+)(=(.+)|)$" | ||
|
||
a = "--model=layers" | ||
b = "--model/layers=3" | ||
c = "--model-layers=7" | ||
d = "--yes" | ||
e = "sd--aa=layers" | ||
f = "sd--yes" | ||
|
||
for x in [a,b,c,d,e,f,]: | ||
match = re.match(pattern, x) | ||
if match is not None: | ||
print (match.group(1), match.group(3)) | ||
else: | ||
print ("None") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from argparse import Namespace | ||
|
||
a = Namespace(a = 12, b = "3") | ||
a.__dict__["c.2"] = 100 | ||
|
||
print (a.c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ | |
from .random import * | ||
from .logger import * | ||
from .cloud import * | ||
from.argparser import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .argparser import Argument, PreCondition, Condition, ArgumentParser, MyDict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
import sys | ||
from typing import Callable, Any | ||
import re | ||
import warnings | ||
from argparse import Namespace | ||
from ..random.my_random import my_randint | ||
from ..universal.get_subdict import get_subdict | ||
|
||
class Condition: | ||
def __init__(self, condition: Callable[[ dict[str, Any] ], bool] = lambda _: True): | ||
self.condition = condition | ||
|
||
def test(self, config: dict[str, Any]): | ||
return self.condition(config) | ||
|
||
def __and__(self, other: "Condition"): | ||
return Condition(lambda v: self.condition(v) and other.condition(v)) | ||
|
||
def __or__(self, other: "Condition"): | ||
return Condition(lambda v: self.condition(v) or other.condition(v)) | ||
|
||
@classmethod | ||
def and_all(cls, conditions: list["Condition"]): | ||
r = Condition() | ||
for c in conditions: | ||
r = r & c | ||
return r | ||
|
||
|
||
condition_env = {} | ||
class PreCondition: | ||
def __init__(self, cond: Condition | Callable[[ dict[str, Any] ], bool]): | ||
|
||
if not isinstance(cond, Condition): | ||
cond = Condition(cond) | ||
|
||
self.cond = cond | ||
self.id = my_randint(0,2333333) | ||
def __enter__(self): | ||
condition_env[self.id] = self.cond | ||
def __exit__(self, *arg, **kwarg): | ||
condition_env.pop(self.id) | ||
|
||
class MyDict(dict): | ||
def __init__(self, d: dict = {}): | ||
super().__init__(d) | ||
|
||
def sub(self, prefix: str, splitter: str = "/"): | ||
return MyDict( get_subdict(self, prefix, splitter) ) | ||
|
||
def __call__(self, prefix: str, splitter: str = "/"): | ||
return self.sub(prefix, splitter) | ||
|
||
|
||
class Argument: | ||
'''The class that describe an argument. | ||
### Properties | ||
- name: the name of the argument. | ||
- type: how to convert string to a value. | ||
- default: the value set to the argument if it is not provided. | ||
- present: the value set to the argument if it is provided but without a value. | ||
- help: help string. | ||
- pre_condition: Before an argument is parsed, the pre-condition must be satisfied. | ||
- post_condition: After an argument is parsed, if the post-condition is not satisfied, then will raise an error. | ||
''' | ||
def __init__(self, | ||
name: str, | ||
type: Callable[[str] , Any], | ||
default: Any = None, | ||
present: Any = None, | ||
help:str = "" , | ||
pre_condition : Condition = Condition(), | ||
post_condition: Condition = Condition(), | ||
): | ||
self.name = name | ||
self.type = type | ||
self.default = default | ||
self.present = present | ||
self.pre_condition = pre_condition | ||
self.post_condition = post_condition | ||
self.help = help | ||
|
||
class ArgumentParser: | ||
'''This class is a modification of python `argparse.ArgumentParser` class. | ||
For each argument, there are two conditions: pre-condition and post-condition. | ||
Before an argument is parsed, the pre-condition must be satisfied. | ||
After an argument is parsed, if the post-condition is not satisfied, then will raise an error. | ||
''' | ||
|
||
def __init__(self, help: str = ""): | ||
|
||
self.help = help | ||
|
||
self.arguments: dict[str, Argument] = {} | ||
self.alias : dict[str, str] = {} # redirects to | ||
|
||
@property | ||
def now_preconds(self): | ||
return [c for id,c in condition_env.items() if c is not None] | ||
|
||
@classmethod | ||
def get_subconfig(cls, C: dict | Namespace, prefix: str, splitter: str = "/"): | ||
if isinstance(C, Namespace): | ||
C = C.__dict__ | ||
return get_subdict(C,prefix,splitter) | ||
|
||
def add_alias(self, alias: str, original: str): | ||
'''add an alias for an argument.''' | ||
self.alias[alias] = original | ||
|
||
def add_argument(self, | ||
name: str, | ||
type: Callable[[str] , Any], | ||
default: Any = None, | ||
present: Any = None, | ||
help: str = "" , | ||
verify : Callable[ [Any] , bool] = lambda _: True , | ||
): | ||
post_cond = Condition(lambda C: verify(C.get(name))) | ||
self.arguments[name] = Argument(name, type, default, present, help, Condition.and_all(self.now_preconds), post_cond) | ||
|
||
def add_bool(self, | ||
name: str, | ||
help: str = "" , | ||
verify : Callable[[ Any ], bool] = lambda _: True , | ||
): | ||
self.add_argument(name, bool, False, True, help, verify) | ||
|
||
def parse_namespace(self, | ||
args: list[str] | None = None, | ||
pattern = r"^--([^=]+)(=(.+)|)$", | ||
get_match: Callable[[re.Match], tuple[str,str]] = lambda m: (m.group(1), m.group(3)) , | ||
) -> Namespace: | ||
return Namespace(**self.parse(args, pattern, get_match)) | ||
|
||
def parse(self, | ||
args: list[str] | None = None, | ||
pattern = r"^--([^=]+)(=(.+)|)$", | ||
get_match: Callable[[re.Match], tuple[str,str]] = lambda m: (m.group(1), m.group(3)) , | ||
) -> MyDict: | ||
''' Parse argument. | ||
### Parameters | ||
-- args: arguments to be parsed. | ||
-- pattern: a regular expression to match name and value. | ||
-- get_match: a callable object that get name and value from the get_match. The input callable is a `re.Match` object, | ||
the output should be a 2-tuple, with the first element being the name and the second element being the value of the argument. | ||
''' | ||
if args is None: | ||
args = sys.argv[1:] | ||
|
||
# get value of each argument | ||
name_vals = {} | ||
for s in args: | ||
# get name val pairs | ||
match = re.match(pattern, s) | ||
if match is None: | ||
continue | ||
name, val = get_match(match) | ||
|
||
# apply alias | ||
alias_tar = self.alias.get(name) | ||
if alias_tar is not None: | ||
name = alias_tar | ||
|
||
# apply present val | ||
arg = self.arguments.get(name) | ||
if arg is None: | ||
continue | ||
if val is None: | ||
val = arg.present | ||
|
||
# record name, val pair | ||
name_vals[name] = val | ||
|
||
# apply default val | ||
for name , arg in self.arguments.items(): | ||
if not (name in name_vals): | ||
name_vals[name] = arg.default | ||
|
||
# actually assign values | ||
parsed = {} | ||
_t = 0 | ||
for _t in range(100): | ||
now_len = len(parsed) | ||
for name, val in name_vals.items(): | ||
|
||
# get Argument object | ||
arg = self.arguments.get(name) | ||
if arg is None: | ||
continue | ||
|
||
# check pre condition | ||
if not arg.pre_condition.test(parsed): | ||
continue | ||
|
||
# store value | ||
parsed[name] = arg.type(val) | ||
|
||
# check post condition | ||
if not arg.post_condition.test(parsed): | ||
raise RuntimeError(f"bad argument {name}: value can not be {parsed[name]}") | ||
|
||
new_len = len(parsed) | ||
if new_len == now_len: # no new argument is parsed | ||
break | ||
now_len = new_len | ||
|
||
if _t >= 98: | ||
warnings.warn("Too deep nested logic. Only performed 100 iterations.") | ||
|
||
# for those who forbidded by precondition, assign `None`. | ||
for name in self.arguments: | ||
if not (name in parsed): | ||
parsed[name] = None | ||
|
||
return MyDict(parsed) | ||
|
||
|
Oops, something went wrong.