From adea9df1001acd405052b21b8a74350926904310 Mon Sep 17 00:00:00 2001 From: Yongyi <1004473299@qq.com> Date: Wed, 27 Dec 2023 00:07:46 -0500 Subject: [PATCH] add argparser --- setup.py | 2 +- test/__init__.py | 1 + test/argparser/__init__.py | 1 + test/argparser/argparser.py | 95 +++++++++ test/random/__init__.py | 2 +- test/random/{my_random.py => fix_random.py} | 4 +- test/try/try_argv.py | 19 ++ test/try/try_namespace.py | 6 + xingyun/__init__.py | 1 + xingyun/argparser/__init__.py | 1 + xingyun/argparser/argparser.py | 221 ++++++++++++++++++++ xingyun/random/__init__.py | 2 +- xingyun/random/fix_random.py | 97 +++++++++ xingyun/random/my_random.py | 97 ++------- xingyun/universal/get_subdict.py | 17 ++ 15 files changed, 476 insertions(+), 90 deletions(-) create mode 100644 test/argparser/__init__.py create mode 100644 test/argparser/argparser.py rename test/random/{my_random.py => fix_random.py} (90%) create mode 100644 test/try/try_argv.py create mode 100644 test/try/try_namespace.py create mode 100644 xingyun/argparser/__init__.py create mode 100644 xingyun/argparser/argparser.py create mode 100644 xingyun/random/fix_random.py create mode 100644 xingyun/universal/get_subdict.py diff --git a/setup.py b/setup.py index cd72734..c76e4c9 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setup( name="xingyun", - version="0.0.3", + version="0.0.4", url="http://github.com/FFTYYY/XingYun", description="", long_description=readme, diff --git a/test/__init__.py b/test/__init__.py index 36dde8a..a5812e1 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -4,6 +4,7 @@ from .time import * from .cloud import * from .logger import * +from .argparser import * if __name__ == "__main__": testsuite = unittest.TestLoader().discover(".") diff --git a/test/argparser/__init__.py b/test/argparser/__init__.py new file mode 100644 index 0000000..51a0d3f --- /dev/null +++ b/test/argparser/__init__.py @@ -0,0 +1 @@ +from .argparser import * \ No newline at end of file diff --git a/test/argparser/argparser.py b/test/argparser/argparser.py new file mode 100644 index 0000000..dc158ea --- /dev/null +++ b/test/argparser/argparser.py @@ -0,0 +1,95 @@ +from xingyun.argparser import ArgumentParser, PreCondition, MyDict +import unittest +import pdb + +def make_parser(): + argp = ArgumentParser() + + # ---- model ---- + argp.add_argument("model/model" , type = str, default = "TWIRLS", verify = lambda v: v in ["TWIRLS", "UGNN", "IGNN", "CNN"] ) + argp.add_alias ("model" , "model/model") + + argp.add_argument("model/num_layers", type = int, default = 12 ) + + with PreCondition(lambda C: C["model/model"] in ["TWIRLS" , "UGNN", "IGNN"]): + argp.add_argument("model/GNN-spec/prop_method", type = str, default = "message", verify = lambda v: v in ["message", "identity"] ) + + with PreCondition(lambda C: C["model/model"] in ["CNN"]): + argp.add_bool ("model/CNN-spec/bn") + + + # ---- optim ---- + argp.add_argument("optim/lr", type = float, default = 0.1, verify = lambda v: v >= 0 ) + argp.add_argument("optim/wd", type = float, default = 0.0, verify = lambda v: v >= 0 ) + + # ---- general ---- + argp.add_bool ("no_activation") + + return argp + + + + +class TestArgParser(unittest.TestCase): + def setUp(self) -> None: + self.argp = make_parser() + return super().setUp() + + def test_main(self): + argv = [ + "--model=UGNN" , + "--model/num_layers=3", + "--model/GNN-spec/prop_method=message", + "--model/CNN-spec/bn", + "--optim/lr=0.1", + "--optim/wd=5e-5", + "--no_activation" + ] + + C = self.argp.parse(argv) + + print (C.sub("model")) + + self.assertTrue( C.sub("model")["model"] == "UGNN" ) + self.assertTrue( C.sub("model")["num_layers"] == 3 ) + self.assertTrue( C.sub("model")["GNN-spec/prop_method"] == "message" ) + self.assertTrue( C.sub("model")["CNN-spec/bn"] is None ) + self.assertTrue( C("model")["CNN-spec/bn"] == C("model")("CNN-spec")["bn"] ) + self.assertTrue( C.sub("optim")["lr"] == 0.1) + self.assertTrue( C.sub("optim")["wd"] == 5e-5 ) + self.assertTrue( C["no_activation"] == True ) + + def test_raise(self): + argv = [ + "--model=UGNN" , + "--model/num_layers=3", + "--model/GNN-spec/prop_method=UGNN_message", + "--model/CNN-spec/bn", + "--optim/lr=0.1", + "--optim/wd=5e-5", + "--no_activation" + ] + self.assertRaises(RuntimeError, lambda : self.argp.parse(argv)) + + def test_default(self): + argv = [ + "--model=CNN" , + "--model/num_layers=12", + "--model/GNN-spec/prop_method=UGNN_message", + "--model/CNN-spec/bn", + "--optim/lr=0.1", + ] + C = self.argp.parse(argv) + + self.assertTrue( C.sub("model")["model"] == "CNN" ) + self.assertTrue( C.sub("model")["num_layers"] == 12 ) + self.assertTrue( C.sub("model")["GNN-spec/prop_method"] is None ) + self.assertTrue( C.sub("model")["CNN-spec/bn"] == True ) + self.assertTrue( C.sub("optim")["lr"] == 0.1) + self.assertTrue( C.sub("optim")["wd"] == 0.0 ) + self.assertTrue( C["no_activation"] == False ) + + + +if __name__ == "__main__": + unittest.main() diff --git a/test/random/__init__.py b/test/random/__init__.py index a1998d5..359e34d 100644 --- a/test/random/__init__.py +++ b/test/random/__init__.py @@ -1 +1 @@ -from .my_random import * \ No newline at end of file +from .fix_random import * \ No newline at end of file diff --git a/test/random/my_random.py b/test/random/fix_random.py similarity index 90% rename from test/random/my_random.py rename to test/random/fix_random.py index 85b4fe9..cb296e2 100644 --- a/test/random/my_random.py +++ b/test/random/fix_random.py @@ -1,7 +1,7 @@ import random import torch from xingyun.random.set_random_seed import set_random_seed -from xingyun.random.my_random import MyRandom +from xingyun.random.fix_random import FixRandom import unittest def randomint(): @@ -17,7 +17,7 @@ def test_main(self): set_random_seed(2333) rn_3 = randomint() - with MyRandom(): + with FixRandom(): rand_list = [randomint() for _ in range(10)] rn_4 = randomint() diff --git a/test/try/try_argv.py b/test/try/try_argv.py new file mode 100644 index 0000000..39f9293 --- /dev/null +++ b/test/try/try_argv.py @@ -0,0 +1,19 @@ +import sys +import re +print (sys.argv[1:]) + +pattern = r"^--([^=]+)(=(.+)|)$" + +a = "--model=layers" +b = "--model/layers=3" +c = "--model-layers=7" +d = "--yes" +e = "sd--aa=layers" +f = "sd--yes" + +for x in [a,b,c,d,e,f,]: + match = re.match(pattern, x) + if match is not None: + print (match.group(1), match.group(3)) + else: + print ("None") \ No newline at end of file diff --git a/test/try/try_namespace.py b/test/try/try_namespace.py new file mode 100644 index 0000000..312eccc --- /dev/null +++ b/test/try/try_namespace.py @@ -0,0 +1,6 @@ +from argparse import Namespace + +a = Namespace(a = 12, b = "3") +a.__dict__["c.2"] = 100 + +print (a.c) \ No newline at end of file diff --git a/xingyun/__init__.py b/xingyun/__init__.py index c3d6a26..555870c 100644 --- a/xingyun/__init__.py +++ b/xingyun/__init__.py @@ -8,3 +8,4 @@ from .random import * from .logger import * from .cloud import * +from.argparser import * \ No newline at end of file diff --git a/xingyun/argparser/__init__.py b/xingyun/argparser/__init__.py new file mode 100644 index 0000000..f54783b --- /dev/null +++ b/xingyun/argparser/__init__.py @@ -0,0 +1 @@ +from .argparser import Argument, PreCondition, Condition, ArgumentParser, MyDict \ No newline at end of file diff --git a/xingyun/argparser/argparser.py b/xingyun/argparser/argparser.py new file mode 100644 index 0000000..4e38d8e --- /dev/null +++ b/xingyun/argparser/argparser.py @@ -0,0 +1,221 @@ +import sys +from typing import Callable, Any +import re +import warnings +from argparse import Namespace +from ..random.my_random import my_randint +from ..universal.get_subdict import get_subdict + +class Condition: + def __init__(self, condition: Callable[[ dict[str, Any] ], bool] = lambda _: True): + self.condition = condition + + def test(self, config: dict[str, Any]): + return self.condition(config) + + def __and__(self, other: "Condition"): + return Condition(lambda v: self.condition(v) and other.condition(v)) + + def __or__(self, other: "Condition"): + return Condition(lambda v: self.condition(v) or other.condition(v)) + + @classmethod + def and_all(cls, conditions: list["Condition"]): + r = Condition() + for c in conditions: + r = r & c + return r + + +condition_env = {} +class PreCondition: + def __init__(self, cond: Condition | Callable[[ dict[str, Any] ], bool]): + + if not isinstance(cond, Condition): + cond = Condition(cond) + + self.cond = cond + self.id = my_randint(0,2333333) + def __enter__(self): + condition_env[self.id] = self.cond + def __exit__(self, *arg, **kwarg): + condition_env.pop(self.id) + +class MyDict(dict): + def __init__(self, d: dict = {}): + super().__init__(d) + + def sub(self, prefix: str, splitter: str = "/"): + return MyDict( get_subdict(self, prefix, splitter) ) + + def __call__(self, prefix: str, splitter: str = "/"): + return self.sub(prefix, splitter) + + +class Argument: + '''The class that describe an argument. + + ### Properties + - name: the name of the argument. + - type: how to convert string to a value. + - default: the value set to the argument if it is not provided. + - present: the value set to the argument if it is provided but without a value. + - help: help string. + - pre_condition: Before an argument is parsed, the pre-condition must be satisfied. + - post_condition: After an argument is parsed, if the post-condition is not satisfied, then will raise an error. + ''' + def __init__(self, + name: str, + type: Callable[[str] , Any], + default: Any = None, + present: Any = None, + help:str = "" , + pre_condition : Condition = Condition(), + post_condition: Condition = Condition(), + ): + self.name = name + self.type = type + self.default = default + self.present = present + self.pre_condition = pre_condition + self.post_condition = post_condition + self.help = help + +class ArgumentParser: + '''This class is a modification of python `argparse.ArgumentParser` class. + + For each argument, there are two conditions: pre-condition and post-condition. + Before an argument is parsed, the pre-condition must be satisfied. + After an argument is parsed, if the post-condition is not satisfied, then will raise an error. + ''' + + def __init__(self, help: str = ""): + + self.help = help + + self.arguments: dict[str, Argument] = {} + self.alias : dict[str, str] = {} # redirects to + + @property + def now_preconds(self): + return [c for id,c in condition_env.items() if c is not None] + + @classmethod + def get_subconfig(cls, C: dict | Namespace, prefix: str, splitter: str = "/"): + if isinstance(C, Namespace): + C = C.__dict__ + return get_subdict(C,prefix,splitter) + + def add_alias(self, alias: str, original: str): + '''add an alias for an argument.''' + self.alias[alias] = original + + def add_argument(self, + name: str, + type: Callable[[str] , Any], + default: Any = None, + present: Any = None, + help: str = "" , + verify : Callable[ [Any] , bool] = lambda _: True , + ): + post_cond = Condition(lambda C: verify(C.get(name))) + self.arguments[name] = Argument(name, type, default, present, help, Condition.and_all(self.now_preconds), post_cond) + + def add_bool(self, + name: str, + help: str = "" , + verify : Callable[[ Any ], bool] = lambda _: True , + ): + self.add_argument(name, bool, False, True, help, verify) + + def parse_namespace(self, + args: list[str] | None = None, + pattern = r"^--([^=]+)(=(.+)|)$", + get_match: Callable[[re.Match], tuple[str,str]] = lambda m: (m.group(1), m.group(3)) , + ) -> Namespace: + return Namespace(**self.parse(args, pattern, get_match)) + + def parse(self, + args: list[str] | None = None, + pattern = r"^--([^=]+)(=(.+)|)$", + get_match: Callable[[re.Match], tuple[str,str]] = lambda m: (m.group(1), m.group(3)) , + ) -> MyDict: + ''' Parse argument. + + ### Parameters + -- args: arguments to be parsed. + -- pattern: a regular expression to match name and value. + -- get_match: a callable object that get name and value from the get_match. The input callable is a `re.Match` object, + the output should be a 2-tuple, with the first element being the name and the second element being the value of the argument. + ''' + if args is None: + args = sys.argv[1:] + + # get value of each argument + name_vals = {} + for s in args: + # get name val pairs + match = re.match(pattern, s) + if match is None: + continue + name, val = get_match(match) + + # apply alias + alias_tar = self.alias.get(name) + if alias_tar is not None: + name = alias_tar + + # apply present val + arg = self.arguments.get(name) + if arg is None: + continue + if val is None: + val = arg.present + + # record name, val pair + name_vals[name] = val + + # apply default val + for name , arg in self.arguments.items(): + if not (name in name_vals): + name_vals[name] = arg.default + + # actually assign values + parsed = {} + _t = 0 + for _t in range(100): + now_len = len(parsed) + for name, val in name_vals.items(): + + # get Argument object + arg = self.arguments.get(name) + if arg is None: + continue + + # check pre condition + if not arg.pre_condition.test(parsed): + continue + + # store value + parsed[name] = arg.type(val) + + # check post condition + if not arg.post_condition.test(parsed): + raise RuntimeError(f"bad argument {name}: value can not be {parsed[name]}") + + new_len = len(parsed) + if new_len == now_len: # no new argument is parsed + break + now_len = new_len + + if _t >= 98: + warnings.warn("Too deep nested logic. Only performed 100 iterations.") + + # for those who forbidded by precondition, assign `None`. + for name in self.arguments: + if not (name in parsed): + parsed[name] = None + + return MyDict(parsed) + + diff --git a/xingyun/random/__init__.py b/xingyun/random/__init__.py index 8ca8caa..31feeca 100644 --- a/xingyun/random/__init__.py +++ b/xingyun/random/__init__.py @@ -1,4 +1,4 @@ '''This module provides some utils about random numbers.''' -from .my_random import MyRandom +from .fix_random import FixRandom from .set_random_seed import set_random_seed, set_module_seed diff --git a/xingyun/random/fix_random.py b/xingyun/random/fix_random.py new file mode 100644 index 0000000..a6d7db0 --- /dev/null +++ b/xingyun/random/fix_random.py @@ -0,0 +1,97 @@ +import time +import copy +import random +from typing import Any +from xingyun.universal.import_module import my_import_module + +from .set_random_seed import RandomAllowedModule + +def get_random_state(module: RandomAllowedModule) -> Any: + if module == "random": + return random.getstate() + + if module == "numpy": + np = my_import_module("numpy") + if np is None: + return None + return np.random.get_state() + + if module == "torch": + torch = my_import_module("torch") + cuda = my_import_module("torch.cuda") + + if (torch is None) or (cuda is None): + return None + + return { + "torch": torch.random.get_rng_state(), + "cuda" : cuda .random.get_rng_state(), + } + +def set_random_state(state: Any, module: RandomAllowedModule) -> bool: + flag = True + if module == "random": + + try: + random.setstate(state) + except: + flag = False + + if module == "numpy": + + np = my_import_module("numpy") + if (np is None) or (state is None): + return False + + try: + np.random.set_state(state) + except: + flag = False + + if module == "torch": + torch = my_import_module("torch") + cuda = my_import_module("torch.cuda") + + if (torch is None) or (cuda is None): + return False + try: + torch.random.set_rng_state(state["torch"]) + cuda.random.set_rng_state(state["cuda"]) + except: + flag = False + + + if not flag: + raise RuntimeError(f"set random state of module {module} bad.") + + return flag + +class FixRandom: + def __init__(self, random_seed: int | None = None, modules: list[RandomAllowedModule] = ["random" , "torch" , "numpy"]): + '''This class create a temporary environment, inside which the random seed + is set to a given value while not affecting the global random seed. + + Notice that, to make this class work, the global random seed must be also managed by `xingyun`. + ''' + if random_seed is None: + random_seed = int( time.time() ) + + self.random_seed = random_seed + self.modules = modules + + self.entering_state = {} + + def __enter__(self): + for m in self.modules: + self.entering_state[m] = get_random_state(m) + + def __exit__(self, *args, **kwargs): + for m in self.modules: + set_random_state(self.entering_state[m],m) + + + + + + + diff --git a/xingyun/random/my_random.py b/xingyun/random/my_random.py index 1af0516..ec7d174 100644 --- a/xingyun/random/my_random.py +++ b/xingyun/random/my_random.py @@ -1,85 +1,12 @@ -import time -import copy -import random -from typing import Any -from xingyun.universal.import_module import my_import_module - -from .set_random_seed import RandomAllowedModule - -def get_random_state(module: RandomAllowedModule) -> Any: - if module == "random": - return random.getstate() - - if module == "numpy": - np = my_import_module("numpy") - return np.random.get_state() - - if module == "torch": - torch = my_import_module("torch") - cuda = my_import_module("torch.cuda") - - return { - "torch": torch.random.get_rng_state(), - "cuda" : cuda .random.get_rng_state(), - } - -def set_random_state(state: Any, module: RandomAllowedModule) -> bool: - flag = True - if module == "random": - try: - random.setstate(state) - except: - flag = False - - if module == "numpy": - try: - np = my_import_module("numpy") - np.random.set_state(state) - except: - flag = False - - if module == "torch": - try: - torch = my_import_module("torch") - cuda = my_import_module("torch.cuda") - - torch.random.set_rng_state(state["torch"]) - cuda.random.set_rng_state(state["cuda"]) - - except: - flag = False - - if not flag: - raise RuntimeError(f"set random state of module {module} bad.") - - return flag - -class MyRandom: - def __init__(self, random_seed: int | None = None, modules: list[RandomAllowedModule] = ["random" , "torch" , "numpy"]): - '''This class create a temporary environment, inside which the random seed - is set to a given value while not affecting the global random seed. - - Notice that, to make this class work, the global random seed must be also managed by `xingyun`. - ''' - if random_seed is None: - random_seed = int( time.time() ) - - self.random_seed = random_seed - self.modules = modules - - self.entering_state = {} - - def __enter__(self): - for m in self.modules: - self.entering_state[m] = get_random_state(m) - - def __exit__(self, *args, **kwargs): - for m in self.modules: - set_random_state(self.entering_state[m],m) - - - - - - - +_seed = 2333 +def my_rand(): + global _seed + _seed = _seed * 233 + 23333 + _seed = _seed % 233333333 + return _seed + +def my_randint(low: int, high: int): + if high - low <= 0: + return low + r = my_rand() + return (r % (high - low)) + low \ No newline at end of file diff --git a/xingyun/universal/get_subdict.py b/xingyun/universal/get_subdict.py new file mode 100644 index 0000000..13863fc --- /dev/null +++ b/xingyun/universal/get_subdict.py @@ -0,0 +1,17 @@ +from typing import Any + +def get_subdict(d: dict[str, Any], prefix: str, splitter: str = "/"): + '''Find a subdict from a dict. + + ### Example + >>> dic = {"a": 1, "a/b": 2, "b": 3} + >>> get_subdict(dic, "a") + {"b": 2} + + ''' + ret = {} + for x,y in d.items(): + splitted = x.split(splitter) + if splitted[0] == prefix: + ret[splitter.join(splitted[1:])] = y + return ret \ No newline at end of file