diff --git a/src/otoole/input.py b/src/otoole/input.py index a9efc11..ff828a7 100644 --- a/src/otoole/input.py +++ b/src/otoole/input.py @@ -257,7 +257,7 @@ def write( logger.debug(default_values) self.inputs = inputs # parameter/set data OR result data - input_data = kwargs.get("input_data", None) + self.input_params = kwargs.get("input_data", None) # parameter/set data for name, df in sorted(self.inputs.items()): logger.debug("%s has %s columns: %s", name, len(df.index.names), df.columns) @@ -271,26 +271,16 @@ def write( raise KeyError("Cannot find %s in input or results config", name) if entity_type != "set": - default_value = default_values[name] - # This should be moved inside the loop and performed once for each parameter - if self.write_defaults and "Annual" in name: - try: - logger.info(f"Expanding {name} with default values") - df_expand = self._expand_dataframe( - df, default_value, input_data - ) - except KeyError as ex: - logger.info( - f"Unable to write default values due to missing {ex} data" - ) + if self.write_defaults: + df_out = self._expand_dataframe(name, df) else: - df_expand = df + df_out = df self._write_parameter( - df_expand, + df_out, name, handle, - default=default_value, + default=default_values[name], input_data=self.inputs, ) else: @@ -301,84 +291,62 @@ def write( if isinstance(handle, TextIO): handle.close() - def _expand_dataframe( - self, data: pd.DataFrame, default: float, input_data: dict[str, pd.DataFrame] - ) -> pd.DataFrame: - """Expand an individual dataframe with default values""" - # save set information for each parameter - index_data = {} - for index in data.index.names: - index_data[index] = input_data[index]["VALUE"].to_list() - - # set index - if len(index_data) > 1: - new_index = pd.MultiIndex.from_product( - list(index_data.values()), names=list(index_data.keys()) - ) - else: - new_index = pd.Index( - list(index_data.values())[0], name=list(index_data.keys())[0] - ) - df_default = pd.DataFrame(index=new_index, dtype="float16") - - # save default result value - df_default["VALUE"] = default - - # combine result and default value dataframe - if not data.empty: - df = pd.concat([data, df_default]) - df = df[~df.index.duplicated(keep="first")] - else: - df = df_default - df = df.sort_index() - return df - - def _expand_defaults( - self, - inputs: Dict[str, pd.DataFrame], - default_values: Dict[str, float], - input_data: Dict[str, pd.DataFrame] = None, - ) -> Dict[str, pd.DataFrame]: + def _expand_dataframe(self, name: str, df: pd.DataFrame) -> Dict[str, pd.DataFrame]: """Populates default value entry rows in dataframes Parameters ---------- - inputs : Dict[str, pd.DataFrame], - param/set data or result data - default_values : Dict[str, float] - defaults of param/result data - input_data: Dict[str, pd.DataFrame] - param/set data needed for expanding result data + name: str + Name of parameter/result to expand + df: pd.DataFrame, + input parameter/result data to be expanded Returns ------- - Dict[str, pd.DataFrame] + pd.DataFrame, Input data with expanded default values replacing missing entries """ - sets = [x for x in self.user_config if self.user_config[x]["type"] == "set"] - input_data = input_data if input_data else inputs.copy() + # TODO: Issue with how otoole handles trade route right now. + # The double definition of REGION throws an error. + if name == "TradeRoute": + return df + + default_df = self._get_default_dataframe(name) - output = {} - for name, data in inputs.items(): + df = pd.concat([df, default_df]) + df = df[~df.index.duplicated(keep="first")] + return df.sort_index() - # skip sets - if name in sets: - output[name] = data - continue + # default_df.update(df) + # return default_df.sort_index() - # TODO - # Issue with how otoole handles trade route right now. - # The double definition of REGION throws an error. - if name == "TradeRoute": - output[name] = data - continue + def _get_default_dataframe(self, name: str) -> pd.DataFrame: + """Creates default dataframe""" + + index_data = {} + indices = self.user_config[name]["indices"] + try: # result data + for index in indices: + index_data[index] = self.input_params[index]["VALUE"].to_list() + except (TypeError, KeyError): # parameter data + for index in indices: + index_data[index] = self.inputs[index]["VALUE"].to_list() - output[name] = self._expand_dataframe( - data, default_values[name], input_data + if len(index_data) > 1: + new_index = pd.MultiIndex.from_product( + list(index_data.values()), names=list(index_data.keys()) + ) + else: + new_index = pd.Index( + list(index_data.values())[0], name=list(index_data.keys())[0] ) - return output + df = pd.DataFrame(index=new_index) + df["VALUE"] = self.default_values[name] + df["VALUE"] = df.VALUE.astype(self.user_config[name]["dtype"]) + + return df class ReadStrategy(Strategy):