Skip to content

Commit

Permalink
Refactor Symbol iteration methods
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Nov 19, 2024
1 parent d24acaf commit 73f3f66
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 33 deletions.
38 changes: 25 additions & 13 deletions dwave/optimization/libcpp/graph.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,38 @@ from libcpp cimport bool
from libcpp.memory cimport shared_ptr, unique_ptr
from libcpp.vector cimport vector

from dwave.optimization.libcpp.array cimport Array, span
from dwave.optimization.libcpp cimport span
from dwave.optimization.libcpp.array cimport Array
from dwave.optimization.libcpp.state cimport State

cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil:
cdef cppclass Node:
struct SuccessorView:
Node* ptr
shared_ptr[bool] expired_ptr() const
const vector[Node*]& predecessors() const
const vector[SuccessorView]& successors() const
Py_ssize_t topological_index()

cdef cppclass ArrayNode(Node, Array):
pass

cdef cppclass DecisionNode(Node):
pass

# Sometimes Cython isn't able to reason about pointers as template inputs, so
# we make a few aliases for convenience
ctypedef Node* NodePtr
ctypedef ArrayNode* ArrayNodePtr
ctypedef DecisionNode* DecisionNodePtr

cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil:
cdef cppclass Graph:
T* emplace_node[T](...) except+
void initialize_state(State&) except+
span[const unique_ptr[Node]] nodes() const
span[ArrayNode*] constraints() const
span[const ArrayNodePtr] constraints()
span[const DecisionNodePtr] decisions()
Py_ssize_t num_nodes()
Py_ssize_t num_decisions()
Py_ssize_t num_constraints()
Expand All @@ -41,14 +64,3 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization"
void topological_sort()
bool topologically_sorted() const
Py_ssize_t remove_unused_nodes()

cdef cppclass Node:
struct SuccessorView:
Node* ptr
shared_ptr[bool] expired_ptr() const
const vector[Node*]& predecessors() const
const vector[SuccessorView]& successors() const
Py_ssize_t topological_index()

cdef cppclass ArrayNode(Node, Array):
pass
27 changes: 10 additions & 17 deletions dwave/optimization/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ from libcpp.utility cimport move
from libcpp.vector cimport vector

from dwave.optimization.libcpp.array cimport Array as cppArray
from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode
from dwave.optimization.symbols cimport symbol_from_ptr


Expand Down Expand Up @@ -573,8 +574,8 @@ cdef class Model:
<dwave.optimization.symbols.LessEqual at ...>
>>> constraints = next(model.iter_constraints())
"""
for i in range(self._graph.num_constraints()):
yield symbol_from_ptr(self, self._graph.constraints()[i])
for ptr in self._graph.constraints():
yield symbol_from_ptr(self, ptr)

def iter_decisions(self):
"""Iterate over all decision variables in the model.
Expand All @@ -590,19 +591,8 @@ cdef class Model:
<dwave.optimization.symbols.LessEqual at ...>
>>> decisions = next(model.iter_decisions())
"""
cdef Py_ssize_t num_decisions = self.num_decisions()
cdef Py_ssize_t seen_decisions = 0

cdef Symbol symbol
for symbol in self.iter_symbols():
if 0 <= symbol.node_ptr.topological_index() < num_decisions:
# we found a decision!
yield symbol
seen_decisions += 1

if seen_decisions >= num_decisions:
# we found them all
return
for ptr in self._graph.decisions():
yield symbol_from_ptr(self, ptr)

def iter_symbols(self):
"""Iterate over all symbols in the model.
Expand All @@ -616,8 +606,11 @@ cdef class Model:
>>> c = model.constant([[2, 3], [5, 6]])
>>> symbol_1, symbol_2 = model.iter_symbols()
"""
for i in range(self._graph.num_nodes()):
yield symbol_from_ptr(self, self._graph.nodes()[i].get())
# Because nodes() is a span of unique_ptr, we can't just iterate over
# it cythonically. Cython would try to do a copy/move assignment.
nodes = self._graph.nodes()
for i in range(nodes.size()):
yield symbol_from_ptr(self, nodes[i].get())

def list(self, n : int):
"""Create a list symbol as a decision variable.
Expand Down
8 changes: 5 additions & 3 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ from dwave.optimization.libcpp.array cimport (
SizeInfo as cppSizeInfo,
Slice as cppSlice,
)
from dwave.optimization.libcpp.graph cimport ArrayNode as cppArrayNode, Node as cppNode
from dwave.optimization.libcpp.graph cimport (
ArrayNode as cppArrayNode,
ArrayNodePtr as cppArrayNodePtr,
Node as cppNode,
)
from dwave.optimization.libcpp.nodes cimport (
AbsoluteNode as cppAbsoluteNode,
AddNode as cppAddNode,
Expand Down Expand Up @@ -92,8 +96,6 @@ from dwave.optimization.libcpp.nodes cimport (
)
from dwave.optimization.model cimport ArraySymbol, Model, Symbol

ctypedef cppArrayNode* cppArrayNodePtr # Cython gets confused when templating pointers

__all__ = [
"Absolute",
"Add",
Expand Down

0 comments on commit 73f3f66

Please sign in to comment.