Skip to content

Commit

Permalink
Fix device misconfiguration
Browse files Browse the repository at this point in the history
  • Loading branch information
bacox committed Mar 17, 2022
1 parent 672f926 commit ce1936a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion fltk/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _config(self, config: Config):
self.config.rank = self.rank
self.config.world_size = self.world_size
self.cuda = config.cuda
self.init_device()
self.device = self.init_device()
self.distributed = config.distributed
self.set_net(self.load_default_model())

Expand Down
4 changes: 4 additions & 0 deletions fltk/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class Config:

def __init__(self, **kwargs) -> None:
enum_fields = [x for x in self.__dataclass_fields__.items() if isinstance(x[1].type, Enum) or isinstance(x[1].type, EnumMeta)]
if 'dataset' in kwargs and 'dataset_name' not in kwargs:
kwargs['dataset_name'] = kwargs['dataset']
if 'net' in kwargs and 'net_name' not in kwargs:
kwargs['net_name'] = kwargs['net']
for name, field in enum_fields:
if name in kwargs and isinstance(kwargs[name], str):
kwargs[name] = field.type(kwargs[name])
Expand Down

0 comments on commit ce1936a

Please sign in to comment.