Skip to content

Commit

Permalink
Refactor expand defaults to reduce memory use
Browse files Browse the repository at this point in the history
  • Loading branch information
willu47 committed Feb 7, 2024
1 parent 4258955 commit 7849353
Showing 1 changed file with 54 additions and 35 deletions.
89 changes: 54 additions & 35 deletions src/otoole/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,6 @@ def write(
self.inputs = inputs # parameter/set data OR result data
input_data = kwargs.get("input_data", None)

if self.write_defaults:
try:
self.inputs = self._expand_defaults(inputs, default_values, input_data)
except KeyError as ex:
logger.debug(f"Can not write default values due to missing {ex} data")

for name, df in sorted(self.inputs.items()):
logger.debug("%s has %s columns: %s", name, len(df.index.names), df.columns)

Expand All @@ -278,8 +272,26 @@ def write(

if entity_type != "set":
default_value = default_values[name]
# This should be moved inside the loop and performed once for each parameter
if self.write_defaults:
try:
logger.info(f"Expanding {name} with default values")
df_expand = self._expand_dataframe(
df, default_value, input_data
)
except KeyError as ex:
logger.info(
f"Unable to write default values due to missing {ex} data"
)
else:
df_expand = df

self._write_parameter(
df, name, handle, default=default_value, input_data=self.inputs
df_expand,
name,
handle,
default=default_value,
input_data=self.inputs,
)
else:
self._write_set(df, name, handle)
Expand All @@ -289,6 +301,38 @@ def write(
if isinstance(handle, TextIO):
handle.close()

def _expand_dataframe(
self, data: pd.DataFrame, default: float, input_data: dict[str, pd.DataFrame]
) -> pd.DataFrame:
"""Expand an individual dataframe with default values"""
# save set information for each parameter
index_data = {}
for index in data.index.names:
index_data[index] = input_data[index]["VALUE"].to_list()

# set index
if len(index_data) > 1:
new_index = pd.MultiIndex.from_product(
list(index_data.values()), names=list(index_data.keys())
)
else:
new_index = pd.Index(
list(index_data.values())[0], name=list(index_data.keys())[0]
)
df_default = pd.DataFrame(index=new_index, dtype="float16")

# save default result value
df_default["VALUE"] = default

# combine result and default value dataframe
if not data.empty:
df = pd.concat([data, df_default])
df = df[~df.index.duplicated(keep="first")]
else:
df = df_default
df = df.sort_index()
return df

def _expand_defaults(
self,
inputs: Dict[str, pd.DataFrame],
Expand Down Expand Up @@ -317,7 +361,6 @@ def _expand_defaults(

output = {}
for name, data in inputs.items():
logger.info(f"Writing defaults for {name}")

# skip sets
if name in sets:
Expand All @@ -331,33 +374,9 @@ def _expand_defaults(
output[name] = data
continue

# save set information for each parameter
index_data = {}
for index in data.index.names:
index_data[index] = input_data[index]["VALUE"].to_list()

# set index
if len(index_data) > 1:
new_index = pd.MultiIndex.from_product(
list(index_data.values()), names=list(index_data.keys())
)
else:
new_index = pd.Index(
list(index_data.values())[0], name=list(index_data.keys())[0]
)
df_default = pd.DataFrame(index=new_index)

# save default result value
df_default["VALUE"] = default_values[name]

# combine result and default value dataframe
if not data.empty:
df = pd.concat([data, df_default])
df = df[~df.index.duplicated(keep="first")]
else:
df = df_default
df = df.sort_index()
output[name] = df
output[name] = self._expand_dataframe(
data, default_values[name], input_data
)

return output

Expand Down

0 comments on commit 7849353

Please sign in to comment.