Skip to content

Commit

Permalink
Merge pull request #24 from aai-institute/unify-docstrings
Browse files Browse the repository at this point in the history
Onnx export compatible LU-trasform
  • Loading branch information
fariedabuzaid authored Sep 27, 2023
2 parents 977b960 + c873221 commit 2efdfc5
Show file tree
Hide file tree
Showing 11 changed files with 843 additions and 473 deletions.
26 changes: 10 additions & 16 deletions scripts/run-expreiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

from src.experiments.config_parser import read_config

Pathable = T.Union[str, os.PathLike] # In principle one can cast it to os.path.Path
Pathable = T.Union[str, os.PathLike] # In principle one can cast it to os.path.Path


@click.command()
@click.option("--report_dir", default="./reports", help="Report file")
@click.option("--config", default="./config.yaml", help="Prefix for config items")
@click.option("--storage_path", default=None, help="Prefix for config items")
def run(report_dir: Pathable, config: Pathable, storage_path: Pathable):
"""Loads an experiment from config file conducts the experiment it.
Args:
report_dir (str): Directory to save report to.
config (str): Path to config file. The report is expected to be specified in .yaml format with
Expand All @@ -22,22 +23,15 @@ def run(report_dir: Pathable, config: Pathable, storage_path: Pathable):
storage_path (str): Path to Ray storage directory. Defaults to None.
"""
sepline = "\n" + ("-" * 80) + "\n" + ("-" * 80) + "\n"
print(
f"{sepline}Parsing config file:{sepline}"
)
print(f"{sepline}Parsing config file:{sepline}")
config = os.path.abspath(config)
experiment = read_config(config)
print(
f"{sepline}Done.{sepline}"
)
print(
f"{sepline}Conducting experiment{sepline}"
)
print(f"{sepline}Done.{sepline}")
print(f"{sepline}Conducting experiment{sepline}")
# Conduct experiment
experiment.conduct(report_dir, storage_path=storage_path)
print(
f"{sepline}Done.{sepline}"
)

print(f"{sepline}Done.{sepline}")


if __name__ == "__main__":
run()
run()
50 changes: 28 additions & 22 deletions src/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,68 @@


class Experiment(object):
"""Base class for experiments.
"""
"""Base class for experiments."""

def __init__(self, name, *args, **kwargs):
super().__init__()
self.name = name
@classmethod

@classmethod
def _init_rec(cls, cfg):
if isinstance(cfg, dict):
if "experiment" in cfg:
experiment_type = cfg["experiment"]["experiment_type"]
params = cls._init_rec(cfg["experiment"]["experiment_params"])

return experiment_type(**params)
else:
return {k: cls._init_rec(v) for k, v in cfg.items()}
elif isinstance(cfg, list):
return [cls._init_rec(v) for v in cfg]
else:
return cfg

@classmethod
def from_dict(cls, config: T.Dict[str, T.Any]) -> "Experiment":
if "experiment" not in config:
raise ValueError("Invalid config file. The config file needs to contain an experiment field.")
raise ValueError(
"Invalid config file. The config file needs to contain an experiment field."
)
return cls._init_rec(config)

def conduct(self, report_dir: os.PathLike, storage_path: os.PathLike = None) -> None:
"""Conducts the experiment and saves the results to the report directory. The method is expected to store all results in report_dir.
"""

def conduct(
self, report_dir: os.PathLike, storage_path: os.PathLike = None
) -> None:
"""Conducts the experiment and saves the results to the report directory. The method is expected to store all results in report_dir."""
raise NotImplementedError



class ExperimentCollection(Experiment):
""" Implements an experiment that consists of several jointly conducted but independent experiments.
"""
"""Implements an experiment that consists of several jointly conducted but independent experiments."""

def __init__(self, experiments: T.Iterable[Experiment], *args, **kwargs) -> None:
"""
The function initializes an object with a list of experiments based on a given configuration.
:param experiments: The "experiments" parameter is an iterable object that contains a list of
experiments. Each experiment is represented by a configuration object
:type experiments: Iterable *args
"""
super().__init__(*args, **kwargs)
self.experiments = experiments

@classmethod
def from_dict(cls, config: T.Dict[str, T.Any]) -> "ExperimentCollection":
config = deepcopy(config)
for i, exp_cfg in enumerate(config["experiment_params"]["experiments"]):
config["experiment_params"]["experiments"][i] = Experiment.from_dict(exp_cfg)

config["experiment_params"]["experiments"][i] = Experiment.from_dict(
exp_cfg
)

return Experiment.from_dict(config)



def conduct(self, report_dir: os.PathLike, storage_path: os.PathLike = None):
for i, exp in enumerate(self.experiments):
exp.conduct(os.path.join(report_dir, f"{i}_{exp.name}"), storage_path=storage_path)

exp.conduct(
os.path.join(report_dir, f"{i}_{exp.name}"), storage_path=storage_path
)
85 changes: 42 additions & 43 deletions src/experiments/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


def unfold_raw_config(d: Dict[str, Any]):
"""Unfolds an ordered DAG given as a dictionary into a tree given as dictionary.
"""Unfolds an ordered DAG given as a dictionary into a tree given as dictionary.
That means that unfold_dict(d) is bisimilar to d but no two distinct key paths in the resulting
dictionary reference the same object
:param d: The dictionary to unfold
"""
du = dict()
Expand All @@ -25,134 +25,135 @@ def unfold_raw_config(d: Dict[str, Any]):
du[k] = [unfold_raw_config(x) for x in v]
else:
du[k] = deepcopy(v)

return du


def push_overwrites(item: Any, attributes: Dict[str, Any]) -> Any:
"""Pushes the overwrites in the given dictionary to the given item.
If the item already specifies an overwrite, it is updated.
If the item is a dictionary, an overwrite specification for the dictionary is created.
If the item is a dictionary, an overwrite specification for the dictionary is created.
If the item is a list, the overwrites are pushed each element and the processed list is returned.
Otherwise, item is overwritten by attributen
:param item: The item to push the overwrites to.
:param overwrites: The overwrites to push.
"""
try:
if "__exact__" in attributes:
return deepcopy(attributes["__exact__"])
return deepcopy(attributes["__exact__"])
except:
pass

if isinstance(item, dict):
if "__overwrites__" not in item:
result = deepcopy(attributes)
result["__overwrites__"] = item
else:
result = item
result["__overwrites__"].update(attributes)
elif isinstance(item, list):
elif isinstance(item, list):
result = [push_overwrites(x, attributes) for x in item]
else:
result = deepcopy(attributes)

return result


def apply_overwrite(d: Dict[str, Any], recurse: bool = True):
"""Applies the "__overwrites__" keyword sematic to a unfolded raw config dictionary and returns the result.
Except for the special semantics that applies to dictionaries and lists (see below),
all keys $k$ that are present in in the "__overwrites__" dictionary $o$ are overwritten in $d$ by $o[k]$.
all keys $k$ that are present in in the "__overwrites__" dictionary $o$ are overwritten in $d$ by $o[k]$.
** Dict/List overwrites **:
- If $d[k]$ is a dictionary, then $d[k]$ must be a dictionary and overwrites of $o[k]$ are recursively
- If $d[k]$ is a dictionary, then $d[k]$ must be a dictionary and overwrites of $o[k]$ are recursively
to the corresponding $d[k]$, i.e. lower-lever overwrites are specified/updated (see notes on recursion).
The only exception is if $o[k]$ contains the special key "__exact__" with value True. In this case
$d[k$]$ is replaced by $o[k]["__exact__"]$.
- If $o[k]$ is a list, then $o[k]$ is pushed to all list elements.
- If $o[k]$ is a list, then $o[k]$ is pushed to all list elements.
** Recursion **:
If recursion is enabled, overwrites are are fully expanded in infix order where nested overwrites
(see behavior on dict/list overwrites) are pushed (and overwrite) to the next level, i.e.
If recursion is enabled, overwrites are are fully expanded in infix order where nested overwrites
(see behavior on dict/list overwrites) are pushed (and overwrite) to the next level, i.e.
higher level overwrites lower level. Else, only the top-level overwrites are applied.
** Note **: Applying this function to a non-unfolded dictionary
can result in unexpected behavior due to side side-effects.
:param d: The unfolded raw config dictionary.
:pram recurse: If True, the overwrites are applied recursively. Defaults to True.
:pram recurse: If True, the overwrites are applied recursively. Defaults to True.
Can be useful for efficient combination of this method with other parsing methods.
"""

# Apply top-level overwrite
if "__overwrites__" in d:
overwritten_attr = d
overwritten_attr = d
d = overwritten_attr.pop("__overwrites__")

for k, v in overwritten_attr.items():
if k not in d:
d[k] = v
else:
d[k] = push_overwrites(d[k], v)



if recurse:
for k, v in d.items():
if isinstance(v, dict):
d[k] = apply_overwrite(v)
elif isinstance(v, list):
d[k] = [apply_overwrite(x) for x in v]

return d


def read_config(yaml_path: Union[str, Path]) -> dict:
"""Loads a yaml file and returns the corresponding dictionary.
Besides the standard yaml syntax, the function also supports the following
additional functionality:
Special keys:
__class__: The value of this key is interpreted as the class name of the object.
__class__: The value of this key is interpreted as the class name of the object.
The class is imported and stored in the result dictionary under the key <key>.
Example:
entry in yaml: __class__model: laplace_flows.flows.NiceFlow)
entry in result: model: __import__("laplace_flows.flows.NiceFlow")
__tune__<key>: The value of this key is interpreted as a dictionary that contains the
configuration for the hyperparameter optimization using tune sample methods.
__tune__<key>: The value of this key is interpreted as a dictionary that contains the
configuration for the hyperparameter optimization using tune sample methods.
the directive is evaluated and the result in the result dictionary under the key <key>.
Example:
entry in yaml: __tune__lr: loguniform(1e-4, 1e-1)
entry in result: lr: eval("tune.loguniform(1e-4, 1e-1)")
:param yaml_path: Path to the yaml file.
"""

with open(yaml_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config = unfold_raw_config(config)

config = unfold_raw_config(config)
config = parse_raw_config(config)

return config



def parse_raw_config(d: dict):
"""Parses an unfolded raw config dictionary and returns the corresponding dictionary.
Parsing includes the following steps:
- Overwrites are applied (see apply_overwrite)
- The "__object__" key is interpreted as a class name and the corresponding class is imported.
- The "__eval__" key is evaluated.
- The "__class__" key is interpreted as a class name and the corresponding class is imported.
:param d: The raw config dictionary.
"""
if isinstance(d, dict):
d = apply_overwrite(d, recurse=False)

# Depth-first recursion
for k, v in d.items():
d[k] = parse_raw_config(v)

if "__object__" in d:
module, cls = d["__object__"].rsplit(".", 1)
C = getattr(import_module(module), cls)
Expand All @@ -164,12 +165,10 @@ def parse_raw_config(d: dict):
module, cls = d["__class__"].rsplit(".", 1)
C = getattr(import_module(module), cls)
return C
else:
else:
return d
elif isinstance(d, list):
result = [parse_raw_config(x) for x in d]
return result
else:
return d


Loading

0 comments on commit 2efdfc5

Please sign in to comment.