Skip to content

Commit

Permalink
Make kwargs explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
p-snft committed Dec 14, 2023
1 parent 1fe963e commit e449b52
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 31 deletions.
43 changes: 31 additions & 12 deletions src/oemof/network/energy_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import blinker
import dill as pickle

import warnings

from oemof.network.groupings import DEFAULT as BY_UID
from oemof.network.groupings import Entities
from oemof.network.groupings import Grouping
Expand Down Expand Up @@ -134,24 +136,41 @@ class EnergySystem:
.. _blinker: https://blinker.readthedocs.io/en/stable/
"""

def __init__(self, **kwargs):
def __init__(
self,
*,
groupings=None,
results=None,
timeindex=None,
timeincrement=None,
temporal=None,
nodes=None,
entities=None,
):
if groupings is None:
groupings = []
if entities is not None:
warnings.warn(
"Parameter 'entities' is deprecated, use 'nodes'"
+ " instead. Will overwrite nodes.",
FutureWarning,
)
nodes = entities
if nodes is None:
nodes = []

self._first_ungrouped_node_index_ = 0
self._groups = {}
self._groupings = [BY_UID] + [
g if isinstance(g, Grouping) else Entities(g)
for g in kwargs.get("groupings", [])
g if isinstance(g, Grouping) else Entities(g) for g in groupings
]
self._nodes = {}

self.results = kwargs.get("results")

self.timeindex = kwargs.get("timeindex")

self.timeincrement = kwargs.get("timeincrement", None)

self.temporal = kwargs.get("temporal")

self.add(*kwargs.get("entities", ()))
self.results = results
self.timeindex = timeindex
self.timeincrement = timeincrement
self.temporal = temporal
self.add(*nodes)

def add(self, *nodes):
"""Add :class:`nodes <oemof.network.Node>` to this energy system."""
Expand Down
5 changes: 2 additions & 3 deletions src/oemof/network/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ def create_nx_graph(
>>> import oemof.network.graph as grph
>>> datetimeindex = pd.date_range('1/1/2017', periods=3, freq='H')
>>> es = EnergySystem(timeindex=datetimeindex)
>>> b_gas = Bus(label='b_gas', balanced=False)
>>> b_gas = Bus(label='b_gas')
>>> bel1 = Bus(label='bel1')
>>> bel2 = Bus(label='bel2')
>>> demand_el = Sink(label='demand_el', inputs = [bel1])
>>> pp_gas = Transformer(label=('pp', 'gas'),
... inputs=[b_gas],
... outputs=[bel1],
... conversion_factors={bel1: 0.5})
... outputs=[bel1])
>>> line_to2 = Transformer(label='line_to2', inputs=[bel1], outputs=[bel2])
>>> line_from2 = Transformer(label='line_from2',
... inputs=[bel2], outputs=[bel1])
Expand Down
2 changes: 1 addition & 1 deletion src/oemof/network/groupings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, key=None, constant_key=None, filter=None, **kwargs):
+ "one of `key` or `constant_key`."
)
self.filter = filter
for kw in ["value", "merge", "filter"]:
for kw in ["value", "merge"]:
if kw in kwargs:
setattr(self, kw, kwargs[kw])

Expand Down
12 changes: 7 additions & 5 deletions src/oemof/network/network/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

from .entity import Entity

EdgeLabel = namedtuple("EdgeLabel", ["input", "output"])


class Edge(Entity):
"""
Expand All @@ -41,15 +39,16 @@ class Edge(Entity):
name.
"""

Label = EdgeLabel
Label = namedtuple("EdgeLabel", ["input", "output"])

def __init__(
self,
input_node=None,
output_node=None,
flow=None,
values=None,
**kwargs,
*,
custom_properties=None,
):
if flow is not None and values is not None:
raise ValueError(
Expand All @@ -60,7 +59,10 @@ def __init__(
f" `values`: {values}\n"
"Choose one."
)
super().__init__(label=Edge.Label(input_node, output_node))
super().__init__(
label=Edge.Label(input_node, output_node),
custom_properties=custom_properties,
)
self.values = values if values is not None else flow
if input_node is not None and output_node is not None:
input_node.outputs[output_node] = self
Expand Down
2 changes: 1 addition & 1 deletion src/oemof/network/network/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Entity:
to easily attach custom information to any Entity.
"""

def __init__(self, label=None, *, custom_properties=None, **kwargs):
def __init__(self, label=None, *, custom_properties=None):
self._label = label
if custom_properties is None:
custom_properties = {}
Expand Down
24 changes: 18 additions & 6 deletions src/oemof/network/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,43 @@ class Node(Entity):
A dictionary mapping output nodes to corresponding outflows.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
label=None,
*,
inputs=None,
outputs=None,
custom_properties=None,
):
super().__init__(label=label, custom_properties=custom_properties)

self._inputs = Inputs(self)
self._outputs = Outputs(self)
self._in_edges = set()

if inputs is None:
inputs = {}
if outputs is None:
outputs = {}

msg = "{} {!r} of {!r} not an instance of Node but of {}."

for i in kwargs.get("inputs", {}):
for i in inputs:
if not isinstance(i, Node):
raise ValueError(msg.format("Input", i, self, type(i)))
self._in_edges.add(i)
try:
flow = kwargs["inputs"].get(i)
flow = inputs.get(i)
except AttributeError:
flow = None
edge = Edge.from_object(flow)
edge.input = i
edge.output = self
for o in kwargs.get("outputs", {}):
for o in outputs:
if not isinstance(o, Node):
raise ValueError(msg.format("Output", o, self, type(o)))
try:
flow = kwargs["outputs"].get(o)
flow = outputs.get(o)
except AttributeError:
flow = None
edge = Edge.from_object(flow)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_groupings/test_groupings_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def test_entity_grouping_on_construction():
bus = Bus(label="test bus")
ensys = es.EnergySystem(entities=[bus])
ensys = es.EnergySystem(nodes=[bus])
assert ensys.groups[bus.label] is bus


Expand All @@ -37,8 +37,8 @@ def by_uid(n):

ensys = es.EnergySystem(groupings=[by_uid])

ungrouped = [Node(uid="Not in 'Group': {}".format(i)) for i in range(10)]
grouped = [Node(uid="In 'Group': {}".format(i)) for i in range(10)]
ungrouped = [Node(label="Not in 'Group': {}".format(i)) for i in range(10)]
grouped = [Node(label="In 'Group': {}".format(i)) for i in range(10)]
assert None not in ensys.groups
for g in ensys.groups.values():
for e in ungrouped:
Expand Down

0 comments on commit e449b52

Please sign in to comment.