Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Support config with multiple sources in _BASE_ #16

Merged
merged 1 commit into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
* `Python 3.7`

```bash
pip install -r requirements.txt
pip install -r requirements/runtime.txt
# For development
pip install -r requirements/dev.txt
```

## Pre-commit Hook
Expand All @@ -16,3 +18,9 @@ pip install -r requirements.txt
pip install pre-commit
pre-commit install
```

## Config

> Fix the version `fvcore` to `0.1.2.post20210128`

* Support multiple inheritance of config
98 changes: 96 additions & 2 deletions pytorch_trainer/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,102 @@
from fvcore.common.config import CfgNode as CN
import os
import logging
from typing import Any, Dict

import yaml
from fvcore.common.config import CfgNode as _CfgNode

BASE_KEY = "_BASE_"


class CfgNode(_CfgNode):
"""Support more features by inheritance.

NOTE:
* Support `list` type to `_BASE_` inheritance functionality.
"""

@classmethod
def load_yaml_with_base(cls, filename: str, allow_unsafe: bool = False) -> None:
"""
Just like `yaml.load(open(filename))`, but inherit attributes from its
`_BASE_`.

Args:
filename (str or file-like object): the file name or file of the current config.
Will be used to find the base config file.
allow_unsafe (bool): whether to allow loading the config file with
`yaml.unsafe_load`.

Returns:
(dict): the loaded yaml
"""

with cls._open_cfg(filename) as f:
try:
cfg = yaml.safe_load(f)
except yaml.constructor.ConstructorError:
if not allow_unsafe:
raise
logger = logging.getLogger(__name__)
logger.warning(
"Loading config {} with yaml.unsafe_load. Your machine may "
"be at risk if the file contains malicious content.".format(
filename
)
)
f.close()
with cls._open_cfg(filename) as f:
cfg = yaml.unsafe_load(f)

# pyre-ignore
def merge_a_into_b(a: Dict[Any, Any], b: Dict[Any, Any]) -> None:
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v

def _load_base_cfg(base_cfg_file):
if base_cfg_file.startswith("~"):
base_cfg_file = os.path.expanduser(base_cfg_file)
if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])):
# the path to base cfg is relative to the config file itself.
base_cfg_file = os.path.join(
os.path.dirname(filename),
base_cfg_file)
base_cfg = cls.load_yaml_with_base(
base_cfg_file,
allow_unsafe=allow_unsafe)

return base_cfg

if BASE_KEY in cfg:
BASE_VALUE = cfg.pop(BASE_KEY)
_base_group = {}
if isinstance(BASE_VALUE, list):
# Merge cfg into _base_group first.
merge_a_into_b(cfg, _base_group)

for _BASE_VALUE in BASE_VALUE:
base_cfg = _load_base_cfg(_BASE_VALUE)
# Merge each base_cfg into _base_group
merge_a_into_b(base_cfg, _base_group)
return _base_group
elif isinstance(BASE_VALUE, str):
base_cfg = _load_base_cfg(BASE_VALUE)
merge_a_into_b(cfg, base_cfg)
return base_cfg

return cfg


# [NOTE] Default field is free to add any node.
# Because this base config could be shared among applications, it should be clean to any import.
_C = CN(new_allowed=True)
_C = CfgNode(new_allowed=True)


def get_cfg_defaults():
Expand Down
3 changes: 2 additions & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch>=1.7.1
fvcore
fvcore=0.1.2.post20210128
tensorboard
autopep8
pre-commit
pytest
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch>=1.7.1
fvcore
fvcore=0.1.2.post20210128
tensorboard