Skip to content

Commit

Permalink
Improve error messages and add unit tests
Browse files Browse the repository at this point in the history
Now we have 100% coverage! :)
  • Loading branch information
lumbric committed May 16, 2024
1 parent a1f9f0e commit 68fa168
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 26 deletions.
42 changes: 22 additions & 20 deletions syfop/node_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@ def _preprocess_input_commodities(self, inputs, input_commodities):
if not all(isinstance(node, NodeBase) for node in inputs):
raise ValueError("inputs must be of type NodeBase or some subclass")

# some nodes don't have inputs, because nothing is connected to it, but we still need
# one input commodity (and an input flow). that means we require one input_commmodity if
# there are inputs and otherwise lengths should match up.
num_input_commodities = max(1, len(inputs))

if isinstance(input_commodities, str):
# some nodes don't have inputs, because nothing is connected to it, but we still need
# an input commodity (and an input flow)
input_commodities = max(1, len(inputs)) * [input_commodities]
elif len(inputs) != len(input_commodities) and not (
len(inputs) == 0 and len(input_commodities) == 1
):
input_commodities = num_input_commodities * [input_commodities]
elif num_input_commodities != len(input_commodities):
raise ValueError(
f"invalid number of input_commodities provided for node '{self.name}': "
f"{input_commodities}, does not match number of inputs: "
f"{', '.join(input_.name for input_ in inputs)}"
f"{list(input_.name for input_ in inputs)}"
)

return input_commodities
Expand Down Expand Up @@ -85,12 +86,12 @@ def has_costs(self):

def _get_flows(self, direction, flows, attached_nodes, commodities, commodity):
if len(attached_nodes) == 0:
if commodities != [commodity]:
raise ValueError(
f"node '{self.name}' has no {direction} nodes, therefore "
f"{direction}_commidities should be set to '{[commodity]}', "
f"but it is: {commodities} "
)
# this should never happen... we check this ear
assert commodities == [commodity], (
f"node '{self.name}' has no {direction} nodes, therefore "
f"{direction}_commidities should be set to '{[commodity]}', "
f"but it is: {commodities} "
)
return flows.values()
else:
return (
Expand Down Expand Up @@ -313,13 +314,14 @@ def _create_constraint_inout_flow_balance_commodity(
# operator properly in linopy if lhs is an xarray object in lhs == rhs?
#
if not isinstance(lhs, linopy.Variable) and not isinstance(lhs, linopy.LinearExpression):
if self.storage is not None:
# lhs means that sum of output flow nodes is not a variable, which means that we
# have a self is of type NodeFixOutput. then storage doesn't really make
# sense, so we can simply forbid this case.
# If we want to support it, we need to take care of a wrong sign when adding charge
# and discharge below to lhs.
raise RuntimeError("NodeFixOutput with Storage not supported")
# lhs means that sum of output flow nodes is not a variable, which means that we
# have a self is of type NodeFixOutput. then storage doesn't really make
# sense, so we can simply forbid this case.
# If we want to support it, we need to take care of a wrong sign when adding charge
# and discharge below to lhs.
# This already checked in NodeFixOutput.__init__().
assert self.storage is None, "NodeFixOutput with Storage not supported"

lhs, rhs = rhs, lhs
if isinstance(rhs, linopy.Variable) or isinstance(rhs, linopy.LinearExpression):
lhs = lhs - rhs
Expand Down
1 change: 1 addition & 0 deletions syfop/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_version():
# raise RuntimeError(f"Unable to get version from Git: {e}: {e.output.decode()}")

# TODO we don't check if the tag is a valid version string (e.g. v0.1.0).
# TODO --match would be nice extra option for git describe above
version = git_tag.lstrip("v")

# convert string from git-describe to a valid PEP440 version string
Expand Down
34 changes: 30 additions & 4 deletions tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,10 @@ def test_simple_co2_storage(storage_type):
def test_missing_node():
"""If a node is used as input but not passed to the Network constructor, this is an error.
This might change in future."""
wind = Node(
wind = NodeScalableInput(
name="wind",
inputs=[],
input_commodities=[],
costs=10,
input_profile=const_time_series(0.5),
costs=1 * ureg.EUR / ureg.MW,
)
electricity = Node(
name="electricity",
Expand Down Expand Up @@ -731,3 +730,30 @@ def test_missing_size_commodity_parameter_no_output():

with pytest.raises(ValueError, match=error_msg):
Network([wind, demand, curtailment])


def test_node_type_only():
# we might want to disallow time_coords without time unit, because it does not work for storage
# and input_flow_costs
time_coords = np.arange(10)

node1 = Node(
name="node1",
inputs=[],
input_commodities=["electricity"],
costs=10 * ureg.EUR / ureg.MW,
)
node2 = Node(
name="node2",
inputs=[node1],
input_commodities=["electricity"],
size_commodity="electricity",
costs=0,
)

network = Network([node1, node2], time_coords=time_coords)
network.optimize()

assert list(node1.input_flows.keys()) == [""]
assert list(node2.input_flows.keys()) == ["node1"]
assert network.model.solution.size_node1 == 0.0
22 changes: 20 additions & 2 deletions tests/test_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest

from syfop.node import (
Expand Down Expand Up @@ -55,13 +57,15 @@ def test_wrong_input_proportions_commodities(three_example_nodes):
)


def test_wrong_number_of_commodities(three_example_nodes):
def test_invalid_num_input_commodities1(three_example_nodes):
wind, solar_pv, _ = three_example_nodes

error_msg = (
"invalid number of input_commodities provided for node 'electricity': "
"\\['electricity'\\], does not match number of inputs: solar_pv, wind"
"['electricity'], does not match number of inputs: ['solar_pv', 'wind']"
)
error_msg = re.escape(error_msg)

with pytest.raises(ValueError, match=error_msg):
_ = Node(
name="electricity",
Expand All @@ -74,6 +78,20 @@ def test_wrong_number_of_commodities(three_example_nodes):
)


def test_invalid_num_input_commodities2():
error_msg = (
"invalid number of input_commodities provided for node 'gas': "
"[], does not match number of inputs: []"
)
with pytest.raises(ValueError, match=error_msg):
_ = Node(
name="gas",
inputs=[],
input_commodities=[],
costs=0,
)


def test_input_profile_not_capacity_factor():
error_msg = "invalid values in input_profile: must be capacity factors"

Expand Down
16 changes: 16 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import linopy
import numpy as np
import pytest

Expand Down Expand Up @@ -36,6 +37,21 @@ def test_print_constraints(some_model):
print_constraints(some_model)


def test_constraints_to_str_empty_constraints():
model = linopy.Model()
constraints_as_str = constraints_to_str(model)
assert constraints_as_str == ""


def test_constraints_to_str_empty_vars():
# this is a very stupid test to increase coverage to 100% :)
model = linopy.Model()
var = model.add_variables(name="some_var", lower=[])
model.add_constraints(var == 0)
constraints_as_str = constraints_to_str(model)
assert constraints_as_str == ""


def test_constraints_to_str(some_model):
constraints_as_str = constraints_to_str(some_model)
expected_output = """
Expand Down
19 changes: 19 additions & 0 deletions tests/test_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import re
from subprocess import CalledProcessError
from unittest.mock import patch

from syfop.version import get_version, version


def test_version():
pattern = re.compile(r"^\d+\.\d+\.\d+")
assert pattern.match(version)


@patch("subprocess.check_output")
def test_get_version(mock_check_output):
# this simulates a shallow clone with --depth=1 or missing tags
mock_check_output.side_effect = CalledProcessError(128, ["git", "describe", "--tags"])

version = get_version()
assert version == "0.0.0"

0 comments on commit 68fa168

Please sign in to comment.