Skip to content

Commit

Permalink
add argparser
Browse files Browse the repository at this point in the history
  • Loading branch information
FFTYYY committed Dec 27, 2023
1 parent e9500b7 commit adea9df
Show file tree
Hide file tree
Showing 15 changed files with 476 additions and 90 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setup(
name="xingyun",
version="0.0.3",
version="0.0.4",
url="http://github.com/FFTYYY/XingYun",
description="",
long_description=readme,
Expand Down
1 change: 1 addition & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .time import *
from .cloud import *
from .logger import *
from .argparser import *

if __name__ == "__main__":
testsuite = unittest.TestLoader().discover(".")
Expand Down
1 change: 1 addition & 0 deletions test/argparser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .argparser import *
95 changes: 95 additions & 0 deletions test/argparser/argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from xingyun.argparser import ArgumentParser, PreCondition, MyDict
import unittest
import pdb

def make_parser():
argp = ArgumentParser()

# ---- model ----
argp.add_argument("model/model" , type = str, default = "TWIRLS", verify = lambda v: v in ["TWIRLS", "UGNN", "IGNN", "CNN"] )
argp.add_alias ("model" , "model/model")

argp.add_argument("model/num_layers", type = int, default = 12 )

with PreCondition(lambda C: C["model/model"] in ["TWIRLS" , "UGNN", "IGNN"]):
argp.add_argument("model/GNN-spec/prop_method", type = str, default = "message", verify = lambda v: v in ["message", "identity"] )

with PreCondition(lambda C: C["model/model"] in ["CNN"]):
argp.add_bool ("model/CNN-spec/bn")


# ---- optim ----
argp.add_argument("optim/lr", type = float, default = 0.1, verify = lambda v: v >= 0 )
argp.add_argument("optim/wd", type = float, default = 0.0, verify = lambda v: v >= 0 )

# ---- general ----
argp.add_bool ("no_activation")

return argp




class TestArgParser(unittest.TestCase):
def setUp(self) -> None:
self.argp = make_parser()
return super().setUp()

def test_main(self):
argv = [
"--model=UGNN" ,
"--model/num_layers=3",
"--model/GNN-spec/prop_method=message",
"--model/CNN-spec/bn",
"--optim/lr=0.1",
"--optim/wd=5e-5",
"--no_activation"
]

C = self.argp.parse(argv)

print (C.sub("model"))

self.assertTrue( C.sub("model")["model"] == "UGNN" )
self.assertTrue( C.sub("model")["num_layers"] == 3 )
self.assertTrue( C.sub("model")["GNN-spec/prop_method"] == "message" )
self.assertTrue( C.sub("model")["CNN-spec/bn"] is None )
self.assertTrue( C("model")["CNN-spec/bn"] == C("model")("CNN-spec")["bn"] )
self.assertTrue( C.sub("optim")["lr"] == 0.1)
self.assertTrue( C.sub("optim")["wd"] == 5e-5 )
self.assertTrue( C["no_activation"] == True )

def test_raise(self):
argv = [
"--model=UGNN" ,
"--model/num_layers=3",
"--model/GNN-spec/prop_method=UGNN_message",
"--model/CNN-spec/bn",
"--optim/lr=0.1",
"--optim/wd=5e-5",
"--no_activation"
]
self.assertRaises(RuntimeError, lambda : self.argp.parse(argv))

def test_default(self):
argv = [
"--model=CNN" ,
"--model/num_layers=12",
"--model/GNN-spec/prop_method=UGNN_message",
"--model/CNN-spec/bn",
"--optim/lr=0.1",
]
C = self.argp.parse(argv)

self.assertTrue( C.sub("model")["model"] == "CNN" )
self.assertTrue( C.sub("model")["num_layers"] == 12 )
self.assertTrue( C.sub("model")["GNN-spec/prop_method"] is None )
self.assertTrue( C.sub("model")["CNN-spec/bn"] == True )
self.assertTrue( C.sub("optim")["lr"] == 0.1)
self.assertTrue( C.sub("optim")["wd"] == 0.0 )
self.assertTrue( C["no_activation"] == False )



if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/random/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .my_random import *
from .fix_random import *
4 changes: 2 additions & 2 deletions test/random/my_random.py → test/random/fix_random.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import torch
from xingyun.random.set_random_seed import set_random_seed
from xingyun.random.my_random import MyRandom
from xingyun.random.fix_random import FixRandom
import unittest

def randomint():
Expand All @@ -17,7 +17,7 @@ def test_main(self):

set_random_seed(2333)
rn_3 = randomint()
with MyRandom():
with FixRandom():
rand_list = [randomint() for _ in range(10)]
rn_4 = randomint()

Expand Down
19 changes: 19 additions & 0 deletions test/try/try_argv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import sys
import re
print (sys.argv[1:])

pattern = r"^--([^=]+)(=(.+)|)$"

a = "--model=layers"
b = "--model/layers=3"
c = "--model-layers=7"
d = "--yes"
e = "sd--aa=layers"
f = "sd--yes"

for x in [a,b,c,d,e,f,]:
match = re.match(pattern, x)
if match is not None:
print (match.group(1), match.group(3))
else:
print ("None")
6 changes: 6 additions & 0 deletions test/try/try_namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from argparse import Namespace

a = Namespace(a = 12, b = "3")
a.__dict__["c.2"] = 100

print (a.c)
1 change: 1 addition & 0 deletions xingyun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .random import *
from .logger import *
from .cloud import *
from.argparser import *
1 change: 1 addition & 0 deletions xingyun/argparser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .argparser import Argument, PreCondition, Condition, ArgumentParser, MyDict
221 changes: 221 additions & 0 deletions xingyun/argparser/argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import sys
from typing import Callable, Any
import re
import warnings
from argparse import Namespace
from ..random.my_random import my_randint
from ..universal.get_subdict import get_subdict

class Condition:
def __init__(self, condition: Callable[[ dict[str, Any] ], bool] = lambda _: True):
self.condition = condition

def test(self, config: dict[str, Any]):
return self.condition(config)

def __and__(self, other: "Condition"):
return Condition(lambda v: self.condition(v) and other.condition(v))

def __or__(self, other: "Condition"):
return Condition(lambda v: self.condition(v) or other.condition(v))

@classmethod
def and_all(cls, conditions: list["Condition"]):
r = Condition()
for c in conditions:
r = r & c
return r


condition_env = {}
class PreCondition:
def __init__(self, cond: Condition | Callable[[ dict[str, Any] ], bool]):

if not isinstance(cond, Condition):
cond = Condition(cond)

self.cond = cond
self.id = my_randint(0,2333333)
def __enter__(self):
condition_env[self.id] = self.cond
def __exit__(self, *arg, **kwarg):
condition_env.pop(self.id)

class MyDict(dict):
def __init__(self, d: dict = {}):
super().__init__(d)

def sub(self, prefix: str, splitter: str = "/"):
return MyDict( get_subdict(self, prefix, splitter) )

def __call__(self, prefix: str, splitter: str = "/"):
return self.sub(prefix, splitter)


class Argument:
'''The class that describe an argument.
### Properties
- name: the name of the argument.
- type: how to convert string to a value.
- default: the value set to the argument if it is not provided.
- present: the value set to the argument if it is provided but without a value.
- help: help string.
- pre_condition: Before an argument is parsed, the pre-condition must be satisfied.
- post_condition: After an argument is parsed, if the post-condition is not satisfied, then will raise an error.
'''
def __init__(self,
name: str,
type: Callable[[str] , Any],
default: Any = None,
present: Any = None,
help:str = "" ,
pre_condition : Condition = Condition(),
post_condition: Condition = Condition(),
):
self.name = name
self.type = type
self.default = default
self.present = present
self.pre_condition = pre_condition
self.post_condition = post_condition
self.help = help

class ArgumentParser:
'''This class is a modification of python `argparse.ArgumentParser` class.
For each argument, there are two conditions: pre-condition and post-condition.
Before an argument is parsed, the pre-condition must be satisfied.
After an argument is parsed, if the post-condition is not satisfied, then will raise an error.
'''

def __init__(self, help: str = ""):

self.help = help

self.arguments: dict[str, Argument] = {}
self.alias : dict[str, str] = {} # redirects to

@property
def now_preconds(self):
return [c for id,c in condition_env.items() if c is not None]

@classmethod
def get_subconfig(cls, C: dict | Namespace, prefix: str, splitter: str = "/"):
if isinstance(C, Namespace):
C = C.__dict__
return get_subdict(C,prefix,splitter)

def add_alias(self, alias: str, original: str):
'''add an alias for an argument.'''
self.alias[alias] = original

def add_argument(self,
name: str,
type: Callable[[str] , Any],
default: Any = None,
present: Any = None,
help: str = "" ,
verify : Callable[ [Any] , bool] = lambda _: True ,
):
post_cond = Condition(lambda C: verify(C.get(name)))
self.arguments[name] = Argument(name, type, default, present, help, Condition.and_all(self.now_preconds), post_cond)

def add_bool(self,
name: str,
help: str = "" ,
verify : Callable[[ Any ], bool] = lambda _: True ,
):
self.add_argument(name, bool, False, True, help, verify)

def parse_namespace(self,
args: list[str] | None = None,
pattern = r"^--([^=]+)(=(.+)|)$",
get_match: Callable[[re.Match], tuple[str,str]] = lambda m: (m.group(1), m.group(3)) ,
) -> Namespace:
return Namespace(**self.parse(args, pattern, get_match))

def parse(self,
args: list[str] | None = None,
pattern = r"^--([^=]+)(=(.+)|)$",
get_match: Callable[[re.Match], tuple[str,str]] = lambda m: (m.group(1), m.group(3)) ,
) -> MyDict:
''' Parse argument.
### Parameters
-- args: arguments to be parsed.
-- pattern: a regular expression to match name and value.
-- get_match: a callable object that get name and value from the get_match. The input callable is a `re.Match` object,
the output should be a 2-tuple, with the first element being the name and the second element being the value of the argument.
'''
if args is None:
args = sys.argv[1:]

# get value of each argument
name_vals = {}
for s in args:
# get name val pairs
match = re.match(pattern, s)
if match is None:
continue
name, val = get_match(match)

# apply alias
alias_tar = self.alias.get(name)
if alias_tar is not None:
name = alias_tar

# apply present val
arg = self.arguments.get(name)
if arg is None:
continue
if val is None:
val = arg.present

# record name, val pair
name_vals[name] = val

# apply default val
for name , arg in self.arguments.items():
if not (name in name_vals):
name_vals[name] = arg.default

# actually assign values
parsed = {}
_t = 0
for _t in range(100):
now_len = len(parsed)
for name, val in name_vals.items():

# get Argument object
arg = self.arguments.get(name)
if arg is None:
continue

# check pre condition
if not arg.pre_condition.test(parsed):
continue

# store value
parsed[name] = arg.type(val)

# check post condition
if not arg.post_condition.test(parsed):
raise RuntimeError(f"bad argument {name}: value can not be {parsed[name]}")

new_len = len(parsed)
if new_len == now_len: # no new argument is parsed
break
now_len = new_len

if _t >= 98:
warnings.warn("Too deep nested logic. Only performed 100 iterations.")

# for those who forbidded by precondition, assign `None`.
for name in self.arguments:
if not (name in parsed):
parsed[name] = None

return MyDict(parsed)


Loading

0 comments on commit adea9df

Please sign in to comment.