Skip to content

Commit

Permalink
expand defaults keyerror test
Browse files Browse the repository at this point in the history
  • Loading branch information
trevorb1 committed Feb 1, 2024
1 parent 324dfcd commit 2007bfe
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 14 deletions.
16 changes: 9 additions & 7 deletions src/otoole/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,14 @@ def write(
handle = self._header()
logger.debug(default_values)

self.inputs = inputs # parameter/set data OR result data
self.inputs = inputs # parameter/set data OR result data
input_data = kwargs.get("input_data", None)

if self.write_defaults:
try:
self.inputs = self._expand_defaults(inputs, default_values, input_data)
except KeyError as ex:
logger.debug(f"Can not write default values due to missing {ex} data")
print(f"Can not write default values due to missing {ex} data")

for name, df in sorted(self.inputs.items()):
logger.debug("%s has %s columns: %s", name, len(df.index.names), df.columns)
Expand All @@ -291,13 +290,16 @@ def write(
handle.close()

def _expand_defaults(
self, inputs: Dict[str, pd.DataFrame], default_values: Dict[str, float], input_data: Dict[str, pd.DataFrame] = None
self,
inputs: Dict[str, pd.DataFrame],
default_values: Dict[str, float],
input_data: Dict[str, pd.DataFrame] = None,
) -> Dict[str, pd.DataFrame]:
"""Populates default value entry rows in dataframes
Parameters
----------
inputs : Dict[str, pd.DataFrame],
inputs : Dict[str, pd.DataFrame],
param/set data or result data
default_values : Dict[str, float]
defaults of param/result data
Expand All @@ -311,8 +313,8 @@ def _expand_defaults(
"""

sets = [x for x in self.user_config if self.user_config[x]["type"] == "set"]
input_data = input_data if input_data else inputs.copy()
input_data = input_data if input_data else inputs.copy()

output = {}
for name, data in inputs.items():
logger.info(f"Writing defaults for {name}")
Expand Down
33 changes: 26 additions & 7 deletions tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def capital_cost():
).set_index(["REGION", "TECHNOLOGY", "YEAR"])
return df


@fixture
def new_capacity():
df = pd.DataFrame(
Expand All @@ -50,6 +51,7 @@ def new_capacity():
).set_index(["REGION", "TECHNOLOGY", "YEAR"])
return df


@fixture()
def simple_default_values():
default_values = {}
Expand All @@ -71,9 +73,8 @@ def simple_input_data(region, year, technology, capital_cost):

@fixture
def simple_result_data(new_capacity):
return {
"NewCapacity": new_capacity
}
return {"NewCapacity": new_capacity}


@fixture
def simple_user_config():
Expand Down Expand Up @@ -285,15 +286,19 @@ def test_expand_parameters_defaults(
assert_frame_equal(actual[parameter], expected)

def test_expand_result_defaults(
self, simple_user_config, simple_default_values, simple_input_data, simple_result_data
self,
simple_user_config,
simple_default_values,
simple_input_data,
simple_result_data,
):
write_strategy = DummyWriteStrategy(
user_config=simple_user_config, default_values=simple_default_values
)
actual = write_strategy._expand_defaults(
simple_result_data, write_strategy.default_values, simple_input_data
)

expected = pd.DataFrame(
data=[
["SIMPLICITY", "HYD1", 2014, 2.34],
Expand All @@ -305,9 +310,24 @@ def test_expand_result_defaults(
],
columns=["REGION", "TECHNOLOGY", "YEAR", "VALUE"],
).set_index(["REGION", "TECHNOLOGY", "YEAR"])

assert_frame_equal(actual["NewCapacity"], expected)

def test_expand_results_key_error(
self, simple_user_config, simple_result_data, simple_default_values
):
"""When input data is just the result data"""
write_strategy = DummyWriteStrategy(
user_config=simple_user_config,
default_values=simple_default_values,
write_defaults=True,
)

with raises(KeyError, match="REGION"):
write_strategy._expand_defaults(
simple_result_data, write_strategy.default_values
)


class TestReadStrategy:

Expand Down Expand Up @@ -531,4 +551,3 @@ def test_compare_read_to_expected_exception(self, simple_user_config, expected):
reader = DummyReadStrategy(simple_user_config)
with raises(OtooleNameMismatchError):
reader._compare_read_to_expected(names=expected)

0 comments on commit 2007bfe

Please sign in to comment.