@@ -493,65 +493,65 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
493
493
"""
494
494
if len (node .outputs ) > 1 :
495
495
return
496
- try :
497
- shape_i = fgraph .shape_feature .shape_i
498
- except AttributeError :
499
- shape_i = None
500
- if isinstance (node .op , Elemwise ):
501
- scalar_op = node .op .scalar_op
502
- # print "aa", scalar_op.output_types_preference
503
- if getattr (scalar_op , "output_types_preference" , None ) in (
504
- ps .upgrade_to_float ,
505
- ps .upcast_out ,
506
- ):
507
- # this is the kind of op that we can screw with the input
508
- # dtypes by upcasting explicitly
509
- output_dtype = node .outputs [0 ].type .dtype
510
- new_inputs = []
511
- for i in node .inputs :
512
- if i .type .dtype == output_dtype :
513
- new_inputs .append (i )
514
- else :
515
- try :
516
- cval_i = get_underlying_scalar_constant_value (
517
- i , only_process_constants = True
496
+
497
+ if all (isinstance (i , Constant ) for i in node .inputs ):
498
+ # If all inputs are constant, constant_fold will take care of it
499
+ return
500
+
501
+ if getattr (node .op .scalar_op , "output_types_preference" , None ) in (
502
+ ps .upgrade_to_float ,
503
+ ps .upcast_out ,
504
+ ):
505
+ # this is the kind of op that we can screw with the input
506
+ # dtypes by upcasting explicitly
507
+ output_dtype = node .outputs [0 ].type .dtype
508
+ new_inputs = []
509
+ for i in node .inputs :
510
+ if i .type .dtype == output_dtype :
511
+ new_inputs .append (i )
512
+ else :
513
+ try :
514
+ cval_i = get_underlying_scalar_constant_value (
515
+ i , only_process_constants = True
516
+ )
517
+ if all (i .broadcastable ):
518
+ new_inputs .append (
519
+ shape_padleft (cast (cval_i , output_dtype ), i .ndim )
518
520
)
519
- if all (i .broadcastable ):
520
- new_inputs .append (
521
- shape_padleft (cast (cval_i , output_dtype ), i .ndim )
522
- )
523
- else :
524
- if shape_i is None :
525
- return
526
- new_inputs .append (
527
- alloc (
528
- cast (cval_i , output_dtype ),
529
- * [shape_i (d )(i ) for d in range (i .ndim )],
530
- )
521
+ else :
522
+ try :
523
+ shape_i = fgraph .shape_feature .shape_i
524
+ except AttributeError :
525
+ return
526
+ new_inputs .append (
527
+ alloc (
528
+ cast (cval_i , output_dtype ),
529
+ * [shape_i (d )(i ) for d in range (i .ndim )],
531
530
)
532
- # print >> sys.stderr, "AAA",
533
- # *[Shape_i(d)(i) for d in range(i.ndim)]
534
- except NotScalarConstantError :
535
- # for the case of a non-scalar
536
- if isinstance (i , TensorConstant ):
537
- new_inputs .append (cast (i , output_dtype ))
538
- else :
539
- new_inputs .append (i )
531
+ )
532
+ # print >> sys.stderr, "AAA",
533
+ # *[Shape_i(d)(i) for d in range(i.ndim)]
534
+ except NotScalarConstantError :
535
+ # for the case of a non-scalar
536
+ if isinstance (i , TensorConstant ):
537
+ new_inputs .append (cast (i , output_dtype ))
538
+ else :
539
+ new_inputs .append (i )
540
540
541
- if new_inputs != node .inputs :
542
- rval = [node .op (* new_inputs )]
543
- if not node .outputs [0 ].type .is_super (rval [0 ].type ):
544
- # This can happen for example when floatX=float32
545
- # and we do the true division between and int64
546
- # and a constant that will get typed as int8.
541
+ if new_inputs != node .inputs :
542
+ rval = [node .op (* new_inputs )]
543
+ if not node .outputs [0 ].type .is_super (rval [0 ].type ):
544
+ # This can happen for example when floatX=float32
545
+ # and we do the true division between and int64
546
+ # and a constant that will get typed as int8.
547
547
548
- # As this is just to allow merging more case, if
549
- # the upcast don't work, we can just skip it.
550
- return
548
+ # As this is just to allow merging more case, if
549
+ # the upcast don't work, we can just skip it.
550
+ return
551
551
552
- # Copy over output stacktrace from before upcasting
553
- copy_stack_trace (node .outputs [0 ], rval )
554
- return rval
552
+ # Copy over output stacktrace from before upcasting
553
+ copy_stack_trace (node .outputs [0 ], rval )
554
+ return rval
555
555
556
556
557
557
@node_rewriter ([Elemwise ])
0 commit comments