Skip to content

Commit 1d306f4

Browse files
committed
Simplify Numba implementation of Alloc
1 parent 0b56ed9 commit 1d306f4

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
6868
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
6969
shapes_to_items_src = indent(
7070
"\n".join(
71-
f"{item_name} = to_scalar({shape_name})"
71+
f"{item_name} = {shape_name}.item()"
7272
for item_name, shape_name in zip(
7373
shape_var_item_names, shape_var_names, strict=True
7474
)
@@ -86,12 +86,11 @@ def numba_funcify_Alloc(op, node, **kwargs):
8686

8787
alloc_def_src = f"""
8888
def alloc(val, {", ".join(shape_var_names)}):
89-
val_np = np.asarray(val)
9089
{shapes_to_items_src}
9190
scalar_shape = {create_tuple_string(shape_var_item_names)}
9291
{check_runtime_broadcast_src}
93-
res = np.empty(scalar_shape, dtype=val_np.dtype)
94-
res[...] = val_np
92+
res = np.empty(scalar_shape, dtype=val.dtype)
93+
res[...] = val
9594
return res
9695
"""
9796
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})

0 commit comments

Comments
 (0)