@@ -2537,58 +2537,123 @@ def perform(self, node, inputs, output_storage):
2537
2537
)
2538
2538
2539
2539
def c_code_cache_version (self ):
2540
- return (5 ,)
2540
+ return (6 ,)
2541
2541
2542
2542
def c_code (self , node , name , inputs , outputs , sub ):
2543
- axis , tens = inputs [0 ], inputs [1 :]
2544
- view = - 1
2545
- non_empty_tensor = tens [view ]
2546
- input_1 = tens [0 ]
2547
- l = len (tens )
2548
- (out ,) = outputs
2549
- fail = sub ["fail" ]
2550
- adtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2543
+ axis , * arrays = inputs
2544
+ [out ] = outputs
2551
2545
2552
- copy_to_list = (
2553
- f"""Py_INCREF( { inp } ); PyList_SetItem(list, { i } , (PyObject*) { inp } );"""
2554
- for i , inp in enumerate ( tens )
2555
- )
2546
+ n = len ( arrays )
2547
+ out_dtype = node . outputs [ 0 ]. type . dtype_specs ()[ 2 ]
2548
+ ndim = node . outputs [ 0 ]. type . ndim
2549
+ fail = sub [ "fail" ]
2556
2550
2557
- copy_inputs_to_list = "\n " .join (copy_to_list )
2558
- n = len (tens )
2551
+ # Most times axis is constant, inline it
2552
+ # This is safe to do because the hash of the c_code includes the constant signature
2553
+ if isinstance (node .inputs [0 ], Constant ):
2554
+ static_axis = int (node .inputs [0 ].data )
2555
+ static_axis = normalize_axis_index (static_axis , ndim )
2556
+ axis_def = f"{ static_axis } ;"
2557
+ axis_check = ""
2558
+ else :
2559
+ axis_dtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2560
+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2561
+ axis_check = f"""
2562
+ if (axis < 0){{
2563
+ axis = { ndim } + axis;
2564
+ }}
2565
+ if (axis >= { ndim } || axis < 0) {{
2566
+ PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2567
+ { fail }
2568
+ }}
2569
+ """
2559
2570
2560
2571
code = f"""
2561
- int axis = (({ adtype } *)PyArray_DATA({ axis } ))[0];
2562
- PyObject* list = PyList_New({ l } );
2563
- { copy_inputs_to_list }
2564
- int tensors_lens_sum;
2565
- if({ view } != -1) {{
2566
- tensors_lens_sum = 0;
2567
-
2568
- for(int i=0; i < { n } ; i++){{
2569
- tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2572
+ int axis = { axis_def }
2573
+ PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2574
+ int out_is_valid = 0;
2575
+ npy_intp join_size = 0;
2576
+ npy_intp offset = 0;
2577
+
2578
+ // Validate input shapes and compute join size
2579
+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2580
+
2581
+ { axis_check }
2582
+
2583
+ for (int i = 0; i < { n } ; i++) {{
2584
+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2585
+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2586
+ { fail }
2587
+ }}
2588
+
2589
+ for (int j = 0; j < { ndim } ; j++) {{
2590
+ if (j == axis){{
2591
+ join_size += PyArray_DIM(arrays[i], j);
2592
+ }}
2593
+ else if(PyArray_DIM(arrays[i], j) != shape[j]) {{
2594
+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2595
+ { fail }
2596
+ }}
2597
+ }}
2570
2598
}}
2571
- tensors_lens_sum -= PyArray_DIM({ non_empty_tensor } , axis);
2572
- }}
2573
- if({ view } != -1 && tensors_lens_sum == 0) {{
2574
- Py_XDECREF({ out } );
2575
- Py_INCREF({ non_empty_tensor } );
2576
- { out } = { non_empty_tensor } ;
2577
- }}else{{
2578
- //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2579
- int ndim = PyArray_NDIM({ input_1 } );
2580
- if( axis < -ndim ){{
2581
- PyErr_Format(PyExc_IndexError,
2582
- "Join axis %d out of bounds [0, %d)", axis, ndim);
2583
- { fail }
2599
+
2600
+ // Define dimensions of output array
2601
+ npy_intp out_dims[{ ndim } ];
2602
+ memcpy(out_dims, shape, { ndim } * sizeof(npy_intp));
2603
+ out_dims[axis] = join_size;
2604
+
2605
+ // Reuse output or allocate new one
2606
+ if ({ out } != NULL) {{
2607
+ out_is_valid = (PyArray_NDIM({ out } ) == { ndim } );
2608
+ for (int i = 0; i < { ndim } ; i++) {{
2609
+ out_is_valid &= (PyArray_DIM({ out } , i) == out_dims[i]);
2610
+ }}
2584
2611
}}
2585
- Py_XDECREF({ out } );
2586
- { out } = (PyArrayObject *)PyArray_Concatenate(list, axis);
2587
- Py_DECREF(list);
2588
- if(!{ out } ){{
2589
- { fail }
2612
+
2613
+ if (!out_is_valid) {{
2614
+ Py_XDECREF({ out } );
2615
+ { out } = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2616
+ PyArray_DescrFromType({ out_dtype } ),
2617
+ { ndim } ,
2618
+ out_dims,
2619
+ NULL, /* strides */
2620
+ NULL, /* data */
2621
+ NPY_ARRAY_DEFAULT,
2622
+ NULL);
2623
+
2624
+ if ({ out } == NULL) {{
2625
+ { fail }
2626
+ }}
2627
+ }}
2628
+
2629
+ // Copy data into output buffer
2630
+ for (int i = 0; i < { n } ; i++) {{
2631
+ PyArrayObject *arr = arrays[i];
2632
+
2633
+ // Create view into output buffer
2634
+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2635
+ Py_INCREF(PyArray_DESCR({ out } ));
2636
+ PyArrayObject *view = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
2637
+ PyArray_DESCR({ out } ),
2638
+ { ndim } ,
2639
+ PyArray_SHAPE(arr),
2640
+ PyArray_STRIDES({ out } ),
2641
+ PyArray_BYTES({ out } ) + (offset * PyArray_STRIDES({ out } )[axis]),
2642
+ NPY_ARRAY_WRITEABLE,
2643
+ NULL);
2644
+ if (view == NULL) {{
2645
+ { fail }
2646
+ }}
2647
+
2648
+ // Write to it
2649
+ int success = PyArray_CopyInto(view, arr);
2650
+ Py_DECREF(view);
2651
+ if (success != 0) {{
2652
+ { fail }
2653
+ }}
2654
+
2655
+ offset += PyArray_DIM(arr, axis);
2590
2656
}}
2591
- }}
2592
2657
"""
2593
2658
return code
2594
2659
0 commit comments