Skip to content

Commit

Permalink
Create cfg object from Config class
Browse files Browse the repository at this point in the history
  • Loading branch information
mjaehn committed Sep 20, 2023
1 parent c78be0a commit 86976b0
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 94 deletions.
1 change: 1 addition & 0 deletions cases/cosmo-ghg-11km-test/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# GENERAL SETTINGS
model: cosmo-ghg
constraint: gpu
restart_step: 12
variant: spinup
spinup: 6
Expand Down
201 changes: 107 additions & 94 deletions run_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,102 +86,115 @@ def parse_arguments():

return args

class Config():
def __init__(self, casename):
# Global attributes (initialized with default values)
self.casename = casename
self.user = os.environ['USER']
self.chain_src_dir = os.getcwd()
self.path = os.path.join(self.chain_src_dir, self.casename)
self.work_root = os.path.join(self.chain_src_dir, 'work')

# User-defined attributes from config file
self.load_config_file(casename)

# Derived attributes based on user configuration
self.set_account()
self.set_node_info()


def load_config_file(self, casename):
"""
Load the configuration settings from a YAML file.
This method reads the configuration settings from a YAML file located in
the 'cases/casename' directory and sets them as attributes of the instance.
Parameters:
- casename (str): Name of the folder in 'cases/' where the configuration
files are stored.
Returns:
- self (Config): The same `Config` instance with configuration settings as
attributes.
"""

cfg_file = os.path.join('cases', casename, 'config.yaml')

if not os.path.isfile(cfg_file):
all_cases = [
path.name for path in os.scandir('cases') if path.is_dir()
]
closest_name = min([(tools.levenshtein(casename, name), name)
for name in all_cases],
key=lambda x: x[0])[1]
raise FileNotFoundError(
f"Case-directory '{casename}' not found, did you mean '{closest_name}'?"
)

try:
with open(cfg_file, 'r') as yaml_file:
cfg_data = yaml.load(yaml_file, Loader=yaml.FullLoader)
except FileNotFoundError:
raise FileNotFoundError(
f"No file 'config.yaml' in {os.path.dirname(cfg_file)}")

# Directly assign values to instance attributes
for key, value in cfg_data.items():
setattr(self, key, value)

# Save the user-defined items
self.user_config = cfg_data.items()

return self

def set_account(self):
if self.user == 'jenkins':
# g110 account for Jenkins testing
self.compute_account = 'g110'
elif os.path.exists(os.environ['HOME'] + '/.acct'):
# Use account specified in ~/.acct file
with open(os.environ['HOME'] + '/.acct', 'r') as file:
self.compute_account = file.read().rstrip()
else:
# Use standard account
self.compute_account = os.popen("id -gn").read().splitlines()[0]

return self

def set_node_info(self):
if self.constraint == 'gpu':
self.ntasks_per_node = 12
self.mpich_cuda = ('export MPICH_RDMA_ENABLED_CUDA=1\n'
'export MPICH_G2G_PIPELINE=256\n'
'export CRAY_CUDA_MPS=1\n')
elif self.constraint == 'mc':
self.ntasks_per_node = 36
self.mpich_cuda = ''
else:
raise ValueError("Invalid value for 'constraint' in the configuration."
"It should be either 'gpu' or 'mc'.")

def load_config_file(casename, cfg):
"""
Load the configuration settings from a YAML file.
This function reads the configuration settings from a YAML file located in
the 'cases/casename' directory and sets them as attributes of the provided
`cfg` object.
Parameters:
- casename (str): Name of the folder in 'cases/' where the configuration
files are stored.
- cfg (object): An object to store the configuration settings as attributes.
Returns:
- cfg (object): The same `cfg` object with configuration settings as
attributes.
"""

cfg_file = os.path.join('cases', casename, 'config.yaml')

if not os.path.isfile(cfg_file):
all_cases = [
path.name for path in os.scandir('cases') if path.is_dir()
]
closest_name = min([(tools.levenshtein(casename, name), name)
for name in all_cases],
key=lambda x: x[0])[1]
raise FileNotFoundError(
f"Case-directory '{casename}' not found, did you mean '{closest_name}'?"
)

try:
with open(cfg_file, 'r') as yaml_file:
cfg_data = yaml.load(yaml_file, Loader=yaml.FullLoader)
except FileNotFoundError:
raise FileNotFoundError(
f"No file 'config.yaml' in {os.path.dirname(cfg_file)}")

for key, value in cfg_data.items():
setattr(cfg, key, value)

# Set additional config variables
cfg = set_user_account(cfg)
cfg = set_node_info(cfg)

return cfg


def set_user_account(cfg):
setattr(cfg, user, os.environ['USER'])
if cfg.user == 'jenkins':
# g110 account for Jenkins testing
setattr(cfg, compute_account, 'g110')
elif os.path.exists(os.environ['HOME'] + '/.acct'):
# Use account specified in ~/.acct file
with open(os.environ['HOME'] + '/.acct', 'r') as file:
setattr(cfg, compute_account, file.read().rstrip())
else:
# Use standard account
setattr(cfg, compute_account,
os.popen("id -gn").read().splitlines()[0])

return cfg


def set_node_info(cfg):
if cfg.constraint == 'gpu':
setattr(cfg, ntasks_per_node, 12)
setattr(cfg, mpich_cuda, ('export MPICH_RDMA_ENABLED_CUDA=1\n'
'export MPICH_G2G_PIPELINE=256\n'
'export CRAY_CUDA_MPS=1\n'))
elif cfg.constraint == 'mc':
setattr(cfg, ntasks_per_node, 36)
setattr(cfg, mpich_cuda, '')
else:
raise ValueError("Invalid value for 'constraint' in the configuration."
"It should be either 'gpu' or 'mc'.")

return cfg


def set_paths_and_case(cfg):
# Root directory of the sourcecode of the chain (where run_chain.py is)
setattr(cfg, chain_src_dir, os.getcwd())
return self

# The case name is the name of the case directory
setattr(cfg, casename, os.path.basename(os.path.dirname(path)))

# Path of the case files
setattr(cfg, path, os.path.join(cfg.chain_src.dir, cfg.casename))
def print_config(self):
# Print the configuration
print("Global Attributes:")
print(f"casename: {self.casename}")
print(f"user: {self.user}")
print(f"chain_src_dir: {self.chain_src_dir}")
print(f"path: {self.path}")
print(f"work_root: {self.work_root}")

# Root directory of the working space of the chain
setattr(cfg, work_root, os.path.join(chain_src_dir, 'work'))
print("\nUser-defined attributes:")
for key, value in self.user_config:
print(f"{key}: {value}")

return cfg
print("\nDerived attributes:")
print(f"compute_account: {self.compute_account}")
print(f"ntasks_per_node: {self.ntasks_per_node}")
print(f"mpich_cuda: {self.mpich_cuda}")


def run_chain(work_root, model_cfg, cfg, start_time, hstart, hstop, job_names,
Expand Down Expand Up @@ -637,11 +650,11 @@ def load_model_config_yaml(yamlfile):
if __name__ == '__main__':
args = parse_arguments()

# 'empty' config object to be overwritten by load_config_file
cfg = None
for casename in args.casenames:
model_cfg = load_model_config_yaml('config/models.yaml')
cfg = load_config_file(casename=casename, cfg=cfg)
cfg = Config(casename)
cfg.print_config()
sys.exit()
start_time = datetime.strptime(args.startdate, '%Y-%m-%d')
if args.job_list is None:
args.job_list = model_cfg['models'][cfg.model]['jobs']
Expand Down

0 comments on commit 86976b0

Please sign in to comment.