diff --git a/src/otoole/input.py b/src/otoole/input.py index 28d747f..0b26a3f 100644 --- a/src/otoole/input.py +++ b/src/otoole/input.py @@ -256,15 +256,14 @@ def write( handle = self._header() logger.debug(default_values) - self.inputs = inputs # parameter/set data OR result data + self.inputs = inputs # parameter/set data OR result data input_data = kwargs.get("input_data", None) - + if self.write_defaults: try: self.inputs = self._expand_defaults(inputs, default_values, input_data) except KeyError as ex: logger.debug(f"Can not write default values due to missing {ex} data") - print(f"Can not write default values due to missing {ex} data") for name, df in sorted(self.inputs.items()): logger.debug("%s has %s columns: %s", name, len(df.index.names), df.columns) @@ -291,13 +290,16 @@ def write( handle.close() def _expand_defaults( - self, inputs: Dict[str, pd.DataFrame], default_values: Dict[str, float], input_data: Dict[str, pd.DataFrame] = None + self, + inputs: Dict[str, pd.DataFrame], + default_values: Dict[str, float], + input_data: Dict[str, pd.DataFrame] = None, ) -> Dict[str, pd.DataFrame]: """Populates default value entry rows in dataframes Parameters ---------- - inputs : Dict[str, pd.DataFrame], + inputs : Dict[str, pd.DataFrame], param/set data or result data default_values : Dict[str, float] defaults of param/result data @@ -311,8 +313,8 @@ def _expand_defaults( """ sets = [x for x in self.user_config if self.user_config[x]["type"] == "set"] - input_data = input_data if input_data else inputs.copy() - + input_data = input_data if input_data else inputs.copy() + output = {} for name, data in inputs.items(): logger.info(f"Writing defaults for {name}") diff --git a/tests/test_input.py b/tests/test_input.py index 6637fc0..bfee24c 100644 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -38,6 +38,7 @@ def capital_cost(): ).set_index(["REGION", "TECHNOLOGY", "YEAR"]) return df + @fixture def new_capacity(): df = pd.DataFrame( @@ -50,6 +51,7 @@ def new_capacity(): ).set_index(["REGION", "TECHNOLOGY", "YEAR"]) return df + @fixture() def simple_default_values(): default_values = {} @@ -71,9 +73,8 @@ def simple_input_data(region, year, technology, capital_cost): @fixture def simple_result_data(new_capacity): - return { - "NewCapacity": new_capacity - } + return {"NewCapacity": new_capacity} + @fixture def simple_user_config(): @@ -285,7 +286,11 @@ def test_expand_parameters_defaults( assert_frame_equal(actual[parameter], expected) def test_expand_result_defaults( - self, simple_user_config, simple_default_values, simple_input_data, simple_result_data + self, + simple_user_config, + simple_default_values, + simple_input_data, + simple_result_data, ): write_strategy = DummyWriteStrategy( user_config=simple_user_config, default_values=simple_default_values @@ -293,7 +298,7 @@ def test_expand_result_defaults( actual = write_strategy._expand_defaults( simple_result_data, write_strategy.default_values, simple_input_data ) - + expected = pd.DataFrame( data=[ ["SIMPLICITY", "HYD1", 2014, 2.34], @@ -305,9 +310,24 @@ def test_expand_result_defaults( ], columns=["REGION", "TECHNOLOGY", "YEAR", "VALUE"], ).set_index(["REGION", "TECHNOLOGY", "YEAR"]) - + assert_frame_equal(actual["NewCapacity"], expected) + def test_expand_results_key_error( + self, simple_user_config, simple_result_data, simple_default_values + ): + """When input data is just the result data""" + write_strategy = DummyWriteStrategy( + user_config=simple_user_config, + default_values=simple_default_values, + write_defaults=True, + ) + + with raises(KeyError, match="REGION"): + write_strategy._expand_defaults( + simple_result_data, write_strategy.default_values + ) + class TestReadStrategy: @@ -531,4 +551,3 @@ def test_compare_read_to_expected_exception(self, simple_user_config, expected): reader = DummyReadStrategy(simple_user_config) with raises(OtooleNameMismatchError): reader._compare_read_to_expected(names=expected) - \ No newline at end of file