Skip to content

Commit fa3b2cf

Browse files
committed
Faster C-implementation of Join
Reuse outputs when possible, and avoid numpy overhead on public facing function
1 parent 158a884 commit fa3b2cf

File tree

1 file changed

+108
-43
lines changed

1 file changed

+108
-43
lines changed

pytensor/tensor/basic.py

+108-43
Original file line numberDiff line numberDiff line change
@@ -2537,58 +2537,123 @@ def perform(self, node, inputs, output_storage):
25372537
)
25382538

25392539
def c_code_cache_version(self):
2540-
return (5,)
2540+
return (6,)
25412541

25422542
def c_code(self, node, name, inputs, outputs, sub):
2543-
axis, tens = inputs[0], inputs[1:]
2544-
view = -1
2545-
non_empty_tensor = tens[view]
2546-
input_1 = tens[0]
2547-
l = len(tens)
2548-
(out,) = outputs
2549-
fail = sub["fail"]
2550-
adtype = node.inputs[0].type.dtype_specs()[1]
2543+
axis, *arrays = inputs
2544+
[out] = outputs
25512545

2552-
copy_to_list = (
2553-
f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
2554-
for i, inp in enumerate(tens)
2555-
)
2546+
n = len(arrays)
2547+
out_dtype = node.outputs[0].type.dtype_specs()[2]
2548+
ndim = node.outputs[0].type.ndim
2549+
fail = sub["fail"]
25562550

2557-
copy_inputs_to_list = "\n".join(copy_to_list)
2558-
n = len(tens)
2551+
# Most times axis is constant, inline it
2552+
# This is safe to do because the hash of the c_code includes the constant signature
2553+
if isinstance(node.inputs[0], Constant):
2554+
static_axis = int(node.inputs[0].data)
2555+
static_axis = normalize_axis_index(static_axis, ndim)
2556+
axis_def = f"{static_axis};"
2557+
axis_check = ""
2558+
else:
2559+
axis_dtype = node.inputs[0].type.dtype_specs()[1]
2560+
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
2561+
axis_check = f"""
2562+
if (axis < 0){{
2563+
axis = {ndim} + axis;
2564+
}}
2565+
if (axis >= {ndim} || axis < 0) {{
2566+
PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2567+
{fail}
2568+
}}
2569+
"""
25592570

25602571
code = f"""
2561-
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
2562-
PyObject* list = PyList_New({l});
2563-
{copy_inputs_to_list}
2564-
int tensors_lens_sum;
2565-
if({view} != -1) {{
2566-
tensors_lens_sum = 0;
2567-
2568-
for(int i=0; i < {n}; i++){{
2569-
tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2572+
int axis = {axis_def}
2573+
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
2574+
int out_is_valid = 0;
2575+
npy_intp join_size = 0;
2576+
npy_intp offset = 0;
2577+
2578+
// Validate input shapes and compute join size
2579+
npy_intp *shape = PyArray_SHAPE(arrays[0]);
2580+
2581+
{axis_check}
2582+
2583+
for (int i = 0; i < {n}; i++) {{
2584+
if (PyArray_NDIM(arrays[i]) != {ndim}) {{
2585+
PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2586+
{fail}
2587+
}}
2588+
2589+
for (int j = 0; j < {ndim}; j++) {{
2590+
if (j == axis){{
2591+
join_size += PyArray_DIM(arrays[i], j);
2592+
}}
2593+
else if(PyArray_DIM(arrays[i], j) != shape[j]) {{
2594+
PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2595+
{fail}
2596+
}}
2597+
}}
25702598
}}
2571-
tensors_lens_sum -= PyArray_DIM({non_empty_tensor}, axis);
2572-
}}
2573-
if({view} != -1 && tensors_lens_sum == 0) {{
2574-
Py_XDECREF({out});
2575-
Py_INCREF({non_empty_tensor});
2576-
{out} = {non_empty_tensor};
2577-
}}else{{
2578-
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2579-
int ndim = PyArray_NDIM({input_1});
2580-
if( axis < -ndim ){{
2581-
PyErr_Format(PyExc_IndexError,
2582-
"Join axis %d out of bounds [0, %d)", axis, ndim);
2583-
{fail}
2599+
2600+
// Define dimensions of output array
2601+
npy_intp out_dims[{ndim}];
2602+
memcpy(out_dims, shape, {ndim} * sizeof(npy_intp));
2603+
out_dims[axis] = join_size;
2604+
2605+
// Reuse output or allocate new one
2606+
if ({out} != NULL) {{
2607+
out_is_valid = (PyArray_NDIM({out}) == {ndim});
2608+
for (int i = 0; i < {ndim}; i++) {{
2609+
out_is_valid &= (PyArray_DIM({out}, i) == out_dims[i]);
2610+
}}
25842611
}}
2585-
Py_XDECREF({out});
2586-
{out} = (PyArrayObject *)PyArray_Concatenate(list, axis);
2587-
Py_DECREF(list);
2588-
if(!{out}){{
2589-
{fail}
2612+
2613+
if (!out_is_valid) {{
2614+
Py_XDECREF({out});
2615+
{out} = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2616+
PyArray_DescrFromType({out_dtype}),
2617+
{ndim},
2618+
out_dims,
2619+
NULL, /* strides */
2620+
NULL, /* data */
2621+
NPY_ARRAY_DEFAULT,
2622+
NULL);
2623+
2624+
if ({out} == NULL) {{
2625+
{fail}
2626+
}}
2627+
}}
2628+
2629+
/* Copy data into output array */
2630+
for (int i = 0; i < {n}; i++) {{
2631+
PyArrayObject *arr = arrays[i];
2632+
2633+
/* Create temporary view array */
2634+
// PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2635+
Py_INCREF(PyArray_DESCR({out}));
2636+
PyArrayObject *view = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2637+
PyArray_DESCR({out}),
2638+
{ndim},
2639+
PyArray_SHAPE(arr),
2640+
PyArray_STRIDES({out}),
2641+
PyArray_BYTES({out}) + (offset * PyArray_STRIDES({out})[axis]),
2642+
NPY_ARRAY_WRITEABLE,
2643+
NULL);
2644+
if (view == NULL) {{
2645+
{fail}
2646+
}}
2647+
2648+
/* Copy data into the correct position */
2649+
int success = PyArray_CopyInto(view, arr);
2650+
Py_DECREF(view);
2651+
if (success != 0) {{
2652+
{fail}
2653+
}}
2654+
2655+
offset += PyArray_DIM(arr, axis);
25902656
}}
2591-
}}
25922657
"""
25932658
return code
25942659

0 commit comments

Comments
 (0)