Skip to content

Commit

Permalink
Add ConcatenateNode (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbodin authored Nov 26, 2024
1 parent 7c6d56b commit 8fd89d3
Show file tree
Hide file tree
Showing 10 changed files with 595 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/reference/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Mathematical Functions
:toctree: generated/

~add
~concatenate
~logical
~logical_and
~logical_or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@

namespace dwave::optimization {

class ConcatenateNode : public ArrayOutputMixin<ArrayNode> {
public:
explicit ConcatenateNode(std::span<ArrayNode*> array_ptrs, ssize_t axis);
explicit ConcatenateNode(std::ranges::contiguous_range auto&& array_ptrs, ssize_t axis)
: ConcatenateNode(std::span<ArrayNode*>(array_ptrs), axis) {}

double const* buff(const State& state) const override;
void commit(State& state) const override;
std::span<const Update> diff(const State& state) const override;
void initialize_state(State& state) const override;
void propagate(State& state) const override;
void revert(State& state) const override;

ssize_t axis() const { return axis_; }

private:
ssize_t axis_;
std::vector<ArrayNode*> array_ptrs_;
std::vector<ssize_t> array_starts_;
};

class ReshapeNode : public ArrayOutputMixin<ArrayNode> {
public:
ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape);
Expand Down
3 changes: 3 additions & 0 deletions dwave/optimization/libcpp/nodes.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ cdef extern from "dwave-optimization/nodes/indexing.hpp" namespace "dwave::optim


cdef extern from "dwave-optimization/nodes/manipulation.hpp" namespace "dwave::optimization" nogil:
cdef cppclass ConcatenateNode(ArrayNode):
Py_ssize_t axis()

cdef cppclass ReshapeNode(ArrayNode):
pass

Expand Down
47 changes: 47 additions & 0 deletions dwave/optimization/mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import functools
import typing

from dwave.optimization.model import ArraySymbol
from dwave.optimization.symbols import (
Add,
And,
Concatenate,
Logical,
Maximum,
Minimum,
Expand All @@ -38,6 +40,7 @@

__all__ = [
"add",
"concatenate",
"logical",
"logical_and",
"logical_not",
Expand Down Expand Up @@ -107,6 +110,50 @@ def add(x1: ArraySymbol, x2: ArraySymbol, *xi: ArraySymbol) -> typing.Union[Add,
raise RuntimeError("implementated by the op() decorator")


def concatenate(array_likes : typing.Union[collections.abc.Iterable, ArraySymbol], axis : int = 0) -> ArraySymbol:
r"""Return the concatenation of one or more symbols on the given axis.
Args:
array_like: Array symbols to concatenate.
axis: The concatenation axis.
Returns:
A symbol that is the concatenation of the given symbols along the specified axis.
Examples:
This example concatenates two constant symbols along the first axis.
>>> from dwave.optimization import Model
>>> from dwave.optimization.mathematical import concatenate
...
>>> model = Model()
>>> a = model.constant([[0,1], [2,3]])
>>> b = model.constant([[4,5]])
>>> a_b = concatenate((a,b), axis=0)
>>> a_b.shape()
(3, 2)
>>> type(a_b)
<class 'dwave.optimization.symbols.Concatenate'>
>>> with model.lock():
... model.states.resize(1)
... print(a_b.state(0))
[[0. 1.]
[2. 3.]
[4. 5.]]
"""
if isinstance(array_likes, ArraySymbol):
return array_likes

if isinstance(array_likes, collections.abc.Sequence) and (0 < len(array_likes)):
if isinstance(array_likes[0], ArraySymbol):
if len(array_likes) == 1:
return array_likes[0]

return Concatenate(tuple(array_likes), axis)

raise TypeError("concatenate takes one or more ArraySymbol as input")


def logical(x: ArraySymbol) -> Logical:
r"""Return the element-wise truth value on the given symbol.
Expand Down
1 change: 1 addition & 0 deletions dwave/optimization/src/nodes/_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class ArrayNodeStateData : public NodeStateData {
return !updates.empty();
}

double* buff() noexcept { return buffer.data(); }
const double* buff() const noexcept { return buffer.data(); }

void commit() noexcept {
Expand Down
152 changes: 152 additions & 0 deletions dwave/optimization/src/nodes/manipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,158 @@

namespace dwave::optimization {

std::vector<ssize_t> make_concatenate_shape(std::span<ArrayNode*> array_ptrs, ssize_t axis);

double const* ConcatenateNode::buff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->buff();
}

void ConcatenateNode::commit(State& state) const {
data_ptr<ArrayNodeStateData>(state)->commit();
}

ConcatenateNode::ConcatenateNode(std::span<ArrayNode*> array_ptrs, const ssize_t axis)
: ArrayOutputMixin(make_concatenate_shape(array_ptrs, axis)), axis_(axis), array_ptrs_(array_ptrs.begin(), array_ptrs.end()) {

// Compute buffer start position for each input array
array_starts_.reserve(array_ptrs.size());
array_starts_.emplace_back(0);
for (ssize_t arr_i = 1, stop = array_ptrs.size(); arr_i < stop; ++arr_i) {
auto subshape = array_ptrs_[arr_i - 1]->shape().last(this->ndim() - axis_);
ssize_t prod = std::accumulate(subshape.begin(), subshape.end(), 1, std::multiplies<ssize_t>());
array_starts_.emplace_back(prod + array_starts_[arr_i - 1]);
}

for (auto it = array_ptrs.begin(), stop = array_ptrs.end(); it != stop; ++it) {
if ((*it)->dynamic()) {
throw std::invalid_argument(
"concatenate input arrays cannot be dynamic");
}

this->add_predecessor((*it));
}
}

std::span<const Update> ConcatenateNode::diff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->diff();
}

void ConcatenateNode::initialize_state(State& state) const {
int index = topological_index();
assert(index >= 0 && "must be topologically sorted");
assert(static_cast<int>(state.size()) > index && "unexpected state length");
assert(state[index] == nullptr && "already initialized state");

std::vector<double> values;
values.resize(size());

for (ssize_t arr_i = 0, stop = array_ptrs_.size(); arr_i < stop; ++arr_i) {
// Create a view into our buffer with the same shape as
// our input array starting at the correct place
auto view_it = Array::iterator(
values.data() + array_starts_[arr_i],
this->ndim(),
array_ptrs_[arr_i]->shape().data(),
this->strides().data());

std::copy(array_ptrs_[arr_i]->begin(state), array_ptrs_[arr_i]->end(state), view_it);
}

state[index] = std::make_unique<ArrayNodeStateData>(std::move(values));
}

std::vector<ssize_t> make_concatenate_shape(std::span<ArrayNode*> array_ptrs, ssize_t axis) {
// One or more arrays must be given
if (array_ptrs.size() < 1) {
throw std::invalid_argument("need at least one array to concatenate");
}

for (auto it = std::next(array_ptrs.begin()), stop = array_ptrs.end(); it != stop; ++it) {

// Arrays must have the same number of dimensions
if ((*std::prev(it))->ndim() != (*it)->ndim()) {
throw std::invalid_argument(
"all the input arrays must have the same number of dimensions," +
std::string(" but the array at index ") +
std::to_string(std::distance(array_ptrs.begin(), std::prev(it))) +
" has " + std::to_string((*std::prev(it))->ndim()) +
" dimension(s) and the array at index " +
std::to_string(std::distance(array_ptrs.begin(), it)) +
" has " +
std::to_string((*it)->ndim()) +
" dimension(s)");
}

// Array shapes must be the same except for on the concatenation axis
for (ssize_t i = 0, stop = (*it)->ndim(); i < stop; ++i) {
if (i != axis) {
if ( (*std::prev(it))->shape()[i] != (*it)->shape()[i] ) {
throw std::invalid_argument(
"all the input array dimensions except for the concatenation" +
std::string(" axis must match exactly, but along dimension ") +
std::to_string(i) + ", the array at index " +
std::to_string(std::distance(array_ptrs.begin(), std::prev(it))) +
" has size " +
std::to_string((*std::prev(it))->shape()[i]) +
" and the array at index " +
std::to_string(std::distance(array_ptrs.begin(), it)) +
" has size " +
std::to_string((*it)->shape()[i]));
}
}
}
}

// Axis must be in range 0..ndim-1
// We can do this check on the first input array since we at
// this point know they all have the same number of dimensions
if (!(0 <= axis && axis < array_ptrs.front()->ndim())) {
throw std::invalid_argument(
"axis " +
std::to_string(axis) +
std::string(" is out of bounds for array of dimension ") +
std::to_string(array_ptrs.front()->ndim()));
}


// The shape of the input arrays, which will be the
// same except for possibly on the concatenation axis
std::span<const ssize_t> shape0 = array_ptrs.front()->shape();
std::vector<ssize_t> shape(shape0.begin(), shape0.end());

// On the concatenation axis we sum the axis dimension sizes
for (auto it = std::next(array_ptrs.begin()), stop = array_ptrs.end(); it != stop; ++it) {
shape[axis] = shape[axis] + (*it)->shape()[axis];
}

return shape;
}

void ConcatenateNode::propagate(State& state) const {
auto ptr = data_ptr<ArrayNodeStateData>(state);

for (ssize_t arr_i = 0, stop = array_ptrs_.size(); arr_i < stop; ++arr_i) {
auto view_it = Array::iterator(
ptr->buff() + array_starts_[arr_i],
this->ndim(),
array_ptrs_[arr_i]->shape().data(),
this->strides().data());

for (auto diff : array_ptrs_[arr_i]->diff(state)) {
assert(!diff.placed() && !diff.removed() && "no dynamic support implemented");
auto update_it = view_it + diff.index;
ssize_t buffer_index = &*update_it - ptr->buffer.data();
assert(*update_it == diff.old);
ptr->updates.emplace_back(buffer_index, *view_it, diff.value);
*update_it = diff.value;
}
}
}

void ConcatenateNode::revert(State& state) const {
data_ptr<ArrayNodeStateData>(state)->revert();
}

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
: ArrayOutputMixin(shape), array_ptr_(node_ptr) {
// Don't (yet) support non-contiguous predecessors.
Expand Down
4 changes: 4 additions & 0 deletions dwave/optimization/symbols.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class BinaryVariable(ArraySymbol):
def set_state(self, index: int, state: numpy.typing.ArrayLike): ...


class Concatenate(ArraySymbol):
...


class Constant(ArraySymbol):
def __bool__(self) -> bool: ...
def __index__(self) -> int: ...
Expand Down
66 changes: 66 additions & 0 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ from dwave.optimization.libcpp.nodes cimport (
ArrayValidationNode as cppArrayValidationNode,
BasicIndexingNode as cppBasicIndexingNode,
BinaryNode as cppBinaryNode,
ConcatenateNode as cppConcatenateNode,
ConstantNode as cppConstantNode,
DisjointBitSetNode as cppDisjointBitSetNode,
DisjointBitSetsNode as cppDisjointBitSetsNode,
Expand Down Expand Up @@ -106,6 +107,7 @@ __all__ = [
"BasicIndexing",
"BinaryVariable",
"_CombinedIndexing",
"Concatenate",
"Constant",
"DisjointBitSets",
"DisjointBitSet",
Expand Down Expand Up @@ -781,6 +783,70 @@ cdef class BinaryVariable(ArraySymbol):
_register(BinaryVariable, typeid(cppBinaryNode))


cdef class Concatenate(ArraySymbol):
"""Concatenate symbol.
Examples:
This example creates a Concatenate symbol.
>>> from dwave.optimization.model import Model
>>> from dwave.optimization.symbols import Concatenate
>>> model = Model()
>>> a = model.constant([[1, 2], [3, 4]])
>>> b = model.constant([[5, 6]])
>>> a_b = Concatenate((a, b), axis=0)
>>> type(a_b)
<class 'dwave.optimization.symbols.Concatenate'>
"""
def __init__(self, tuple inputs, int axis = 0):
if len(inputs) < 1:
raise TypeError("must have at least one predecessor node")

cdef Model model = inputs[0].model
cdef vector[cppArrayNode*] cppinputs

cdef ArraySymbol array
for node in inputs:
if node.model != model:
raise ValueError("all predecessors must be from the same model")
array = <ArraySymbol?>node
cppinputs.push_back(array.array_ptr)

self.ptr = model._graph.emplace_node[cppConcatenateNode](
cppinputs, axis)
self.initialize_arraynode(model, self.ptr)

@staticmethod
def _from_symbol(Symbol symbol):
cdef cppConcatenateNode* ptr = dynamic_cast_ptr[cppConcatenateNode](symbol.node_ptr)
if not ptr:
raise TypeError("given symbol cannot be used to construct a Concatenate")

cdef Concatenate m = Concatenate.__new__(Concatenate)
m.ptr = ptr
m.initialize_arraynode(symbol.model, ptr)
return m

@classmethod
def _from_zipfile(cls, zf, directory, Model model, predecessors):
if len(predecessors) < 1:
raise ValueError("Concatenate must have at least one predecessor")

with zf.open(directory + "axis.json", "r") as f:
return Concatenate(tuple(predecessors), axis=json.load(f))

def _into_zipfile(self, zf, directory):
encoder = json.JSONEncoder(separators=(',', ':'))
zf.writestr(directory + "axis.json", encoder.encode(self.axis()))

def axis(self):
return self.ptr.axis()

cdef cppConcatenateNode* ptr

_register(Concatenate, typeid(cppConcatenateNode))


cdef class Constant(ArraySymbol):
"""Constant symbol.
Expand Down
Loading

0 comments on commit 8fd89d3

Please sign in to comment.