Skip to content

Commit 9a9fbbc

Browse files
committed
Don't apply local_upcast_elemwise_constant_inputs when all inputs are constant
1 parent aaecd79 commit 9a9fbbc

File tree

1 file changed

+54
-54
lines changed

1 file changed

+54
-54
lines changed

pytensor/tensor/rewriting/elemwise.py

+54-54
Original file line numberDiff line numberDiff line change
@@ -493,65 +493,65 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
493493
"""
494494
if len(node.outputs) > 1:
495495
return
496-
try:
497-
shape_i = fgraph.shape_feature.shape_i
498-
except AttributeError:
499-
shape_i = None
500-
if isinstance(node.op, Elemwise):
501-
scalar_op = node.op.scalar_op
502-
# print "aa", scalar_op.output_types_preference
503-
if getattr(scalar_op, "output_types_preference", None) in (
504-
ps.upgrade_to_float,
505-
ps.upcast_out,
506-
):
507-
# this is the kind of op that we can screw with the input
508-
# dtypes by upcasting explicitly
509-
output_dtype = node.outputs[0].type.dtype
510-
new_inputs = []
511-
for i in node.inputs:
512-
if i.type.dtype == output_dtype:
513-
new_inputs.append(i)
514-
else:
515-
try:
516-
cval_i = get_underlying_scalar_constant_value(
517-
i, only_process_constants=True
496+
497+
if all(isinstance(i, Constant) for i in node.inputs):
498+
# If all inputs are constant, constant_fold will take care of it
499+
return
500+
501+
if getattr(node.op.scalar_op, "output_types_preference", None) in (
502+
ps.upgrade_to_float,
503+
ps.upcast_out,
504+
):
505+
# this is the kind of op that we can screw with the input
506+
# dtypes by upcasting explicitly
507+
output_dtype = node.outputs[0].type.dtype
508+
new_inputs = []
509+
for i in node.inputs:
510+
if i.type.dtype == output_dtype:
511+
new_inputs.append(i)
512+
else:
513+
try:
514+
cval_i = get_underlying_scalar_constant_value(
515+
i, only_process_constants=True
516+
)
517+
if all(i.broadcastable):
518+
new_inputs.append(
519+
shape_padleft(cast(cval_i, output_dtype), i.ndim)
518520
)
519-
if all(i.broadcastable):
520-
new_inputs.append(
521-
shape_padleft(cast(cval_i, output_dtype), i.ndim)
522-
)
523-
else:
524-
if shape_i is None:
525-
return
526-
new_inputs.append(
527-
alloc(
528-
cast(cval_i, output_dtype),
529-
*[shape_i(d)(i) for d in range(i.ndim)],
530-
)
521+
else:
522+
try:
523+
shape_i = fgraph.shape_feature.shape_i
524+
except AttributeError:
525+
return
526+
new_inputs.append(
527+
alloc(
528+
cast(cval_i, output_dtype),
529+
*[shape_i(d)(i) for d in range(i.ndim)],
531530
)
532-
# print >> sys.stderr, "AAA",
533-
# *[Shape_i(d)(i) for d in range(i.ndim)]
534-
except NotScalarConstantError:
535-
# for the case of a non-scalar
536-
if isinstance(i, TensorConstant):
537-
new_inputs.append(cast(i, output_dtype))
538-
else:
539-
new_inputs.append(i)
531+
)
532+
# print >> sys.stderr, "AAA",
533+
# *[Shape_i(d)(i) for d in range(i.ndim)]
534+
except NotScalarConstantError:
535+
# for the case of a non-scalar
536+
if isinstance(i, TensorConstant):
537+
new_inputs.append(cast(i, output_dtype))
538+
else:
539+
new_inputs.append(i)
540540

541-
if new_inputs != node.inputs:
542-
rval = [node.op(*new_inputs)]
543-
if not node.outputs[0].type.is_super(rval[0].type):
544-
# This can happen for example when floatX=float32
545-
# and we do the true division between and int64
546-
# and a constant that will get typed as int8.
541+
if new_inputs != node.inputs:
542+
rval = [node.op(*new_inputs)]
543+
if not node.outputs[0].type.is_super(rval[0].type):
544+
# This can happen for example when floatX=float32
545+
# and we do the true division between and int64
546+
# and a constant that will get typed as int8.
547547

548-
# As this is just to allow merging more case, if
549-
# the upcast don't work, we can just skip it.
550-
return
548+
# As this is just to allow merging more case, if
549+
# the upcast don't work, we can just skip it.
550+
return
551551

552-
# Copy over output stacktrace from before upcasting
553-
copy_stack_trace(node.outputs[0], rval)
554-
return rval
552+
# Copy over output stacktrace from before upcasting
553+
copy_stack_trace(node.outputs[0], rval)
554+
return rval
555555

556556

557557
@node_rewriter([Elemwise])

0 commit comments

Comments
 (0)