Skip to content

Commit

Permalink
one more comment and simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed May 2, 2024
1 parent efc75a2 commit 0884107
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/jax_finufft/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def lowering(
n_transf = source_shape[1]
n_j = points_shape[0][1]

# Dispatch to the correct custom call target depending on the dimension,
# dtype, and NUFFT type.
if output_shape is None: # Type 2
op_name = f"nufft{ndim}d2{suffix}".encode("ascii")
n_k = np.array(source_shape[2:], dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + (n_j,)
else: # Type 1
op_name = f"nufft{ndim}d1{suffix}".encode("ascii")
n_k = np.array(output_shape, dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + tuple(output_shape)

# The backend expects the output shape in Fortran order, so we'll just
# fake it here, by sending in n_k and x in the reverse order.
Expand All @@ -87,7 +87,7 @@ def lowering(
# Reverse points because backend uses Fortran order
operands=[descriptor, source, *points[::-1]],
operand_layouts=default_layouts([0], source_shape, *points_shape[::-1]),
result_layouts=default_layouts(full_output_shape),
result_layouts=default_layouts(ctx.avals_out[0].shape),
).results

else:
Expand All @@ -102,5 +102,5 @@ def lowering(
operands=[source, *points[::-1]],
backend_config=descriptor_bytes,
operand_layouts=default_layouts(source_shape, *points_shape[::-1]),
result_layouts=default_layouts(full_output_shape),
result_layouts=default_layouts(ctx.avals_out[0].shape),
).results

0 comments on commit 0884107

Please sign in to comment.