From 73f3f664873cff81590cc0dd66fdd7645fa46deb Mon Sep 17 00:00:00 2001 From: Alexander Condello Date: Tue, 19 Nov 2024 12:51:59 -0800 Subject: [PATCH] Refactor Symbol iteration methods --- dwave/optimization/libcpp/graph.pxd | 38 +++++++++++++++++++---------- dwave/optimization/model.pyx | 27 ++++++++------------ dwave/optimization/symbols.pyx | 8 +++--- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/dwave/optimization/libcpp/graph.pxd b/dwave/optimization/libcpp/graph.pxd index 251aaaff..6c9d561f 100644 --- a/dwave/optimization/libcpp/graph.pxd +++ b/dwave/optimization/libcpp/graph.pxd @@ -19,15 +19,38 @@ from libcpp cimport bool from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.vector cimport vector -from dwave.optimization.libcpp.array cimport Array, span +from dwave.optimization.libcpp cimport span +from dwave.optimization.libcpp.array cimport Array from dwave.optimization.libcpp.state cimport State +cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil: + cdef cppclass Node: + struct SuccessorView: + Node* ptr + shared_ptr[bool] expired_ptr() const + const vector[Node*]& predecessors() const + const vector[SuccessorView]& successors() const + Py_ssize_t topological_index() + + cdef cppclass ArrayNode(Node, Array): + pass + + cdef cppclass DecisionNode(Node): + pass + +# Sometimes Cython isn't able to reason about pointers as template inputs, so +# we make a few aliases for convenience +ctypedef Node* NodePtr +ctypedef ArrayNode* ArrayNodePtr +ctypedef DecisionNode* DecisionNodePtr + cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil: cdef cppclass Graph: T* emplace_node[T](...) except+ void initialize_state(State&) except+ span[const unique_ptr[Node]] nodes() const - span[ArrayNode*] constraints() const + span[const ArrayNodePtr] constraints() + span[const DecisionNodePtr] decisions() Py_ssize_t num_nodes() Py_ssize_t num_decisions() Py_ssize_t num_constraints() @@ -41,14 +64,3 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" void topological_sort() bool topologically_sorted() const Py_ssize_t remove_unused_nodes() - - cdef cppclass Node: - struct SuccessorView: - Node* ptr - shared_ptr[bool] expired_ptr() const - const vector[Node*]& predecessors() const - const vector[SuccessorView]& successors() const - Py_ssize_t topological_index() - - cdef cppclass ArrayNode(Node, Array): - pass diff --git a/dwave/optimization/model.pyx b/dwave/optimization/model.pyx index 3df7aa26..43e666a3 100644 --- a/dwave/optimization/model.pyx +++ b/dwave/optimization/model.pyx @@ -45,6 +45,7 @@ from libcpp.utility cimport move from libcpp.vector cimport vector from dwave.optimization.libcpp.array cimport Array as cppArray +from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode from dwave.optimization.symbols cimport symbol_from_ptr @@ -573,8 +574,8 @@ cdef class Model: >>> constraints = next(model.iter_constraints()) """ - for i in range(self._graph.num_constraints()): - yield symbol_from_ptr(self, self._graph.constraints()[i]) + for ptr in self._graph.constraints(): + yield symbol_from_ptr(self, ptr) def iter_decisions(self): """Iterate over all decision variables in the model. @@ -590,19 +591,8 @@ cdef class Model: >>> decisions = next(model.iter_decisions()) """ - cdef Py_ssize_t num_decisions = self.num_decisions() - cdef Py_ssize_t seen_decisions = 0 - - cdef Symbol symbol - for symbol in self.iter_symbols(): - if 0 <= symbol.node_ptr.topological_index() < num_decisions: - # we found a decision! - yield symbol - seen_decisions += 1 - - if seen_decisions >= num_decisions: - # we found them all - return + for ptr in self._graph.decisions(): + yield symbol_from_ptr(self, ptr) def iter_symbols(self): """Iterate over all symbols in the model. @@ -616,8 +606,11 @@ cdef class Model: >>> c = model.constant([[2, 3], [5, 6]]) >>> symbol_1, symbol_2 = model.iter_symbols() """ - for i in range(self._graph.num_nodes()): - yield symbol_from_ptr(self, self._graph.nodes()[i].get()) + # Because nodes() is a span of unique_ptr, we can't just iterate over + # it cythonically. Cython would try to do a copy/move assignment. + nodes = self._graph.nodes() + for i in range(nodes.size()): + yield symbol_from_ptr(self, nodes[i].get()) def list(self, n : int): """Create a list symbol as a decision variable. diff --git a/dwave/optimization/symbols.pyx b/dwave/optimization/symbols.pyx index 23322cbd..501d9f21 100644 --- a/dwave/optimization/symbols.pyx +++ b/dwave/optimization/symbols.pyx @@ -41,7 +41,11 @@ from dwave.optimization.libcpp.array cimport ( SizeInfo as cppSizeInfo, Slice as cppSlice, ) -from dwave.optimization.libcpp.graph cimport ArrayNode as cppArrayNode, Node as cppNode +from dwave.optimization.libcpp.graph cimport ( + ArrayNode as cppArrayNode, + ArrayNodePtr as cppArrayNodePtr, + Node as cppNode, + ) from dwave.optimization.libcpp.nodes cimport ( AbsoluteNode as cppAbsoluteNode, AddNode as cppAddNode, @@ -92,8 +96,6 @@ from dwave.optimization.libcpp.nodes cimport ( ) from dwave.optimization.model cimport ArraySymbol, Model, Symbol -ctypedef cppArrayNode* cppArrayNodePtr # Cython gets confused when templating pointers - __all__ = [ "Absolute", "Add",