diff --git a/src/otoole/input.py b/src/otoole/input.py index 0b26a3f..299218a 100644 --- a/src/otoole/input.py +++ b/src/otoole/input.py @@ -259,12 +259,6 @@ def write( self.inputs = inputs # parameter/set data OR result data input_data = kwargs.get("input_data", None) - if self.write_defaults: - try: - self.inputs = self._expand_defaults(inputs, default_values, input_data) - except KeyError as ex: - logger.debug(f"Can not write default values due to missing {ex} data") - for name, df in sorted(self.inputs.items()): logger.debug("%s has %s columns: %s", name, len(df.index.names), df.columns) @@ -278,8 +272,26 @@ def write( if entity_type != "set": default_value = default_values[name] + # This should be moved inside the loop and performed once for each parameter + if self.write_defaults: + try: + logger.info(f"Expanding {name} with default values") + df_expand = self._expand_dataframe( + df, default_value, input_data + ) + except KeyError as ex: + logger.info( + f"Unable to write default values due to missing {ex} data" + ) + else: + df_expand = df + self._write_parameter( - df, name, handle, default=default_value, input_data=self.inputs + df_expand, + name, + handle, + default=default_value, + input_data=self.inputs, ) else: self._write_set(df, name, handle) @@ -289,6 +301,38 @@ def write( if isinstance(handle, TextIO): handle.close() + def _expand_dataframe( + self, data: pd.DataFrame, default: float, input_data: dict[str, pd.DataFrame] + ) -> pd.DataFrame: + """Expand an individual dataframe with default values""" + # save set information for each parameter + index_data = {} + for index in data.index.names: + index_data[index] = input_data[index]["VALUE"].to_list() + + # set index + if len(index_data) > 1: + new_index = pd.MultiIndex.from_product( + list(index_data.values()), names=list(index_data.keys()) + ) + else: + new_index = pd.Index( + list(index_data.values())[0], name=list(index_data.keys())[0] + ) + df_default = pd.DataFrame(index=new_index, dtype="float16") + + # save default result value + df_default["VALUE"] = default + + # combine result and default value dataframe + if not data.empty: + df = pd.concat([data, df_default]) + df = df[~df.index.duplicated(keep="first")] + else: + df = df_default + df = df.sort_index() + return df + def _expand_defaults( self, inputs: Dict[str, pd.DataFrame], @@ -317,7 +361,6 @@ def _expand_defaults( output = {} for name, data in inputs.items(): - logger.info(f"Writing defaults for {name}") # skip sets if name in sets: @@ -331,33 +374,9 @@ def _expand_defaults( output[name] = data continue - # save set information for each parameter - index_data = {} - for index in data.index.names: - index_data[index] = input_data[index]["VALUE"].to_list() - - # set index - if len(index_data) > 1: - new_index = pd.MultiIndex.from_product( - list(index_data.values()), names=list(index_data.keys()) - ) - else: - new_index = pd.Index( - list(index_data.values())[0], name=list(index_data.keys())[0] - ) - df_default = pd.DataFrame(index=new_index) - - # save default result value - df_default["VALUE"] = default_values[name] - - # combine result and default value dataframe - if not data.empty: - df = pd.concat([data, df_default]) - df = df[~df.index.duplicated(keep="first")] - else: - df = df_default - df = df.sort_index() - output[name] = df + output[name] = self._expand_dataframe( + data, default_values[name], input_data + ) return output