Skip to content

Commit

Permalink
Add type annotations to model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Dec 4, 2024
1 parent 66e3ac3 commit f3c1c86
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
1 change: 0 additions & 1 deletion dwave/optimization/_graph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import tempfile
import typing

import numpy
import numpy.typing

from dwave.optimization.states import States
from dwave.optimization.symbols import *
Expand Down
59 changes: 45 additions & 14 deletions dwave/optimization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,28 @@
packed.
"""

from __future__ import annotations

import collections
import contextlib
import tempfile
import typing

from dwave.optimization._graph import ArraySymbol, _Graph, Symbol
from dwave.optimization.states import States

if typing.TYPE_CHECKING:
import numpy.typing

from dwave.optimization.symbols import *

_ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]]

__all__ = ["Model"]


@contextlib.contextmanager
def locked(model):
def locked(model: _Graph):
"""Context manager that hold a locked model and unlocks it when the context is exited."""
try:
yield
Expand Down Expand Up @@ -101,7 +111,7 @@ def __init__(self):
self.objective = None
self.states = States(self)

def binary(self, shape=None):
def binary(self, shape: typing.Optional[_ShapeLike] = None) -> BinaryVariable:
r"""Create a binary symbol as a decision variable.
Args:
Expand All @@ -120,7 +130,7 @@ def binary(self, shape=None):
from dwave.optimization.symbols import BinaryVariable # avoid circular import
return BinaryVariable(self, shape)

def constant(self, array_like):
def constant(self, array_like: numpy.typing.ArrayLike) -> Constant:
r"""Create a constant symbol.
Args:
Expand All @@ -142,7 +152,11 @@ def constant(self, array_like):
from dwave.optimization.symbols import Constant # avoid circular import
return Constant(self, array_like)

def disjoint_bit_sets(self, primary_set_size, num_disjoint_sets):
def disjoint_bit_sets(
self,
primary_set_size: int,
num_disjoint_sets: int,
) -> tuple[DisjointBitSets, tuple[DisjointBitSet, ...]]:
"""Create a disjoint-sets symbol as a decision variable.
Divides a set of the elements of ``range(primary_set_size)`` into
Expand Down Expand Up @@ -171,13 +185,18 @@ def disjoint_bit_sets(self, primary_set_size, num_disjoint_sets):
>>> model = Model()
>>> parts_set, parts_subsets = model.disjoint_bit_sets(10, 4)
"""
# avoid circular import
from dwave.optimization.symbols import DisjointBitSets, DisjointBitSet

from dwave.optimization.symbols import DisjointBitSets, DisjointBitSet # avoid circular import
main = DisjointBitSets(self, primary_set_size, num_disjoint_sets)
sets = tuple(DisjointBitSet(main, i) for i in range(num_disjoint_sets))
return main, sets

def disjoint_lists(self, primary_set_size, num_disjoint_lists):
def disjoint_lists(
self,
primary_set_size: int,
num_disjoint_lists: int,
) -> tuple[DisjointLists, tuple[DisjointList, ...]]:
"""Create a disjoint-lists symbol as a decision variable.
Divides a set of the elements of ``range(primary_set_size)`` into
Expand Down Expand Up @@ -208,7 +227,7 @@ def disjoint_lists(self, primary_set_size, num_disjoint_lists):
lists = [DisjointList(main, i) for i in range(num_disjoint_lists)]
return main, lists

def feasible(self, index: int = 0):
def feasible(self, index: int = 0) -> bool:
"""Check the feasibility of the state at the input index.
Args:
Expand Down Expand Up @@ -238,7 +257,12 @@ def feasible(self, index: int = 0):
"""
return all(sym.state(index) for sym in self.iter_constraints())

def integer(self, shape=None, lower_bound=None, upper_bound=None):
def integer(
self,
shape: typing.Optional[_ShapeLike] = None,
lower_bound: typing.Optional[int] = None,
upper_bound: typing.Optional[int] = None,
) -> IntegerVariable:
r"""Create an integer symbol as a decision variable.
Args:
Expand All @@ -264,7 +288,7 @@ def integer(self, shape=None, lower_bound=None, upper_bound=None):
from dwave.optimization.symbols import IntegerVariable # avoid circular import
return IntegerVariable(self, shape, lower_bound, upper_bound)

def list(self, n: int):
def list(self, n: int) -> ListVariable:
"""Create a list symbol as a decision variable.
Args:
Expand All @@ -283,7 +307,7 @@ def list(self, n: int):
from dwave.optimization.symbols import ListVariable # avoid circular import
return ListVariable(self, n)

def lock(self):
def lock(self) -> contextlib.AbstractContextManager:
"""Lock the model.
No new symbols can be added to a locked model.
Expand Down Expand Up @@ -327,7 +351,9 @@ def minimize(self, value: ArraySymbol):
super().minimize(value)
self.objective = value

def quadratic_model(self, x, quadratic, linear=None):
# dev note: the typing is underspecified, but it would be quite complex to fully
# specify the linear/quadratic so let's leave it alone for now.
def quadratic_model(self, x: ArraySymbol, quadratic, linear=None) -> QuadraticModel:
"""Create a quadratic model from an array and a quadratic model.
Args:
Expand All @@ -353,7 +379,11 @@ def quadratic_model(self, x, quadratic, linear=None):
from dwave.optimization.symbols import QuadraticModel
return QuadraticModel(x, quadratic, linear)

def set(self, n, min_size=0, max_size=None):
def set(self,
n: int,
min_size: int = 0,
max_size: typing.Optional[int] = None,
) -> SetVariable:
"""Create a set symbol as a decision variable.
Args:
Expand All @@ -375,7 +405,7 @@ def set(self, n, min_size=0, max_size=None):
from dwave.optimization.symbols import SetVariable # avoid circular import
return SetVariable(self, n, min_size, n if max_size is None else max_size)

def to_file(self, **kwargs):
def to_file(self, **kwargs) -> typing.BinaryIO:
"""Serialize the model to a new file-like object.
See also:
Expand All @@ -386,7 +416,8 @@ def to_file(self, **kwargs):
file.seek(0)
return file

def to_networkx(self):
# NetworkX might not be installed so we just say we return an object
def to_networkx(self) -> object:
"""Convert the model to a NetworkX graph.
Returns:
Expand Down

0 comments on commit f3c1c86

Please sign in to comment.