Skip to content

Commit

Permalink
Switch reduction val to first input of expression in NaryReduce
Browse files Browse the repository at this point in the history
  • Loading branch information
wbernoudy committed Dec 23, 2024
1 parent 22a7764 commit 23cefb1
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 29 deletions.
12 changes: 12 additions & 0 deletions dwave/optimization/include/dwave-optimization/nodes/lambda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@

namespace dwave::optimization {

// Performs an n-ary element-wise reduction operation on the 1d array operands.
//
// The operation is taken in as another (separate) `Graph`, which is expected
// to have n + 1 `InputNode`s, where n is the number of operands. The extra
// `InputNode` will be used for the previous/initial value of the output of
// the reduction. Following the convention of `numpy.ufunc.reduce()` and
// `std::accumulate()`, the special input should be the first `InputNode`
// on the given `Graph`, with the remaining inputs used for the values of
// the operands.
class NaryReduceNode : public ArrayOutputMixin<ArrayNode> {
public:
// Runtime constructor that can be used from Cython/Python
Expand All @@ -47,6 +56,9 @@ class NaryReduceNode : public ArrayOutputMixin<ArrayNode> {
private:
double evaluate_expression(State& register_) const;

std::span<const InputNode* const> operand_inputs() const;
const InputNode* const reduction_input() const;

Graph expression_;
const std::vector<ArrayNode*> operands_;
const ArrayNode* output_;
Expand Down
43 changes: 25 additions & 18 deletions dwave/optimization/src/nodes/lambda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,33 +124,32 @@ NaryReduceNode::NaryReduceNode(Graph&& expression, const std::vector<ArrayNode*>
R"({"message": "expression must have output (objective) set"})");
}

std::span<const InputNode* const> inputs = expression_.inputs();
assert(operand_inputs().size() == operands_.size());

for (ssize_t op_idx = 0; op_idx < static_cast<ssize_t>(operands_.size()); op_idx++) {
if (operands_[op_idx]->min() < inputs[op_idx]->min()) {
if (operands_[op_idx]->min() < operand_inputs()[op_idx]->min()) {
throw std::invalid_argument(
R"({"message": "operand with index )" + std::to_string(op_idx) +
R"( has minimum smaller than corresponding input in expression"})");
} else if (operands_[op_idx]->max() > inputs[op_idx]->max()) {
} else if (operands_[op_idx]->max() > operand_inputs()[op_idx]->max()) {
throw std::invalid_argument(
R"({"message": "operand with index )" + std::to_string(op_idx) +
R"( has maximum larger than corresponding input in expression"})");
} else if (inputs[op_idx]->integral() && !operands_[op_idx]->integral()) {
} else if (operand_inputs()[op_idx]->integral() && !operands_[op_idx]->integral()) {
throw std::invalid_argument(
R"({"message": "operand with index )" + std::to_string(op_idx) +
R"( is non-integral, but corresponding input is integral"})");
}
}

const InputNode* previous = inputs.back();
if (previous->integral() && !output_->integral()) {
if (reduction_input()->integral() && !output_->integral()) {
throw std::invalid_argument(
R"({"message": "if expression output can be non-integral, last input must not be integral"})");
;
} else if (output_->min() < previous->min()) {
} else if (output_->min() < reduction_input()->min()) {
throw std::invalid_argument(
R"({"message": "expression output must not have a lower min than the last input"})");
} else if (output_->max() > previous->max()) {
} else if (output_->max() > reduction_input()->max()) {
throw std::invalid_argument(
R"({"message": "expression output must not have a higher max than the last input"})");
}
Expand Down Expand Up @@ -211,15 +210,15 @@ void NaryReduceNode::initialize_state(State& state) const {

// Compute the expression for each subsequent index
for (ssize_t index = 0; index < start_size; ++index) {
// First input comes from the previous expression
val = std::clamp(val, reduction_input()->min(), reduction_input()->max());
reduction_input()->assign(reg, std::span(&val, 1));

for (ssize_t arg_index = 0; arg_index < num_args; ++arg_index) {
double input_val = *iterators[arg_index];
expression_.inputs()[arg_index]->assign(reg, std::span<double>(&input_val, 1));
operand_inputs()[arg_index]->assign(reg, std::span<double>(&input_val, 1));
iterators[arg_index]++;
}
// Final input comes from the previous expression
val = std::clamp(val, expression_.inputs()[num_args]->min(),
expression_.inputs()[num_args]->max());
expression_.inputs()[num_args]->assign(reg, std::span(&val, 1));
val = evaluate_expression(reg);
values.push_back(val);
}
Expand All @@ -234,6 +233,10 @@ double NaryReduceNode::max() const { return output_->max(); }

double NaryReduceNode::min() const { return output_->min(); }

std::span<const InputNode* const> NaryReduceNode::operand_inputs() const {
return expression_.inputs().subspan(1);
}

void NaryReduceNode::propagate(State& state) const {
NaryReduceNodeData* data = data_ptr<NaryReduceNodeData>(state);
ssize_t new_size = this->size(state);
Expand All @@ -247,22 +250,26 @@ void NaryReduceNode::propagate(State& state) const {
double val = initial;

for (ssize_t index = 0; index < new_size; ++index) {
// First input comes from the previous expression
val = std::clamp(val, reduction_input()->min(), reduction_input()->max());
reduction_input()->assign(data->register_, std::span(&val, 1));

for (ssize_t arg_index = 0; arg_index < num_args; ++arg_index) {
double arg_val = *data->iterators[arg_index];
expression_.inputs()[arg_index]->assign(data->register_, std::span(&arg_val, 1));
operand_inputs()[arg_index]->assign(data->register_, std::span(&arg_val, 1));
data->iterators[arg_index]++;
}
// Final input comes from the previous expression
val = std::clamp(val, expression_.inputs()[num_args]->min(),
expression_.inputs()[num_args]->max());
expression_.inputs()[num_args]->assign(data->register_, std::span(&val, 1));
val = evaluate_expression(data->register_);
data->set(index, val);
}

if (data->diff().size()) Node::propagate(state);
}

const InputNode* const NaryReduceNode::reduction_input() const {
return expression_.inputs()[0];
}

void NaryReduceNode::revert(State& state) const { data_ptr<NaryReduceNodeData>(state)->revert(); }

std::span<const ssize_t> NaryReduceNode::shape(const State& state) const {
Expand Down
14 changes: 8 additions & 6 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2401,10 +2401,11 @@ _register(NaryMultiply, typeid(cppNaryMultiplyNode))

cdef class NaryReduce(ArraySymbol):
"""Using a supplied :class:`~dwave.optimization.model.Expression`, perform
a reduction operation along one or more array operands. The reduction
operation (represented by the ``Expression``) takes as input one value from
each of the operand arrays, as well as the result of the previously
computed operation, and computes a new value at the next output index.
an element-wise reduction operation along one or more array operands. The
reduction operation (represented by the ``Expression``) takes as input one
value from each of the operand arrays, as well as the result of the
previously computed operation, and computes a new value at the next output
index.
This takes inspiration from
`numpy.ufunc.reduce <https://numpy.org/doc/2.1/reference/generated/numpy.ufunc.reduce.html>`_
Expand All @@ -2415,7 +2416,7 @@ cdef class NaryReduce(ArraySymbol):
Args:
expression:
An :class:`~dwave.optimization.model.Expression` representing the
reduction operation. The last input on the expression will be given
reduction operation. The first input on the expression will be given
the previous output of the operation at each iteration over the
values of the operands.
operands:
Expand All @@ -2435,10 +2436,11 @@ cdef class NaryReduce(ArraySymbol):
>>> model = Model() # the main model
>>> x = model.integer(10, lower_bound=0, upper_bound=5)
>>> expr = Expression() # the reduction operation
>>> # first input is used to take the value of the previous output
>>> previous = expr.input()
>>> # xi will take the values of `x`. Provided bounds are necessary
>>> # in this case, but may be helpful in other expressions.
>>> xi = expr.input(0, 5, integral=True)
>>> previous = expr.input()
>>> expr.set_output(xi + previous)
>>> cumulative_sum_x = NaryReduce(expr, (x,))
>>> type(cumulative_sum_x)
Expand Down
10 changes: 5 additions & 5 deletions tests/cpp/nodes/test_lambda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ TEST_CASE("NaryReduceNode") {
expression.emplace_node<InputNode>(),
expression.emplace_node<InputNode>()};
auto output_ptr = expression.emplace_node<AddNode>(
expression.emplace_node<MultiplyNode>(inputs[0], inputs[1]), inputs[2]);
expression.emplace_node<MultiplyNode>(inputs[1], inputs[2]), inputs[0]);
expression.set_objective(output_ptr);
expression.topological_sort();

Expand Down Expand Up @@ -71,9 +71,9 @@ TEST_CASE("NaryReduceNode") {
expression.emplace_node<SubtractNode>(
expression.emplace_node<MultiplyNode>(
expression.emplace_node<AddNode>(
inputs[0], expression.emplace_node<ConstantNode>(1)),
inputs[1]),
inputs[2]),
inputs[1], expression.emplace_node<ConstantNode>(1)),
inputs[2]),
inputs[0]),
expression.emplace_node<ConstantNode>(5));
expression.set_objective(output_ptr);
expression.topological_sort();
Expand Down Expand Up @@ -126,7 +126,7 @@ TEST_CASE("NaryReduceNode") {
expression.emplace_node<InputNode>(),
expression.emplace_node<InputNode>()};
auto output_ptr = expression.emplace_node<MaximumNode>(
expression.emplace_node<AddNode>(inputs[0], inputs[2]), inputs[1]);
expression.emplace_node<AddNode>(inputs[1], inputs[0]), inputs[2]);
expression.set_objective(output_ptr);
expression.topological_sort();

Expand Down

0 comments on commit 23cefb1

Please sign in to comment.