Skip to content

Commit

Permalink
Merge pull request #172 from mdcoury/add-node-div
Browse files Browse the repository at this point in the history
Add divide node
  • Loading branch information
arcondello authored Dec 4, 2024
2 parents 3fd8dfa + dc74f8d commit b8d10fc
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/reference/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Mathematical Functions

~add
~concatenate
~divide
~logical
~logical_and
~logical_or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class BinaryOpNode : public ArrayOutputMixin<ArrayNode> {
// https://numpy.org/doc/stable/reference/routines.math.html
using AddNode = BinaryOpNode<std::plus<double>>;
using AndNode = BinaryOpNode<std::logical_and<double>>;
using DivideNode = BinaryOpNode<std::divides<double>>;
using EqualNode = BinaryOpNode<std::equal_to<double>>;
using LessEqualNode = BinaryOpNode<std::less_equal<double>>;
using MultiplyNode = BinaryOpNode<std::multiplies<double>>;
Expand Down
3 changes: 3 additions & 0 deletions dwave/optimization/libcpp/nodes.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ cdef extern from "dwave-optimization/nodes/mathematical.hpp" namespace "dwave::o
cdef cppclass AnyNode(ArrayNode):
pass

cdef cppclass DivideNode(ArrayNode):
pass

cdef cppclass EqualNode(ArrayNode):
pass

Expand Down
37 changes: 37 additions & 0 deletions dwave/optimization/mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Add,
And,
Concatenate,
Divide,
Logical,
Maximum,
Minimum,
Expand All @@ -41,6 +42,7 @@
__all__ = [
"add",
"concatenate",
"divide",
"logical",
"logical_and",
"logical_not",
Expand Down Expand Up @@ -154,6 +156,41 @@ def concatenate(array_likes : typing.Union[collections.abc.Iterable, ArraySymbol
raise TypeError("concatenate takes one or more ArraySymbol as input")


def divide(x1: ArraySymbol, x2: ArraySymbol) -> Divide:
r"""Return an element-wise division on the given symbols.
In the underlying directed acyclic expression graph, produces a
``Divide`` node if two array nodes are provided.
Args:
x1, x2: Input array symbol.
Returns:
A symbol that divides the given symbols element-wise.
Dividing two symbols returns a
:class:`~dwave.optimization.symbols.Divide`.
Examples:
This example divides two integer symbols.
Equivalently, you can use the ``/`` operator (e.g., :code:`i / j`).
>>> from dwave.optimization import Model
>>> from dwave.optimization.mathematical import divide
...
>>> model = Model()
>>> i = model.integer(2, lower_bound=1)
>>> j = model.integer(2, lower_bound=1)
>>> k = divide(i, j) # alternatively: k = i / j
>>> with model.lock():
... model.states.resize(1)
... i.set_state(0, [21, 10])
... j.set_state(0, [7, 2])
... print(k.state(0))
[3. 5.]
"""
return Divide(x1, x2)


def logical(x: ArraySymbol) -> Logical:
r"""Return the element-wise truth value on the given symbol.
Expand Down
1 change: 1 addition & 0 deletions dwave/optimization/model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class ArraySymbol(Symbol):
def __neg__(self) -> Negative: ...
def __pow__(self, exponent: int) -> ArraySymbol: ...
def __sub__(self, rhs: ArraySymbol) -> Subtract: ...
def __truediv__(self, rhs: ArraySymbol) -> Divide: ...
def all(self) -> All: ...
def any(self) -> Any: ...
def max(self) -> Max: ...
Expand Down
7 changes: 7 additions & 0 deletions dwave/optimization/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,13 @@ cdef class ArraySymbol(Symbol):

return NotImplemented

def __truediv__(self, rhs):
if isinstance(rhs, ArraySymbol):
from dwave.optimization.symbols import Divide # avoid circular import
return Divide(self, rhs)

return NotImplemented

def all(self):
"""Create an :class:`~dwave.optimization.symbols.All` symbol.
Expand Down
43 changes: 42 additions & 1 deletion dwave/optimization/src/nodes/mathematical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ BinaryOpNode<BinaryOp>::BinaryOpNode(ArrayNode* a_ptr, ArrayNode* b_ptr)
throw std::invalid_argument("arrays must have the same shape or one must be a scalar");
}

if constexpr (std::is_same<BinaryOp, std::divides<double>>::value) {
bool strictly_negative = rhs_ptr->min() < 0 && rhs_ptr->max() < 0;
bool strictly_positive = rhs_ptr->min() > 0 && rhs_ptr->max() > 0;
if (!strictly_negative && !strictly_positive) {
throw std::invalid_argument("Divide's denominator predecessor must be either strictly positive or strictly negative");
}
}

this->add_predecessor(a_ptr);
this->add_predecessor(b_ptr);
}
Expand Down Expand Up @@ -123,6 +131,9 @@ bool BinaryOpNode<BinaryOp>::integral() const {
auto lhs_ptr = operands_[0];
auto rhs_ptr = operands_[1];

if constexpr (std::is_same<BinaryOp, std::divides<double>>::value) {
return false;
}
if constexpr (std::is_same<BinaryOp, functional::max<double>>::value ||
std::is_same<BinaryOp, functional::min<double>>::value ||
std::is_same<BinaryOp, std::minus<double>>::value ||
Expand Down Expand Up @@ -151,6 +162,19 @@ double BinaryOpNode<BinaryOp>::max() const {

// these can result in inf. If we update propagation/initialization to handle
// that case we should update these as well.
if constexpr (std::is_same<BinaryOp, std::divides<double>>::value) {
double lhs_low = lhs_ptr->min();
double lhs_high = lhs_ptr->max();
double rhs_low = rhs_ptr->min();
double rhs_high = rhs_ptr->max();

assert(lhs_low != 0);
assert(lhs_high != 0);
assert(rhs_low != 0);
assert(rhs_high != 0);
return std::max(
{lhs_low / rhs_low, lhs_low / rhs_high, lhs_high / rhs_low, lhs_high / rhs_high});
}
if constexpr (std::is_same<BinaryOp, functional::max<double>>::value ||
std::is_same<BinaryOp, functional::min<double>>::value ||
std::is_same<BinaryOp, std::plus<double>>::value) {
Expand Down Expand Up @@ -194,6 +218,16 @@ double BinaryOpNode<BinaryOp>::min() const {

// these can result in inf. If we update propagation/initialization to handle
// that case we should update these as well.
if constexpr (std::is_same<BinaryOp, std::divides<double>>::value) {
double lhs_low = lhs_ptr->min();
double lhs_high = lhs_ptr->max();
double rhs_low = rhs_ptr->min();
double rhs_high = rhs_ptr->max();

// TODO: How do we want to handle cases where a denominator is zero?
return std::min(
{lhs_low / rhs_low, lhs_low / rhs_high, lhs_high / rhs_low, lhs_high / rhs_high});
}
if constexpr (std::is_same<BinaryOp, functional::max<double>>::value ||
std::is_same<BinaryOp, functional::min<double>>::value ||
std::is_same<BinaryOp, std::plus<double>>::value) {
Expand Down Expand Up @@ -388,7 +422,7 @@ SizeInfo BinaryOpNode<BinaryOp>::sizeinfo() const {
template class BinaryOpNode<std::plus<double>>;
template class BinaryOpNode<std::minus<double>>;
template class BinaryOpNode<std::multiplies<double>>;
// template class BinaryOpNode<std::divides<double>>;
template class BinaryOpNode<std::divides<double>>;
template class BinaryOpNode<functional::modulus<double>>;
template class BinaryOpNode<std::equal_to<double>>;
// template class BinaryOpNode<std::not_equal_to<double>>;
Expand Down Expand Up @@ -421,6 +455,13 @@ struct InverseOp<std::plus<double>> {
double op(const double& x, const double& y) { return x - y; }
};

template <>
struct InverseOp<std::divides<double>> {
static bool constexpr exists() { return true; }

double op(const double& x, const double& y) { return x * y; }
};

template <>
struct InverseOp<std::multiplies<double>> {
static bool constexpr exists() { return true; }
Expand Down
4 changes: 4 additions & 0 deletions dwave/optimization/symbols.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class DisjointList(ArraySymbol):
...


class Divide(ArraySymbol):
...


class Equal(ArraySymbol):
...

Expand Down
40 changes: 40 additions & 0 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ from dwave.optimization.libcpp.nodes cimport (
DisjointBitSetsNode as cppDisjointBitSetsNode,
DisjointListNode as cppDisjointListNode,
DisjointListsNode as cppDisjointListsNode,
DivideNode as cppDivideNode,
EqualNode as cppEqualNode,
IntegerNode as cppIntegerNode,
LessEqualNode as cppLessEqualNode,
Expand Down Expand Up @@ -114,6 +115,7 @@ __all__ = [
"DisjointBitSet",
"DisjointLists",
"DisjointList",
"Divide",
"Equal",
"IntegerVariable",
"LessEqual",
Expand Down Expand Up @@ -1555,6 +1557,44 @@ cdef class DisjointList(ArraySymbol):
_register(DisjointList, typeid(cppDisjointListNode))


cdef class Divide(ArraySymbol):
"""Division element-wise between two symbols.
Examples:
This example divides two integer symbols.
>>> from dwave.optimization.model import Model
>>> model = Model()
>>> i = model.integer(10, lower_bound=-50, upper_bound=-1)
>>> j = model.integer(10, lower_bound=1, upper_bound=10)
>>> k = i/j
>>> type(k)
<class 'dwave.optimization.symbols.Divide'>
"""
def __init__(self, ArraySymbol lhs, ArraySymbol rhs):
if lhs.model is not rhs.model:
raise ValueError("lhs and rhs do not share the same underlying model")

cdef Model model = lhs.model

self.ptr = model._graph.emplace_node[cppDivideNode](lhs.array_ptr, rhs.array_ptr)
self.initialize_arraynode(model, self.ptr)

@staticmethod
def _from_symbol(Symbol symbol):
cdef cppDivideNode* ptr = dynamic_cast_ptr[cppDivideNode](symbol.node_ptr)
if not ptr:
raise TypeError("given symbol cannot be used to construct a Divide")
cdef Divide x = Divide.__new__(Divide)
x.ptr = ptr
x.initialize_arraynode(symbol.model, ptr)
return x

cdef cppDivideNode* ptr

_register(Divide, typeid(cppDivideNode))


cdef class Equal(ArraySymbol):
"""Equality comparison element-wise between two symbols.
Expand Down
7 changes: 7 additions & 0 deletions releasenotes/notes/add-div-node-852711d8127d0cc2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Add C++ ``DivideNode`` and Python ``Divide`` symbol, and overriding
__truediv__. ``Divide`` propagates the division of its predecessors
element-wise. Note that predecessors of Divide must be either
strictly positive or strictly negative.
60 changes: 59 additions & 1 deletion tests/cpp/nodes/mathematical/test_binaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

namespace dwave::optimization {

TEMPLATE_TEST_CASE("BinaryOpNode", "", std::equal_to<double>, std::less_equal<double>,
// NOTE: divides test is disabled because the template-tests have invalid denominators.
TEMPLATE_TEST_CASE("BinaryOpNode", "",
// std::divides<double>,
std::equal_to<double>, std::less_equal<double>,
std::plus<double>, std::minus<double>, functional::modulus<double>,
std::multiplies<double>, functional::max<double>, functional::min<double>,
std::logical_and<double>, std::logical_or<double>, functional::logical_xor<double>) {
Expand Down Expand Up @@ -394,6 +397,61 @@ TEST_CASE("BinaryOpNode - MultiplyNode") {
}
}

TEST_CASE("BinaryOpNode - DivideNode") {
auto graph = Graph();

GIVEN("x = IntegerNode(-5, 5), a = 3, y = x / a") {
auto x_ptr = graph.emplace_node<IntegerNode>(std::vector<ssize_t>{}, -5, 5);
auto a_ptr = graph.emplace_node<ConstantNode>(3);

auto y_ptr = graph.emplace_node<DivideNode>(x_ptr, a_ptr);

THEN("y's max/min/integral are as expected") {
CHECK(std::abs(y_ptr->max() - 5.0/3.0) < 10e-16);
CHECK(std::abs(y_ptr->min() - -5.0/3.0) < 10e-16);
CHECK_FALSE(y_ptr->integral());
}
}

GIVEN("x = IntegerNode(-5, 5), a = -3, y = x / a") {
auto x_ptr = graph.emplace_node<IntegerNode>(std::vector<ssize_t>{}, -5, 5);
auto a_ptr = graph.emplace_node<ConstantNode>(-3);

auto y_ptr = graph.emplace_node<DivideNode>(x_ptr, a_ptr);

THEN("y's max/min/integral are as expected") {
CHECK(y_ptr->max() == -5.0/-3.0);
CHECK(y_ptr->min() == 5.0/-3.0);
CHECK_FALSE(y_ptr->integral());
}
}
GIVEN("x = IntegerNode(-5, 5), a = 0, y = x / a") {
auto x_ptr = graph.emplace_node<IntegerNode>(std::vector<ssize_t>{}, -5, 5);
auto a_ptr = graph.emplace_node<ConstantNode>(0);

THEN("Check division-by-zero") {
CHECK_THROWS(graph.emplace_node<DivideNode>(x_ptr, a_ptr));
}
}
GIVEN("x = IntegerNode(-5, 5), a = IntegerNode(-5, 0), y = x / a") {
auto x_ptr = graph.emplace_node<IntegerNode>(std::vector<ssize_t>{}, -5, 5);
auto a_ptr = graph.emplace_node<IntegerNode>(std::vector<ssize_t>{}, -5, 0);

THEN("Check division-by-zero") {
CHECK_THROWS(graph.emplace_node<DivideNode>(x_ptr, a_ptr));
}
}
GIVEN("x = IntegerNode(-5, 5), a = IntegerNode(0, 5), y = x / a") {
auto x_ptr = graph.emplace_node<IntegerNode>(std::vector<ssize_t>{}, -5, 5);
auto a_ptr = graph.emplace_node<IntegerNode>(std::vector<ssize_t>{}, 0, 5);

THEN("Check division-by-zero") {
CHECK_THROWS(graph.emplace_node<DivideNode>(x_ptr, a_ptr));
}
}

}

TEST_CASE("BinaryOpNode - SubtractNode") {
auto graph = Graph();

Expand Down
Loading

0 comments on commit b8d10fc

Please sign in to comment.