diff --git a/docs/demo/data/processed/example_override_data.csv.gz b/docs/demo/data/processed/example_override_data.csv.gz new file mode 100644 index 00000000..a88da4fa Binary files /dev/null and b/docs/demo/data/processed/example_override_data.csv.gz differ diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 8906a98a..d9102195 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -414,13 +414,23 @@ def test_deny(self): self.assertDictEqual(d2, reply) def test_data_read(self): - data_path = "../docs/demo/data/processed/timeries_data.csv" + node = Node(name="", data_input_dict={("temperature", 1): 15}) + node.t = 1 + + self.assertEqual(15, node.get_data_input("temperature")) + + def test_data_overrides(self): + data_path = "../docs/demo/data/processed/example_override_data.csv.gz" input_data = pd.read_csv(data_path) - node = Node(name="", data_input_dict=data_path) - node.t = Node.data_input_dict.keys()[0][1] + + overrides = {'data_input_dict': data_path} + node = Node(name="") + node.apply_overrides(overrides) + node.t = list(node.data_input_dict.keys())[0][1] self.assertEqual( - input_data["temperature"].iloc[0], node.get_data_input("temperature") + input_data.groupby("variable").get_group("temperature")["value"].iloc[0], + node.get_data_input("temperature") ) diff --git a/wsimod/nodes/nodes.py b/wsimod/nodes/nodes.py index b3a2dfdf..ef17f0f7 100644 --- a/wsimod/nodes/nodes.py +++ b/wsimod/nodes/nodes.py @@ -89,7 +89,7 @@ def apply_overrides(self, overrides: Dict[str, Any] = {}) -> None: overrides (dict, optional): Dictionary of overrides. Defaults to {}. """ # overrides data_input_dict - + content = overrides.pop("data_input_dict", self.data_input_dict) if isinstance(content, str): self.data_input_dict = read_csv(content)