diff --git a/fltk/util/config/__init__.py b/fltk/util/config/__init__.py index d5d4bf75..a9fc76b2 100644 --- a/fltk/util/config/__init__.py +++ b/fltk/util/config/__init__.py @@ -1,15 +1,10 @@ from __future__ import annotations -import json from pathlib import Path from typing import Optional, Union, Type, Dict -import torch -import yaml import logging -from torch.nn.modules.loss import _Loss - from fltk.util.config.definitions import Loss from fltk.util.config.distributed_config import DistributedConfig @@ -45,15 +40,15 @@ def get_distributed_config(args, alt_path: str = None) -> Optional[DistributedCo def get_learning_param_config(args, alt_path: str = None) -> Optional[DistLearningConfig]: + """ + Retrieve learning parameter configuration from Disk for distributed learning experiments. + """ if args: config_path = args.experiment_config else: config_path = alt_path - safe_loader = get_safe_loader() try: - with open(config_path) as f: - learning_params_dict = yaml.load(f, Loader=safe_loader) - learning_params = DistLearningConfig.from_dict(learning_params_dict) + learning_params = DistLearningConfig.from_yaml(Path(config_path)) except Exception as e: msg = f"Failed to get learning parameter configuration for distributed experiments: {e}" logging.info(msg)