Skip to content

Commit

Permalink
Merge pull request #49 from kevinchern/fix/maybe_equals
Browse files Browse the repository at this point in the history
Make `maybe_equals` consistently return integers
  • Loading branch information
arcondello authored Jul 11, 2024
2 parents 1d5a0ce + 6e954c4 commit aaed22e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 36 deletions.
54 changes: 25 additions & 29 deletions dwave/optimization/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1359,16 +1359,33 @@ cdef class Symbol:

def maybe_equals(self, other):
"""Compare to another node.
This method exists because a complete equality test can be expensive.
Args:
other: Another node in the model's directed acyclic graph.
Returns: integer
Supported return values are:
* ``0``---Not equal.
* ``1``---Might be equal.
* ``2``---Are equal.
* ``0``---Not equal (with certainty)
* ``1``---Might be equal (no guarantees); a complete equality test is necessary
* ``2``---Are equal (with certainty)
Examples:
This example compares
:class:`~dwave.optimization.symbols.IntegerVariable` symbols
of different sizes.
>>> from dwave.optimization import Model
>>> model = Model()
>>> i = model.integer(3, lower_bound=0, upper_bound=20)
>>> j = model.integer(3, lower_bound=-10, upper_bound=10)
>>> k = model.integer(5, upper_bound=55)
>>> i.maybe_equals(j)
1
>>> i.maybe_equals(k)
0
"""
cdef Py_ssize_t NOT = 0
cdef Py_ssize_t MAYBE = 1
Expand Down Expand Up @@ -1658,35 +1675,14 @@ cdef class ArraySymbol(Symbol):
return Max(self)

def maybe_equals(self, other):
"""Compare to another symbol.
Args:
other: Another symbol in the model.
Returns:
True if the two symbols might be equal.
Examples:
This example compares
:class:`~dwave.optimization.symbols.IntegerVariable` symbols
of different sizes.
>>> from dwave.optimization import Model
>>> model = Model()
>>> i = model.integer(3, lower_bound=0, upper_bound=20)
>>> j = model.integer(3, lower_bound=-10, upper_bound=10)
>>> k = model.integer(5, upper_bound=55)
>>> i.maybe_equals(j)
1
>>> i.maybe_equals(k)
0
"""
# note: docstring inherited from Symbol.maybe_equal()
cdef Py_ssize_t maybe = super().maybe_equals(other)
if maybe != 1:
return True if maybe else False

cdef Py_ssize_t NOT = 0
cdef Py_ssize_t MAYBE = 1
cdef Py_ssize_t DEFINITELY = 2

if maybe != 1:
return DEFINITELY if maybe else NOT

if not isinstance(other, ArraySymbol):
return NOT
Expand Down
9 changes: 6 additions & 3 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -858,13 +858,16 @@ cdef class Constant(ArraySymbol):

def maybe_equals(self, other):
cdef Py_ssize_t maybe = super().maybe_equals(other)
if maybe != 1:
return True if maybe else False
cdef Py_ssize_t NOT = 0
cdef Py_ssize_t MAYBE = 1
cdef Py_ssize_t DEFINITELY = 2
if maybe != MAYBE:
return DEFINITELY if maybe else NOT

# avoid NumPy deprecation warning by casting to bool. But also
# `bool` in this namespace is a C++ class so we do an explicit if else
equal = (np.asarray(self) == np.asarray(other)).all()
return True if equal else False
return DEFINITELY if equal else NOT

def state(self, Py_ssize_t index=0, *, bool copy = True):
"""Return the state of the constant symbol.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
fixes:
- Fix return type of ``Symbol.maybe_equals()`` to be integer instead of boolean. See `#23 <https://github.com/dwavesystems/dwave-optimization/issues/23>`_.
10 changes: 6 additions & 4 deletions tests/test_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,19 @@ def assertLessEqual(self, *args, **kwargs): ...
def assertTrue(self, *args, **kwargs): ...

def test_equality(self):
DEFINITELY = 2
for x in self.generate_symbols():
self.assertTrue(x.maybe_equals(x))
self.assertEqual(DEFINITELY, x.maybe_equals(x))
self.assertTrue(x.equals(x))

for x, y in zip(self.generate_symbols(), self.generate_symbols()):
self.assertTrue(x.maybe_equals(y))
self.assertTrue(DEFINITELY, x.maybe_equals(y))
self.assertTrue(x.equals(y))

def test_inequality(self):
MAYBE = 1
for x, y in itertools.combinations(self.generate_symbols(), 2):
self.assertLessEqual(x.maybe_equals(y), 1)
self.assertLessEqual(x.maybe_equals(y), MAYBE)
self.assertFalse(x.equals(y))

def test_iter_symbols(self):
Expand Down Expand Up @@ -115,7 +117,7 @@ def test_state_serialization(self):
states.append(sym.state())
else:
states.append(None)

with model.states.to_file() as f:
new = Model.from_file(f)

Expand Down

0 comments on commit aaed22e

Please sign in to comment.