Skip to content

Commit

Permalink
Merge pull request #167 from arcondello/fix/BinaryOpNode-sizeinfo
Browse files Browse the repository at this point in the history
Implement `BinaryOpNode::sizeinfo()`
  • Loading branch information
arcondello authored Nov 14, 2024
2 parents d772b41 + d2f5cdc commit 9a8849d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class BinaryOpNode : public ArrayOutputMixin<ArrayNode> {

ssize_t size_diff(const State& state) const override;

SizeInfo sizeinfo() const override;

void commit(State& state) const override;
void revert(State& state) const override;
void initialize_state(State& state) const override;
Expand Down
24 changes: 24 additions & 0 deletions dwave/optimization/src/nodes/mathematical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,30 @@ ssize_t BinaryOpNode<BinaryOp>::size_diff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->size_diff();
}

template <class BinaryOp>
SizeInfo BinaryOpNode<BinaryOp>::sizeinfo() const {
if (!dynamic()) return SizeInfo(size());

const Array* lhs_ptr = operands_[0];
const Array* rhs_ptr = operands_[1];

if (lhs_ptr->dynamic() && rhs_ptr->dynamic()) {
// not (yet) possible for both predecessors to be dynamic
assert(false && "not implemeted");
unreachable();
} else if (lhs_ptr->dynamic()) {
assert(rhs_ptr->size() == 1);
return SizeInfo(lhs_ptr);
} else if (rhs_ptr->dynamic()) {
assert(lhs_ptr->size() == 1);
return SizeInfo(rhs_ptr);
}

// not possible for us to be dynamic and none of our predecessors to be
assert(false && "not implemeted");
unreachable();
}

// Uncommented are the tested specializations
template class BinaryOpNode<std::plus<double>>;
template class BinaryOpNode<std::minus<double>>;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
features:
- Implement C++ ``BinaryOpNode::sizeinfo()`` overload.
fixes:
- |
Fix serializing models with binary operations over dynamic predecessors.
Previously it was not possible to estimate the state size which caused
serialization to fail.
4 changes: 4 additions & 0 deletions tests/cpp/nodes/mathematical/test_binaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ TEST_CASE("BinaryOpNode - LessEqualNode") {
THEN("We have the shape we expect") {
CHECK(std::ranges::equal(le_ptr->shape(), std::vector{-1}));
CHECK(std::ranges::equal(ge_ptr->shape(), std::vector{-1}));

// derives its size from the dynamic node
CHECK(le_ptr->sizeinfo() == SizeInfo(y_ptr));
CHECK(ge_ptr->sizeinfo() == SizeInfo(y_ptr));
}

// let's also toss an ArrayValidationNode on there to do most of the
Expand Down

0 comments on commit 9a8849d

Please sign in to comment.