Skip to content

Commit

Permalink
Merge pull request #35 from oemof/feature/node_dict
Browse files Browse the repository at this point in the history
Make energy_system.nodes a dict
  • Loading branch information
p-snft authored Dec 14, 2023
2 parents 4f5973f + 55fce7d commit 056b0df
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Changelog

* Improved code quality
* Add Entity.custom_properties
* Simplify node access (energy_system.nodes[label])
14 changes: 7 additions & 7 deletions src/oemof/network/energy_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, **kwargs):
g if isinstance(g, Grouping) else Entities(g)
for g in kwargs.get("groupings", [])
]
self._nodes = []
self._nodes = {}

self.results = kwargs.get("results")

Expand All @@ -155,7 +155,7 @@ def __init__(self, **kwargs):

def add(self, *nodes):
"""Add :class:`nodes <oemof.network.Node>` to this energy system."""
self.nodes.extend(nodes)
self._nodes.update({node.label: node for node in nodes})
for n in nodes:
self.signals[type(self).add].send(n, EnergySystem=self)

Expand All @@ -168,20 +168,20 @@ def groups(self):
(
g(n, gs)
for g in self._groupings
for n in self.nodes[self._first_ungrouped_node_index_ :]
for n in list(self.nodes)[self._first_ungrouped_node_index_ :]
),
maxlen=0,
)
self._first_ungrouped_node_index_ = len(self.nodes)
return self._groups

@property
def nodes(self):
def node(self):
return self._nodes

@nodes.setter
def nodes(self, value):
self._nodes = value
@property
def nodes(self):
return self._nodes.values()

def flows(self):
return {
Expand Down
8 changes: 5 additions & 3 deletions src/oemof/network/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def create_nx_graph(
grph = nx.DiGraph()

# add nodes
for n in energy_system.nodes:
grph.add_node(str(n.label), label=str(n.label))
for label in energy_system.node.keys():
grph.add_node(str(label), label=str(label))

# add labeled flows on directed edge if an optimization_model has been
# passed or undirected edge otherwise
Expand All @@ -125,7 +125,9 @@ def create_nx_graph(
if remove_nodes_with_substrings is not None:
for i in remove_nodes_with_substrings:
remove_nodes = [
str(v.label) for v in energy_system.nodes if i in str(v.label)
str(label)
for label in energy_system.node.keys()
if i in str(label)
]
grph.remove_nodes_from(remove_nodes)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_network_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,13 @@ def setup_method(self):
def test_entity_registration(self):
n1 = Node(label="<B1>")
self.es.add(n1)
assert self.es.nodes[0] == n1
assert self.es.node["<B1>"] == n1
n2 = Node(label="<B2>")
self.es.add(n2)
assert self.es.nodes[1] == n2
assert self.es.node["<B2>"] == n2
n3 = Node(label="<TF1>", inputs=[n1], outputs=[n2])
self.es.add(n3)
assert n3 in self.es.nodes
assert self.es.node["<TF1>"] == n3


def test_deprecated_classes():
Expand Down

0 comments on commit 056b0df

Please sign in to comment.