diff --git a/src/jax_finufft/lowering.py b/src/jax_finufft/lowering.py index 9981e33..365973d 100644 --- a/src/jax_finufft/lowering.py +++ b/src/jax_finufft/lowering.py @@ -56,14 +56,14 @@ def lowering( n_transf = source_shape[1] n_j = points_shape[0][1] + # Dispatch to the correct custom call target depending on the dimension, + # dtype, and NUFFT type. if output_shape is None: # Type 2 op_name = f"nufft{ndim}d2{suffix}".encode("ascii") n_k = np.array(source_shape[2:], dtype=np.int64) - full_output_shape = tuple(source_shape[:2]) + (n_j,) else: # Type 1 op_name = f"nufft{ndim}d1{suffix}".encode("ascii") n_k = np.array(output_shape, dtype=np.int64) - full_output_shape = tuple(source_shape[:2]) + tuple(output_shape) # The backend expects the output shape in Fortran order, so we'll just # fake it here, by sending in n_k and x in the reverse order. @@ -87,7 +87,7 @@ def lowering( # Reverse points because backend uses Fortran order operands=[descriptor, source, *points[::-1]], operand_layouts=default_layouts([0], source_shape, *points_shape[::-1]), - result_layouts=default_layouts(full_output_shape), + result_layouts=default_layouts(ctx.avals_out[0].shape), ).results else: @@ -102,5 +102,5 @@ def lowering( operands=[source, *points[::-1]], backend_config=descriptor_bytes, operand_layouts=default_layouts(source_shape, *points_shape[::-1]), - result_layouts=default_layouts(full_output_shape), + result_layouts=default_layouts(ctx.avals_out[0].shape), ).results