Skip to content

Commit

Permalink
refactor expand defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
trevorb1 committed Feb 13, 2024
1 parent 551b8bf commit d6aaf4e
Showing 1 changed file with 46 additions and 78 deletions.
124 changes: 46 additions & 78 deletions src/otoole/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def write(
logger.debug(default_values)

self.inputs = inputs # parameter/set data OR result data
input_data = kwargs.get("input_data", None)
self.input_params = kwargs.get("input_data", None) # parameter/set data

for name, df in sorted(self.inputs.items()):
logger.debug("%s has %s columns: %s", name, len(df.index.names), df.columns)
Expand All @@ -271,26 +271,16 @@ def write(
raise KeyError("Cannot find %s in input or results config", name)

if entity_type != "set":
default_value = default_values[name]
# This should be moved inside the loop and performed once for each parameter
if self.write_defaults and "Annual" in name:
try:
logger.info(f"Expanding {name} with default values")
df_expand = self._expand_dataframe(
df, default_value, input_data
)
except KeyError as ex:
logger.info(
f"Unable to write default values due to missing {ex} data"
)
if self.write_defaults:
df_out = self._expand_dataframe(name, df)
else:
df_expand = df
df_out = df

self._write_parameter(
df_expand,
df_out,
name,
handle,
default=default_value,
default=default_values[name],
input_data=self.inputs,
)
else:
Expand All @@ -301,84 +291,62 @@ def write(
if isinstance(handle, TextIO):
handle.close()

def _expand_dataframe(
self, data: pd.DataFrame, default: float, input_data: dict[str, pd.DataFrame]
) -> pd.DataFrame:
"""Expand an individual dataframe with default values"""
# save set information for each parameter
index_data = {}
for index in data.index.names:
index_data[index] = input_data[index]["VALUE"].to_list()

# set index
if len(index_data) > 1:
new_index = pd.MultiIndex.from_product(
list(index_data.values()), names=list(index_data.keys())
)
else:
new_index = pd.Index(
list(index_data.values())[0], name=list(index_data.keys())[0]
)
df_default = pd.DataFrame(index=new_index, dtype="float16")

# save default result value
df_default["VALUE"] = default

# combine result and default value dataframe
if not data.empty:
df = pd.concat([data, df_default])
df = df[~df.index.duplicated(keep="first")]
else:
df = df_default
df = df.sort_index()
return df

def _expand_defaults(
self,
inputs: Dict[str, pd.DataFrame],
default_values: Dict[str, float],
input_data: Dict[str, pd.DataFrame] = None,
) -> Dict[str, pd.DataFrame]:
def _expand_dataframe(self, name: str, df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
"""Populates default value entry rows in dataframes
Parameters
----------
inputs : Dict[str, pd.DataFrame],
param/set data or result data
default_values : Dict[str, float]
defaults of param/result data
input_data: Dict[str, pd.DataFrame]
param/set data needed for expanding result data
name: str
Name of parameter/result to expand
df: pd.DataFrame,
input parameter/result data to be expanded
Returns
-------
Dict[str, pd.DataFrame]
pd.DataFrame,
Input data with expanded default values replacing missing entries
"""

sets = [x for x in self.user_config if self.user_config[x]["type"] == "set"]
input_data = input_data if input_data else inputs.copy()
# TODO: Issue with how otoole handles trade route right now.
# The double definition of REGION throws an error.
if name == "TradeRoute":
return df

default_df = self._get_default_dataframe(name)

output = {}
for name, data in inputs.items():
df = pd.concat([df, default_df])
df = df[~df.index.duplicated(keep="first")]
return df.sort_index()

# skip sets
if name in sets:
output[name] = data
continue
# default_df.update(df)
# return default_df.sort_index()

# TODO
# Issue with how otoole handles trade route right now.
# The double definition of REGION throws an error.
if name == "TradeRoute":
output[name] = data
continue
def _get_default_dataframe(self, name: str) -> pd.DataFrame:
"""Creates default dataframe"""

index_data = {}
indices = self.user_config[name]["indices"]
try: # result data
for index in indices:
index_data[index] = self.input_params[index]["VALUE"].to_list()
except (TypeError, KeyError): # parameter data
for index in indices:
index_data[index] = self.inputs[index]["VALUE"].to_list()

output[name] = self._expand_dataframe(
data, default_values[name], input_data
if len(index_data) > 1:
new_index = pd.MultiIndex.from_product(
list(index_data.values()), names=list(index_data.keys())
)
else:
new_index = pd.Index(
list(index_data.values())[0], name=list(index_data.keys())[0]
)

return output
df = pd.DataFrame(index=new_index)
df["VALUE"] = self.default_values[name]
df["VALUE"] = df.VALUE.astype(self.user_config[name]["dtype"])

return df


class ReadStrategy(Strategy):
Expand Down

0 comments on commit d6aaf4e

Please sign in to comment.