diff --git a/fltk/core/node.py b/fltk/core/node.py index 6d8aac0c..6cd0958e 100644 --- a/fltk/core/node.py +++ b/fltk/core/node.py @@ -56,7 +56,7 @@ def _config(self, config: Config): self.config.rank = self.rank self.config.world_size = self.world_size self.cuda = config.cuda - self.init_device() + self.device = self.init_device() self.distributed = config.distributed self.set_net(self.load_default_model()) diff --git a/fltk/util/config.py b/fltk/util/config.py index 5682059a..c83356db 100644 --- a/fltk/util/config.py +++ b/fltk/util/config.py @@ -70,6 +70,10 @@ class Config: def __init__(self, **kwargs) -> None: enum_fields = [x for x in self.__dataclass_fields__.items() if isinstance(x[1].type, Enum) or isinstance(x[1].type, EnumMeta)] + if 'dataset' in kwargs and 'dataset_name' not in kwargs: + kwargs['dataset_name'] = kwargs['dataset'] + if 'net' in kwargs and 'net_name' not in kwargs: + kwargs['net_name'] = kwargs['net'] for name, field in enum_fields: if name in kwargs and isinstance(kwargs[name], str): kwargs[name] = field.type(kwargs[name])