Skip to content

Commit

Permalink
fix expand defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
trevorb1 committed Feb 1, 2024
1 parent 15bdc0b commit 324dfcd
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 47 deletions.
28 changes: 18 additions & 10 deletions src/otoole/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,17 @@ def write(
handle = self._header()
logger.debug(default_values)

self.input_data = inputs
self.inputs = inputs # parameter/set data OR result data
input_data = kwargs.get("input_data", None)

if self.write_defaults:
try:
self.input_data = self._expand_defaults(inputs, default_values)
self.inputs = self._expand_defaults(inputs, default_values, input_data)
except KeyError as ex:
logger.debug(ex)
logger.debug(f"Can not write default values due to missing {ex} data")
print(f"Can not write default values due to missing {ex} data")

for name, df in sorted(self.input_data.items()):
for name, df in sorted(self.inputs.items()):
logger.debug("%s has %s columns: %s", name, len(df.index.names), df.columns)

try:
Expand All @@ -277,7 +280,7 @@ def write(
if entity_type != "set":
default_value = default_values[name]
self._write_parameter(
df, name, handle, default=default_value, input_data=self.input_data
df, name, handle, default=default_value, input_data=self.inputs
)
else:
self._write_set(df, name, handle)
Expand All @@ -288,25 +291,30 @@ def write(
handle.close()

def _expand_defaults(
self, data_to_expand: Dict[str, pd.DataFrame], default_values: Dict[str, float]
self, inputs: Dict[str, pd.DataFrame], default_values: Dict[str, float], input_data: Dict[str, pd.DataFrame] = None
) -> Dict[str, pd.DataFrame]:
"""Populates default value entry rows in dataframes
Parameters
----------
data_to_expand : Dict[str, pd.DataFrame],
inputs : Dict[str, pd.DataFrame],
param/set data or result data
default_values : Dict[str, float]
defaults of param/result data
input_data: Dict[str, pd.DataFrame]
param/set data needed for expanding result data
Returns
-------
Dict[str, pd.DataFrame]
Input data with expanded default values replacing missing entries
"""

sets = [x for x in self.user_config if self.user_config[x]["type"] == "set"]
input_data = input_data if input_data else inputs.copy()

output = {}
for name, data in data_to_expand.items():
for name, data in inputs.items():
logger.info(f"Writing defaults for {name}")

# skip sets
Expand All @@ -324,7 +332,7 @@ def _expand_defaults(
# save set information for each parameter
index_data = {}
for index in data.index.names:
index_data[index] = self.input_data[index]["VALUE"].to_list()
index_data[index] = input_data[index]["VALUE"].to_list()

# set index
if len(index_data) > 1:
Expand Down
82 changes: 45 additions & 37 deletions tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def capital_cost():
).set_index(["REGION", "TECHNOLOGY", "YEAR"])
return df

@fixture
def new_capacity():
df = pd.DataFrame(
data=[
["SIMPLICITY", "NGCC", 2016, 1.23],
["SIMPLICITY", "HYD1", 2014, 2.34],
["SIMPLICITY", "HYD1", 2015, 3.45],
],
columns=["REGION", "TECHNOLOGY", "YEAR", "VALUE"],
).set_index(["REGION", "TECHNOLOGY", "YEAR"])
return df

@fixture()
def simple_default_values():
Expand All @@ -58,6 +69,12 @@ def simple_input_data(region, year, technology, capital_cost):
}


@fixture
def simple_result_data(new_capacity):
return {
"NewCapacity": new_capacity
}

@fixture
def simple_user_config():
return {
Expand All @@ -80,13 +97,19 @@ def simple_user_config():
"dtype": "int",
"type": "set",
},
"NewCapacity": {
"indices": ["REGION", "TECHNOLOGY", "YEAR"],
"type": "result",
"dtype": "float",
"default": 0,
},
}


# To instantiate abstract class WriteStrategy
class DummyWriteStrategy(WriteStrategy):
def _header(self) -> Union[TextIO, Any]:
raise NotImplementedError()
pass

def _write_parameter(
self,
Expand All @@ -96,13 +119,13 @@ def _write_parameter(
default: float,
**kwargs
) -> pd.DataFrame:
raise NotImplementedError()
pass

def _write_set(self, df: pd.DataFrame, set_name, handle: TextIO) -> pd.DataFrame:
raise NotImplementedError()
pass

def _footer(self, handle: TextIO):
raise NotImplementedError()
pass


# To instantiate abstract class ReadStrategy
Expand Down Expand Up @@ -229,34 +252,6 @@ def input_data_single_index_empty(region):
}
return data, "DiscountRate", discount_rate_out

@fixture
def result_data(region):
new_capacity_in = pd.DataFrame(
[
["SIMPLICITY", "HYD1", 2015, 100],
["SIMPLICITY", "HYD1", 2016, 0.1],
["SIMPLICITY", "NGCC", 2014, 0.5],
["SIMPLICITY", "NGCC", 2015, 100],
],
columns=["REGION", "TECHNOLOGY", "YEAR", "VALUE"],
).set_index(["REGION", "TECHNOLOGY", "YEAR"])
new_capacity_out = pd.DataFrame(
[
["SIMPLICITY", "HYD1", 2014, 20],
["SIMPLICITY", "HYD1", 2015, 100],
["SIMPLICITY", "HYD1", 2016, 0.1],
["SIMPLICITY", "NGCC", 2014, 0.5],
["SIMPLICITY", "NGCC", 2015, 100],
["SIMPLICITY", "NGCC", 2016, 20],
],
columns=["REGION", "TECHNOLOGY", "YEAR", "VALUE"],
).set_index(["REGION", "TECHNOLOGY", "YEAR"])

data = {
"NewCapacity": new_capacity_in,
}
return data, "NewCapacity", new_capacity_out

parameter_test_data = [
input_data_multi_index_no_defaults(region, technology, year),
input_data_multi_index(region, technology, year),
Expand Down Expand Up @@ -290,16 +285,28 @@ def test_expand_parameters_defaults(
assert_frame_equal(actual[parameter], expected)

def test_expand_result_defaults(
self, user_config, simple_default_values, simple_input_data, result_data
self, simple_user_config, simple_default_values, simple_input_data, simple_result_data
):
write_strategy = DummyWriteStrategy(
user_config=user_config, default_values=simple_default_values
user_config=simple_user_config, default_values=simple_default_values
)
write_strategy.input_data = simple_input_data
actual = write_strategy._expand_defaults(
result_data[0], write_strategy.default_values
simple_result_data, write_strategy.default_values, simple_input_data
)
assert_frame_equal(actual[result_data[1]], result_data[2])

expected = pd.DataFrame(
data=[
["SIMPLICITY", "HYD1", 2014, 2.34],
["SIMPLICITY", "HYD1", 2015, 3.45],
["SIMPLICITY", "HYD1", 2016, 20],
["SIMPLICITY", "NGCC", 2014, 20],
["SIMPLICITY", "NGCC", 2015, 20],
["SIMPLICITY", "NGCC", 2016, 1.23],
],
columns=["REGION", "TECHNOLOGY", "YEAR", "VALUE"],
).set_index(["REGION", "TECHNOLOGY", "YEAR"])

assert_frame_equal(actual["NewCapacity"], expected)


class TestReadStrategy:
Expand Down Expand Up @@ -524,3 +531,4 @@ def test_compare_read_to_expected_exception(self, simple_user_config, expected):
reader = DummyReadStrategy(simple_user_config)
with raises(OtooleNameMismatchError):
reader._compare_read_to_expected(names=expected)

0 comments on commit 324dfcd

Please sign in to comment.