Skip to content

Commit

Permalink
Merge pull request #179 from arcondello/feature/_graph.pyx-2
Browse files Browse the repository at this point in the history
Make `Model` a Python class by moving Cython parts into a new `_Graph` class
  • Loading branch information
arcondello authored Dec 5, 2024
2 parents b8d10fc + ea7a6f3 commit 0eb1ccc
Show file tree
Hide file tree
Showing 14 changed files with 719 additions and 647 deletions.
86 changes: 86 additions & 0 deletions dwave/optimization/_model.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 D-Wave Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from libc.stdint cimport uintptr_t
from libcpp cimport bool
from libcpp.memory cimport shared_ptr
from libcpp.vector cimport vector

from dwave.optimization.libcpp.graph cimport ArrayNode as cppArrayNode, Node as cppNode
from dwave.optimization.libcpp.graph cimport Graph as cppGraph
from dwave.optimization.libcpp.state cimport State as cppState

__all__ = []


cdef class _Graph:
cpdef bool is_locked(self) noexcept
cpdef Py_ssize_t num_constraints(self) noexcept
cpdef Py_ssize_t num_decisions(self) noexcept
cpdef Py_ssize_t num_nodes(self) noexcept
cpdef Py_ssize_t num_symbols(self) noexcept

# Make the _Graph class weak referenceable
cdef object __weakref__

cdef cppGraph _graph

# The number of times "lock()" has been called.
cdef readonly Py_ssize_t _lock_count

# Used to keep NumPy arrays that own data alive etc etc
# We could pair each of these with an expired_ptr for the node holding
# memory for easier cleanup later if that becomes a concern.
cdef object _data_sources


cdef class Symbol:
# Inheriting nodes must call this method from their __init__()
cdef void initialize_node(self, _Graph model, cppNode* node_ptr) noexcept

cpdef uintptr_t id(self) noexcept

# Exactly deref(self.expired_ptr)
cpdef bool expired(self) noexcept

@staticmethod
cdef Symbol from_ptr(_Graph model, cppNode* ptr)

# Hold on to a reference to the _Graph, both for access but also, importantly,
# to ensure that the model doesn't get garbage collected unless all of
# the observers have also been garbage collected.
cdef readonly _Graph model

# Hold Node* pointer. This is redundant as most observers will also hold
# a pointer to their observed node with the correct type. But the cost
# of a redundant pointer is quite small for these Python objects and it
# simplifies things quite a bit.
cdef cppNode* node_ptr

# The node's expired flag. If the node is destructed, the boolean value
# pointed to by the expired_ptr will be set to True
cdef shared_ptr[bool] expired_ptr


# Ideally this wouldn't subclass Symbol, but Cython only allows a single
# extension base class, so to support that we assume all ArraySymbols are
# also Symbols (probably a fair assumption)
cdef class ArraySymbol(Symbol):
# Inheriting symbols must call this method from their __init__()
cdef void initialize_arraynode(self, _Graph model, cppArrayNode* array_ptr) noexcept

# Hold ArrayNode* pointer. Again this is redundant, because we're also holding
# a pointer to Node* and we can theoretically dynamic cast each time.
# But again, it's cheap and it simplifies things.
cdef cppArrayNode* array_ptr
61 changes: 7 additions & 54 deletions dwave/optimization/model.pyi → dwave/optimization/_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,51 +18,29 @@ import tempfile
import typing

import numpy
import numpy.typing

from dwave.optimization.states import States
from dwave.optimization.symbols import *


_ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]]

_GraphSubclass = typing.TypeVar("_GraphSubclass", bound="_Graph")

class Model:
def __init__(self): ...

@property
def objective(self) -> ArraySymbol: ...
@property
def states(self) -> States: ...
class _Graph:
def __init__(self, *args, **kwargs) -> typing.NoReturn: ...

def add_constraint(self, value: ArraySymbol) -> ArraySymbol: ...
def binary(self, shape: typing.Optional[_ShapeLike] = None) -> BinaryVariable: ...
def constant(self, array_like: numpy.typing.ArrayLike) -> Constant: ...
def decision_state_size(self) -> int: ...

def disjoint_bit_sets(
self, primary_set_size: int, num_disjoint_sets: int,
) -> tuple[DisjointBitSets, tuple[DisjointBitSet, ...]]: ...

def disjoint_lists(
self, primary_set_size: int, num_disjoint_lists: int,
) -> tuple[DisjointLists, tuple[DisjointList, ...]]: ...

def feasible(self, index: int = 0) -> bool: ...

@classmethod
def from_file(
cls,
cls: typing.Type[_GraphSubclass],
file: typing.Union[typing.BinaryIO, collections.abc.ByteString, str],
*,
check_header: bool = True,
) -> Model: ...

def integer(
self,
shape: typing.Optional[_ShapeLike] = None,
lower_bound: typing.Optional[int] = None,
upper_bound: typing.Optional[int] = None,
) -> IntegerVariable: ...
) -> _GraphSubclass: ...

def into_file(
self,
Expand All @@ -76,39 +54,14 @@ class Model:
def iter_constraints(self) -> collections.abc.Iterator[ArraySymbol]: ...
def iter_decisions(self) -> collections.abc.Iterator[Symbol]: ...
def iter_symbols(self) -> collections.abc.Iterator[Symbol]: ...
def list(self, n: int) -> ListVariable: ...
def lock(self) -> contextlib.AbstractContextManager: ...
def lock(self): ...
def minimize(self, value: ArraySymbol): ...
def num_constraints(self) -> int: ...
def num_decisions(self) -> int: ...
def num_nodes(self) -> int: ...
def num_symbols(self) -> int: ...

# dev note: this is underspecified, but it would be quite complex to fully
# specify the linear/quadratic so let's leave it alone for now.
def quadratic_model(self, x: ArraySymbol, quadratic, linear=None) -> QuadraticModel: ...

def remove_unused_symbols(self) -> int: ...

def set(
self,
n: int,
min_size: int = 0,
max_size: typing.Optional[int] = None,
) -> SetVariable: ...

def state_size(self) -> int: ...

def to_file(
self,
*,
max_num_states: int = 0,
only_decision: bool = False,
) -> typing.BinaryIO: ...

# networkx might not be installed, so we just say we return an object.
def to_networkx(self) -> object: ...

def unlock(self): ...


Expand Down
Loading

0 comments on commit 0eb1ccc

Please sign in to comment.