diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 2956afad02..e03462bf78 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -87,14 +87,7 @@ def jax_funcify_Join(op, **kwargs): def join(axis, *tensors): # tensors could also be tuples, and in this case they don't have a ndim tensors = [jnp.asarray(tensor) for tensor in tensors] - view = op.view - if (view != -1) and all( - tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :] - ): - return tensors[view] - - else: - return jnp.concatenate(tensors, axis=axis) + return jnp.concatenate(tensors, axis=axis) return join diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 7daa625794..7749514e03 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -117,17 +117,9 @@ def arange(start, stop, step): @numba_funcify.register(Join) def numba_funcify_Join(op, **kwargs): - view = op.view - - if view != -1: - # TODO: Where (and why) is this `Join.view` even being used? From a - # quick search, the answer appears to be "nowhere", so we should - # probably just remove it. - raise NotImplementedError("The `view` parameter to `Join` is not supported") - @numba_basic.numba_njit def join(axis, *tensors): - return np.concatenate(tensors, numba_basic.to_scalar(axis)) + return np.concatenate(tensors, axis.item()) return join diff --git a/pytensor/scan/checkpoints.py b/pytensor/scan/checkpoints.py index 8c237267d5..d974e8257e 100644 --- a/pytensor/scan/checkpoints.py +++ b/pytensor/scan/checkpoints.py @@ -1,6 +1,5 @@ import pytensor.tensor.basic as ptb from pytensor.scan.basic import scan -from pytensor.tensor.basic import Join from pytensor.tensor.math import ceil, eq, neq from pytensor.tensor.subtensor import set_subtensor @@ -127,14 +126,12 @@ def scan_checkpoints( # Pad the sequences if needed if padding: - # Since padding could be an empty tensor, Join returns a view of s. - join = Join(view=0) for i, s in enumerate(sequences): overshoots_by = s.shape[0] % save_every_N overshoots = neq(overshoots_by, 0) n = (save_every_N - overshoots_by) * overshoots z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype) - sequences[i] = join(0, s, z) + sequences[i] = ptb.join(0, s, z) # Establish the input variables of the outer scan o_sequences = [ diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5d6c059c53..9117a0d99d 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2412,37 +2412,17 @@ class Join(COp): The axis has to be an index into the shape >>> pt.join(2, x, y, z) Traceback (most recent call last): - ValueError: Axis value 2 is out of range for the given input dimensions + numpy.exceptions.AxisError: axis 2 is out of bounds for array of dimension 2 Joined tensors must have the same rank >>> pt.join(0, x, u) Traceback (most recent call last): - TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1]. + TypeError: Only tensors with the same number of dimensions can be joined """ check_input = False - __props__ = ("view",) - - def __init__(self, view=-1): - self.view = view - if view != -1: - # since the first input is always the axis, the tensors - # start from index 1. - self.view_map = {0: [1 + view]} - - def __str__(self): - if self.view == -1: - return self.__class__.__name__ - else: - classname = self.__class__.__name__ - args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__) - return f"{classname}{{{args}}}" - - def __setstate__(self, d): - self.__dict__.update(d) - if not hasattr(self, "view"): - self.view = -1 + __props__ = () def make_node(self, axis, *tensors): """ @@ -2459,74 +2439,60 @@ def make_node(self, axis, *tensors): if not tensors: raise ValueError("Cannot join an empty list of tensors") + axis = as_tensor_variable(axis) + if axis.type.dtype not in int_dtypes: + raise TypeError(f"Axis {axis} must be an integer type.") + if axis.type.ndim > 0: + raise TypeError(f"Axis {axis} must be 0-d.") + tensors = [as_tensor_variable(x) for x in tensors] - out_dtype = ps.upcast(*[x.type.dtype for x in tensors]) - if not builtins.all(targs.type.ndim for targs in tensors): + if not builtins.all(targs.type.ndim > 0 for targs in tensors): raise TypeError( "Join cannot handle arguments of dimension 0." - " Use `stack` to join scalar values." + " Use `stack` to join scalar values and/or increase rank of scalars." ) if len(tensors) == 1: out_shape = tensors[0].type.shape else: - # When the axis is fixed, a dimension should be - # broadcastable if at least one of the inputs is - # broadcastable on that dimension (see justification below), - # except for the axis dimension. - # Initialize bcastable all false, and then fill in some trues with - # the loops. - - if not isinstance(axis, int): - try: - axis = int(get_scalar_constant_value(axis)) - except NotScalarConstantError: - pass - ndim = tensors[0].type.ndim - if isinstance(axis, int): - # Basically, broadcastable -> length 1, but the - # converse does not hold. So we permit e.g. T/F/T - # joins, and if they fail at runtime they fail, but if - # they don't then it means that the argument where - # that broadcastable flag was False had length 1 along - # this dimension, and therefore this dimension should - # be broadcastable for the output. - - if axis < -ndim: - raise IndexError( - f"Axis value {axis} is out of range for the given input dimensions" - ) - if axis < 0: - axis += ndim - if axis > ndim - 1: - raise ValueError( - f"Axis value {axis} is out of range for the given input dimensions" - ) - # NOTE: Constant negative axis can no longer be negative at this point. - - in_shapes = [x.type.shape for x in tensors] - in_ndims = [len(s) for s in in_shapes] - if set(in_ndims) != {ndim}: - raise TypeError( - "Only tensors with the same number of dimensions can be joined." - f" Input ndims were: {in_ndims}." - ) + + if not builtins.all(x.ndim == ndim for x in tensors): + raise TypeError( + "Only tensors with the same number of dimensions can be joined" + ) + + try: + static_axis = int(get_scalar_constant_value(axis)) + except NotScalarConstantError: + static_axis = None + + if static_axis is None: + # When axis isn't static, we can't canclude anything about output dimension + # (unless we had some degenerate zero arrays) that can be removed during rewrites. + # We could also raise errors if any dimensions are pairwise inconsistent across all the axes + # As no matter the join it would be invalid. + # However, dynamic axis is so rare that is not worth the trouble + out_shape = [None] * ndim + + else: # We know the axis statically + static_axis = normalize_axis_index(static_axis, ndim) + static_shapes = [x.type.shape for x in tensors] # Determine output shapes from a matrix of input shapes - in_shapes = np.array(in_shapes) + static_shapes = np.array(static_shapes) out_shape = [None] * ndim for d in range(ndim): - ins = in_shapes[:, d] - if d == axis: - # Any unknown size along the axis means we can't sum + ins = static_shapes[:, d] + if d == static_axis: + # Any unknown size along the axis means we can't infer it if None in ins: out_shape[d] = None else: out_shape[d] = sum(ins) else: - inset = set(in_shapes[:, d]) + inset = set(static_shapes[:, d]) # Other dims must match exactly, # or if a mix of None and ? the output will be ? # otherwise the input shapes are incompatible. @@ -2536,100 +2502,141 @@ def make_node(self, axis, *tensors): (out_shape[d],) = inset - {None} else: raise ValueError( - f"all input array dimensions other than the specified `axis` ({axis})" + f"all input array dimensions other than the specified `axis` ({static_axis})" " must match exactly, or be unknown (None)," f" but along dimension {d}, the inputs shapes are incompatible: {ins}" ) - else: - # When the axis may vary, no dimension can be guaranteed to be - # broadcastable. - out_shape = [None] * tensors[0].type.ndim - if not builtins.all(x.ndim == len(out_shape) for x in tensors): - raise TypeError( - "Only tensors with the same number of dimensions can be joined" - ) - - inputs = [as_tensor_variable(axis), *tensors] + inputs = [axis, *tensors] + out_dtype = ps.upcast(*[x.type.dtype for x in tensors]) + return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)]) - if inputs[0].type.dtype not in int_dtypes: - raise TypeError(f"Axis value {inputs[0]} must be an integer type") + def perform(self, node, inputs, output_storage): + axis, *arrays = inputs + output_storage[0][0] = np.concatenate( + arrays, axis=axis, dtype=node.outputs[0].type.dtype + ) - return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)]) + def c_code_cache_version(self): + return (7,) - def perform(self, node, axis_and_tensors, out_): - (out,) = out_ - view = self.view - axis, tens = axis_and_tensors[0], axis_and_tensors[1:] - # we check these tensors for being empty. - if (view != -1) and all( - tensor.shape[axis] == 0 for tensor in tens[0:view] + tens[view + 1 :] - ): - out[0] = tens[view] + def c_code(self, node, name, inputs, outputs, sub): + axis, *arrays = inputs + [out] = outputs + n = len(arrays) + ndim = node.outputs[0].type.ndim + fail = sub["fail"] + # Most times axis is constant, inline it + # This is safe to do because the hash of the c_code includes the constant signature + if isinstance(node.inputs[0], Constant): + static_axis = int(node.inputs[0].data) + static_axis = normalize_axis_index(static_axis, ndim) + axis_def = f"{static_axis};" + axis_check = "" else: - ndim = tens[0].ndim - if axis < -ndim: - raise IndexError( - f"Join axis {int(axis)} out of bounds [0, {int(ndim)})" - ) + axis_dtype = node.inputs[0].type.dtype_specs()[1] + axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];" + axis_check = f""" + if (axis < 0){{ + axis = {ndim} + axis; + }} + if (axis >= {ndim} || axis < 0) {{ + PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds"); + {fail} + }} + """ - out[0] = np.asarray( - np.concatenate(tens, axis=axis), dtype=node.outputs[0].type.dtype + copy_arrays_to_tuple = "\n".join( + ( + f"""Py_INCREF({array}); PyTuple_SetItem(arrays_tuple, {i}, (PyObject*){array});""" + for i, array in enumerate(arrays) ) + ) - def c_code_cache_version(self): - return (5,) + code = f""" + int axis = {axis_def} + PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}}; + int out_is_valid = {out} != NULL; - def c_code(self, node, name, inputs, outputs, sub): - axis, tens = inputs[0], inputs[1:] - view = self.view - non_empty_tensor = tens[view] - input_1 = tens[0] - l = len(tens) - (out,) = outputs - fail = sub["fail"] - adtype = node.inputs[0].type.dtype_specs()[1] + {axis_check} - copy_to_list = ( - f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});""" - for i, inp in enumerate(tens) - ) + if (out_is_valid) {{ + // Check if we can reuse output + npy_intp join_size = 0; + npy_intp out_shape[{ndim}]; + npy_intp *shape = PyArray_SHAPE(arrays[0]); - copy_inputs_to_list = "\n".join(copy_to_list) - n = len(tens) + for (int i = 0; i < {n}; i++) {{ + if (PyArray_NDIM(arrays[i]) != {ndim}) {{ + PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim"); + {fail} + }} - code = f""" - int axis = (({adtype} *)PyArray_DATA({axis}))[0]; - PyObject* list = PyList_New({l}); - {copy_inputs_to_list} - int tensors_lens_sum; - if({view} != -1) {{ - tensors_lens_sum = 0; - - for(int i=0; i < {n}; i++){{ - tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis); + join_size += PyArray_SHAPE(arrays[i])[axis]; + + if (i > 0){{ + for (int j = 0; j < {ndim}; j++) {{ + if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{ + PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis"); + {fail} + }} + }} + }} + }} + + memcpy(out_shape, shape, {ndim} * sizeof(npy_intp)); + out_shape[axis] = join_size; + + for (int i = 0; i < {ndim}; i++) {{ + out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]); }} - tensors_lens_sum -= PyArray_DIM({non_empty_tensor}, axis); }} - if({view} != -1 && tensors_lens_sum == 0) {{ + + if (!out_is_valid) {{ + // Use PyArray_Concatenate Py_XDECREF({out}); - Py_INCREF({non_empty_tensor}); - {out} = {non_empty_tensor}; - }}else{{ - //PyObject* PyArray_Concatenate(PyObject* obj, int axis) - int ndim = PyArray_NDIM({input_1}); - if( axis < -ndim ){{ - PyErr_Format(PyExc_IndexError, - "Join axis %d out of bounds [0, %d)", axis, ndim); + PyObject* arrays_tuple = PyTuple_New({n}); + {copy_arrays_to_tuple} + {out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis); + Py_DECREF(arrays_tuple); + if(!{out}){{ {fail} }} - Py_XDECREF({out}); - {out} = (PyArrayObject *)PyArray_Concatenate(list, axis); - Py_DECREF(list); - if(!{out}){{ + }} + else {{ + // Copy the data to the pre-allocated output buffer + + // Create view into output buffer + PyArrayObject_fields *view; + + // PyArray_NewFromDescr steals a reference to descr, so we need to increase it + Py_INCREF(PyArray_DESCR({out})); + view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type, + PyArray_DESCR({out}), + {ndim}, + PyArray_SHAPE(arrays[0]), + PyArray_STRIDES({out}), + PyArray_DATA({out}), + NPY_ARRAY_WRITEABLE, + NULL); + if (view == NULL) {{ {fail} }} + + // Copy data into output buffer + for (int i = 0; i < {n}; i++) {{ + view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis]; + + if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{ + Py_DECREF(view); + {fail} + }} + + view->data += (view->dimensions[axis] * view->strides[axis]); + }} + + Py_DECREF(view); }} """ return code @@ -2639,22 +2646,21 @@ def R_op(self, inputs, eval_points): return [None] return self.make_node(inputs[0], *eval_points[1:]).outputs - def grad(self, axis_and_tensors, grads): + def L_op(self, inputs, outputs, grads): """The gradient wrt a join op is a `Split`, used to partition the gradient along the `axis` which was used for joining. """ - (gz,) = grads - axis, tens = axis_and_tensors[0], axis_and_tensors[1:] + [gz] = grads + [out] = outputs + axis, *tensors = inputs rval = [grad_undefined(self, 0, axis)] - - dtypes = [as_tensor_variable(x).type.dtype for x in tens] - out_dtype = ps.upcast(*dtypes) + out_dtype = out.type.dtype if "float" in out_dtype or "complex" in out_dtype: # assume that this is differentiable - split = Split(len(tens)) - split_gz = split(gz, axis, stack([shape(x)[axis] for x in tens])) + split_sizes = stack([shape(x)[axis] for x in tensors]) + split_gz = split(gz, split_sizes, n_splits=len(tensors), axis=axis) # If there is only one split, it might not be in a list. if not isinstance(split_gz, list): split_gz = [split_gz] @@ -2667,13 +2673,12 @@ def grad(self, axis_and_tensors, grads): else specify_broadcastable( g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1) ) - for t, g in zip(tens, split_gz, strict=True) + for t, g in zip(tensors, split_gz, strict=True) ] rval = rval + split_gz else: - # the output has integer type, so the gradient through it - # is 0 - rval = rval + [t.zeros_like(dtype=config.floatX) for t in tens] + # the output has integer type, so the gradient through it is 0 + rval = rval + [t.zeros_like(dtype=config.floatX) for t in tensors] return rval @@ -2693,7 +2698,8 @@ def infer_shape(self, fgraph, node, ishapes): # An axis < -n_dim or >= ndim would be invalid, but this is # not checked here. A `CheckAndRaise` `Op` would be a way of # addressing that, but it may disrupt optimizations. - join_dim = switch(ge(node.inputs[0], 0), node.inputs[0], node.inputs[0] + n_dim) + axis = node.inputs[0] + join_dim = switch(ge(axis, 0), axis, axis + n_dim) out_shapes = [] for dim in range(n_dim): # we have to deal with 2 possible cases in here : @@ -2716,7 +2722,7 @@ def infer_shape(self, fgraph, node, ishapes): return [tuple(out_shapes)] -join_ = Join() +_join = Join() pprint.assign(Join, printing.FunctionPrinter(["join"])) @@ -2759,7 +2765,7 @@ def join(axis, *tensors_list): if len(tensors_list) == 1: return tensors_list[0] else: - return join_(axis, *tensors_list) + return _join(axis, *tensors_list) @_vectorize_node.register(Join) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 59148fae3b..61db37bd27 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -41,6 +41,7 @@ node_rewriter, ) from pytensor.graph.rewriting.db import RewriteDatabase +from pytensor.npy_2_compat import normalize_axis_index from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.scalar.basic import Second from pytensor.tensor.basic import ( @@ -817,52 +818,38 @@ def local_join_1(fgraph, node): return [tensors[0]] -# TODO: merge in local_useless_join -@register_infer_shape @register_useless -@register_specialize @register_canonicalize +@register_specialize @node_rewriter([Join]) def local_join_empty(fgraph, node): """Join(i, x, y, empty) => Join(i, x, y) Remove empty inputs to joins. The empty inputs can be anywhere. - """ - if not isinstance(node.op, Join): - return - new_inputs = [] + axis, *tensors = node.inputs + try: - join_idx = get_scalar_constant_value( + static_axis = get_scalar_constant_value( node.inputs[0], only_process_constants=True ) except NotScalarConstantError: return - for idx in range(1, len(node.inputs)): - inp = node.inputs[idx] - # We can not use size == 0,, as this can change shape from 3,0 - # to 2,0. This trigger DebugMode error. This happen with - # stack(...,[]) as this add a dimshuffle on [], that add a - # dimensions with shape 1. - if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0: - continue - new_inputs.append(inp) - if len(new_inputs) < len(node.inputs) - 1: - if len(new_inputs) == 0: - # at.join do not work in that case. - # constant folding will take care of this case. - return - ret = join(node.inputs[0], *new_inputs) - o = node.outputs[0] - if ret.dtype != o.dtype: - # Join can upcast some inputs - return - # Copy over stacktrace from previous output (after join op) - # to new output, because an error in the new op must be caused - # by an error in the old join op. - copy_stack_trace(node.outputs, ret) + new_tensors = [tensor for tensor in tensors if tensor.type.shape[static_axis] != 0] + + # If there are zero tensors, the join is useless but so is any other operation + # Another rewrite will (one day) handle all those cases + if 0 < len(new_tensors) < len(tensors): + # join eagerly returns a tensor when there is only one, no need for us to check + ret = join(axis, *new_tensors) + + [old_output] = node.outputs + + if ret.dtype != old_output.dtype: + ret = ret.astype(old_output.dtype) + copy_stack_trace(old_output, ret) return [ret] @@ -1298,7 +1285,7 @@ def local_join_of_alloc(fgraph, node): # Axis can never be lifted # Non-axis allocated dimensions can be lifted if they are all broadcastable [out] = node.outputs - axis = axis.data + static_axis = normalize_axis_index(axis.data, tensors[0].type.ndim) broadcasted_dims = list( zip( @@ -1320,7 +1307,7 @@ def local_join_of_alloc(fgraph, node): lifteable_alloc_dims = { dim for dim in range(out.type.ndim) - if dim != axis and all(broadcasted_dims[dim]) + if dim != static_axis and all(broadcasted_dims[dim]) } if not lifteable_alloc_dims: @@ -1337,13 +1324,13 @@ def local_join_of_alloc(fgraph, node): copy_stack_trace(tensor, new_tensor) new_tensors.append(new_tensor) - new_join = node.op(axis, *new_tensors) + new_join = node.op(static_axis, *new_tensors) copy_stack_trace(node.outputs[0], new_join) # Reintroduce the lifted dims post_join_shape = [] for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)): - if i == axis: + if i == static_axis: # The alloc dim along the axis is the sum of all the pre-join alloc dims post_join_shape.append(add(*alloc_dims)) else: diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 09963f9d36..625246e340 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -172,24 +172,6 @@ def test_Join(vals, axis): ) -def test_Join_view(): - vals, vals_test = zip( - *( - (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), - (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), - ), - strict=True, - ) - g = ptb.Join(view=1)(1, *vals) - - with pytest.raises(NotImplementedError): - compare_numba_and_py( - vals, - g, - vals_test, - ) - - @pytest.mark.parametrize( "n_splits, axis, values, sizes", [ diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 1730ae46ac..a959efd6d3 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1248,65 +1248,41 @@ def test_local_join_1(): def test_local_join_empty(): - # test for vector, vector, empty to vector + # Vector case empty_vec = np.asarray([], dtype=config.floatX) - a = vector("a") - s = pt.join(0, a, a, empty_vec) - f = function([a], s, mode=rewrite_mode) - val = f([1]) - assert np.all(val == [1]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 3 - for n in e - if isinstance(n.op, Join) + vec = vector("vec") + s = pt.join(0, vec, vec, empty_vec) + new_s = rewrite_graph(s) + assert equal_computations([new_s], [join(0, vec, vec)]) + assert new_s.dtype == s.dtype + + # Matrix case + empty_mat = np.zeros((2, 0), dtype=config.floatX) + empty_sym_mat = matrix("m", shape=(2, 0)) + mat = matrix("mat", shape=(2, 10)) + s = join(1, empty_mat, mat, empty_sym_mat, mat, mat) + new_s = rewrite_graph(s) + assert equal_computations([new_s], [join(1, mat, mat, mat)]) + assert new_s.dtype == s.dtype + + # Join can be completely removed, but casting and specify_shape are propagated + int_mat = matrix("int_mat", dtype=int) + s = join(-1, empty_mat, int_mat, empty_sym_mat) + new_s = rewrite_graph(s) + assert equal_computations( + [new_s], [specify_shape(int_mat, (2, None)).astype(s.dtype)] ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test for matrix join(1,a) - empty_mat = np.asarray([[]], dtype=config.floatX) - m = matrix("m") - s = join(1, empty_mat, m, m, m) - f = function([m], s, mode=rewrite_mode) - val = f([[1]]) - assert np.all(val == [[1]]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test for vector, vector, empty to matrix - # We can't rewrite this case. - s = pt.stack([a, a, empty_vec]) - f = function([a], s, mode=rewrite_mode) - val = f([]) - assert np.all(val == [1]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test for matrix join(0,a) - # We can't rewrite this case. - s = join(0, m, np.asarray([[2.0]], dtype=config.floatX), m) - f = function([m], s, mode=rewrite_mode) - val = f([[1]]) - assert np.all(val == [[1], [2], [1]]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX + # Dynamic axis, can't apply rewrite + axis = scalar("axis", dtype=int) + s = join(axis, empty_mat, int_mat, empty_sym_mat) + new_s = rewrite_graph(s) + assert equal_computations([new_s], [s]) + + # Stack introduces an expand_dims in the join, that's a nonzero dim! + s = pt.stack([vec, vec, empty_vec]) + new_s = rewrite_graph(s) + assert equal_computations([new_s], [s]) def test_local_join_make_vector(): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index e29a47691a..c8c2bb224a 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -117,6 +117,7 @@ ivector, lscalar, lvector, + matrices, matrix, row, scalar, @@ -1762,7 +1763,7 @@ def test_join_matrixV_negative_axis(self): got = f(-2) assert np.allclose(got, want) - with pytest.raises(IndexError): + with pytest.raises(ValueError): f(-3) @pytest.mark.parametrize("py_impl", (False, True)) @@ -1805,7 +1806,7 @@ def test_join_matrixC_negative_axis(self, py_impl): got = f() assert np.allclose(got, want) - with pytest.raises(IndexError): + with pytest.raises(ValueError): join(-3, a, b) with impl_ctxt: @@ -2118,28 +2119,6 @@ def test_split_static_shape(self): y = Split(2)(x, 0, [s, 5 - s])[0] assert y.type.shape == (None,) - def test_join_inplace(self): - # Test join to work inplace. - # - # This function tests the case when several elements are passed to the - # join function but all except one of them are empty. In this case join - # should work inplace and the output should be the view of the non-empty - # element. - s = lscalar() - x = vector("x") - z = ptb.zeros((s,)) - - join = Join(view=0) - c = join(0, x, z, z) - - f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True)) - - data = np.array([3, 4, 5], dtype=config.floatX) - - if config.mode not in ["DebugMode", "DEBUG_MODE"]: - assert f(data, 0) is data - assert np.allclose(f(data, 0), [3, 4, 5]) - def test_join_oneInput(self): # Test join when only 1 input is given. # @@ -2174,6 +2153,32 @@ def test_split_view(self, linker): assert np.allclose(r, expected) assert r.base is x_test + @pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}") + @pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"]) + @pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}") + @pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"]) + @config.change_flags(cmodule__warn_no_version=False) + def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark): + if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0): + pytest.skip("Redundant parametrization") + n = 64 + inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef") + out = join(axis, *inputs) + fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True) + fn.vm.allow_gc = gc + test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs] + if memory_layout == "C-contiguous": + pass + elif memory_layout == "F-contiguous": + test_values = [t.T for t in test_values] + elif memory_layout == "Mixed": + test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)] + else: + raise ValueError + + assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6) + benchmark(fn, *test_values) + def test_TensorFromScalar(): s = ps.constant(56)