diff --git a/cadCAD/tools/execution/easy_run.py b/cadCAD/tools/execution/easy_run.py index a449bf41..7d1827b1 100644 --- a/cadCAD/tools/execution/easy_run.py +++ b/cadCAD/tools/execution/easy_run.py @@ -2,7 +2,7 @@ import types from typing import Dict, Union -import pandas as pd # type: ignore +import pandas as pd # type: ignore from cadCAD.configuration import Experiment from cadCAD.configuration.utils import config_sim from cadCAD.engine import ExecutionContext, ExecutionMode, Executor @@ -47,8 +47,9 @@ def easy_run( """ # Set-up sim_config - simulation_parameters = {'N': N_samples, 'T': range(N_timesteps), 'M': params} - sim_config = config_sim(simulation_parameters) # type: ignore + simulation_parameters = {'N': N_samples, + 'T': range(N_timesteps), 'M': params} + sim_config = config_sim(simulation_parameters) # type: ignore # Create a new experiment exp = Experiment() @@ -91,22 +92,27 @@ def easy_run( if assign_params == True: pass else: - params_set &= assign_params # type: ignore + params_set &= assign_params # type: ignore # Logic for getting the assign params criteria if type(assign_params) is list: - selected_params = set(assign_params) & params_set # type: ignore + selected_params = set(assign_params) & params_set # type: ignore elif type(assign_params) is set: selected_params = assign_params & params_set else: selected_params = params_set + # Attribute parameters to each row* + params_dict = select_config_M_dict(configs, 0, selected_params) + + # Handles all cases of parameter types including list + for key, value in params_dict.items(): + df[key] = df.apply(lambda _: value, axis=1) - # Attribute parameters to each row - df = df.assign(**select_config_M_dict(configs, 0, selected_params)) for i, (_, n_df) in enumerate(df.groupby(['simulation', 'subset', 'run'])): - df.loc[n_df.index] = n_df.assign( - **select_config_M_dict(configs, i, selected_params) - ) + params_dict = select_config_M_dict(configs, i, selected_params) + for key, value in params_dict.items(): + df.loc[n_df.index, key] = df.loc[n_df.index].apply( + lambda _: value, axis=1) # Based on Vitor Marthendal (@marthendalnunes) snippet if use_label == True: