From f3c1c8603ac1fe16e9ff3c9fadac2a44cc2e7b3f Mon Sep 17 00:00:00 2001 From: Alexander Condello Date: Wed, 4 Dec 2024 12:33:50 -0800 Subject: [PATCH] Add type annotations to model.py --- dwave/optimization/_graph.pyi | 1 - dwave/optimization/model.py | 59 ++++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/dwave/optimization/_graph.pyi b/dwave/optimization/_graph.pyi index 0b94fbb1..7e9d0f57 100644 --- a/dwave/optimization/_graph.pyi +++ b/dwave/optimization/_graph.pyi @@ -18,7 +18,6 @@ import tempfile import typing import numpy -import numpy.typing from dwave.optimization.states import States from dwave.optimization.symbols import * diff --git a/dwave/optimization/model.py b/dwave/optimization/model.py index a55ac2f8..6f9f7388 100644 --- a/dwave/optimization/model.py +++ b/dwave/optimization/model.py @@ -23,6 +23,9 @@ packed. """ +from __future__ import annotations + +import collections import contextlib import tempfile import typing @@ -30,11 +33,18 @@ from dwave.optimization._graph import ArraySymbol, _Graph, Symbol from dwave.optimization.states import States +if typing.TYPE_CHECKING: + import numpy.typing + + from dwave.optimization.symbols import * + + _ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]] + __all__ = ["Model"] @contextlib.contextmanager -def locked(model): +def locked(model: _Graph): """Context manager that hold a locked model and unlocks it when the context is exited.""" try: yield @@ -101,7 +111,7 @@ def __init__(self): self.objective = None self.states = States(self) - def binary(self, shape=None): + def binary(self, shape: typing.Optional[_ShapeLike] = None) -> BinaryVariable: r"""Create a binary symbol as a decision variable. Args: @@ -120,7 +130,7 @@ def binary(self, shape=None): from dwave.optimization.symbols import BinaryVariable # avoid circular import return BinaryVariable(self, shape) - def constant(self, array_like): + def constant(self, array_like: numpy.typing.ArrayLike) -> Constant: r"""Create a constant symbol. Args: @@ -142,7 +152,11 @@ def constant(self, array_like): from dwave.optimization.symbols import Constant # avoid circular import return Constant(self, array_like) - def disjoint_bit_sets(self, primary_set_size, num_disjoint_sets): + def disjoint_bit_sets( + self, + primary_set_size: int, + num_disjoint_sets: int, + ) -> tuple[DisjointBitSets, tuple[DisjointBitSet, ...]]: """Create a disjoint-sets symbol as a decision variable. Divides a set of the elements of ``range(primary_set_size)`` into @@ -171,13 +185,18 @@ def disjoint_bit_sets(self, primary_set_size, num_disjoint_sets): >>> model = Model() >>> parts_set, parts_subsets = model.disjoint_bit_sets(10, 4) """ + # avoid circular import + from dwave.optimization.symbols import DisjointBitSets, DisjointBitSet - from dwave.optimization.symbols import DisjointBitSets, DisjointBitSet # avoid circular import main = DisjointBitSets(self, primary_set_size, num_disjoint_sets) sets = tuple(DisjointBitSet(main, i) for i in range(num_disjoint_sets)) return main, sets - def disjoint_lists(self, primary_set_size, num_disjoint_lists): + def disjoint_lists( + self, + primary_set_size: int, + num_disjoint_lists: int, + ) -> tuple[DisjointLists, tuple[DisjointList, ...]]: """Create a disjoint-lists symbol as a decision variable. Divides a set of the elements of ``range(primary_set_size)`` into @@ -208,7 +227,7 @@ def disjoint_lists(self, primary_set_size, num_disjoint_lists): lists = [DisjointList(main, i) for i in range(num_disjoint_lists)] return main, lists - def feasible(self, index: int = 0): + def feasible(self, index: int = 0) -> bool: """Check the feasibility of the state at the input index. Args: @@ -238,7 +257,12 @@ def feasible(self, index: int = 0): """ return all(sym.state(index) for sym in self.iter_constraints()) - def integer(self, shape=None, lower_bound=None, upper_bound=None): + def integer( + self, + shape: typing.Optional[_ShapeLike] = None, + lower_bound: typing.Optional[int] = None, + upper_bound: typing.Optional[int] = None, + ) -> IntegerVariable: r"""Create an integer symbol as a decision variable. Args: @@ -264,7 +288,7 @@ def integer(self, shape=None, lower_bound=None, upper_bound=None): from dwave.optimization.symbols import IntegerVariable # avoid circular import return IntegerVariable(self, shape, lower_bound, upper_bound) - def list(self, n: int): + def list(self, n: int) -> ListVariable: """Create a list symbol as a decision variable. Args: @@ -283,7 +307,7 @@ def list(self, n: int): from dwave.optimization.symbols import ListVariable # avoid circular import return ListVariable(self, n) - def lock(self): + def lock(self) -> contextlib.AbstractContextManager: """Lock the model. No new symbols can be added to a locked model. @@ -327,7 +351,9 @@ def minimize(self, value: ArraySymbol): super().minimize(value) self.objective = value - def quadratic_model(self, x, quadratic, linear=None): + # dev note: the typing is underspecified, but it would be quite complex to fully + # specify the linear/quadratic so let's leave it alone for now. + def quadratic_model(self, x: ArraySymbol, quadratic, linear=None) -> QuadraticModel: """Create a quadratic model from an array and a quadratic model. Args: @@ -353,7 +379,11 @@ def quadratic_model(self, x, quadratic, linear=None): from dwave.optimization.symbols import QuadraticModel return QuadraticModel(x, quadratic, linear) - def set(self, n, min_size=0, max_size=None): + def set(self, + n: int, + min_size: int = 0, + max_size: typing.Optional[int] = None, + ) -> SetVariable: """Create a set symbol as a decision variable. Args: @@ -375,7 +405,7 @@ def set(self, n, min_size=0, max_size=None): from dwave.optimization.symbols import SetVariable # avoid circular import return SetVariable(self, n, min_size, n if max_size is None else max_size) - def to_file(self, **kwargs): + def to_file(self, **kwargs) -> typing.BinaryIO: """Serialize the model to a new file-like object. See also: @@ -386,7 +416,8 @@ def to_file(self, **kwargs): file.seek(0) return file - def to_networkx(self): + # NetworkX might not be installed so we just say we return an object + def to_networkx(self) -> object: """Convert the model to a NetworkX graph. Returns: