@@ -479,13 +479,20 @@ impl Builder<'_, '_> {
479
479
let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
480
480
481
481
let glsl = self . ext_inst . borrow_mut ( ) . import_glsl ( self ) ;
482
- let find_xsb = |arg| {
482
+ let find_xsb = |arg, offset : i32 | {
483
483
if trailing {
484
- self . emit ( )
484
+ let lsb = self
485
+ . emit ( )
485
486
. ext_inst ( u32, None , glsl, GLOp :: FindILsb as u32 , [ Operand :: IdRef (
486
487
arg,
487
488
) ] )
488
- . unwrap ( )
489
+ . unwrap ( ) ;
490
+ if offset == 0 {
491
+ lsb
492
+ } else {
493
+ let const_offset = self . constant_i32 ( self . span ( ) , offset) . def ( self ) ;
494
+ self . emit ( ) . i_add ( u32, None , const_offset, lsb) . unwrap ( )
495
+ }
489
496
} else {
490
497
// rust is always unsigned, so FindUMsb
491
498
let msb_bit = self
@@ -496,25 +503,21 @@ impl Builder<'_, '_> {
496
503
. unwrap ( ) ;
497
504
// the glsl op returns the Msb bit, not the amount of leading zeros of this u32
498
505
// leading zeros = 31 - Msb bit
499
- let u32_31 = self . constant_u32 ( self . span ( ) , 31 ) . def ( self ) ;
500
- self . emit ( ) . i_sub ( u32, None , u32_31 , msb_bit) . unwrap ( )
506
+ let const_offset = self . constant_i32 ( self . span ( ) , 31 - offset ) . def ( self ) ;
507
+ self . emit ( ) . i_sub ( u32, None , const_offset , msb_bit) . unwrap ( )
501
508
}
502
509
} ;
503
510
504
511
let converted = match bits {
505
512
8 | 16 => {
513
+ let arg = self . emit ( ) . u_convert ( u32, None , arg. def ( self ) ) . unwrap ( ) ;
506
514
if trailing {
507
- let arg = self . emit ( ) . u_convert ( u32, None , arg. def ( self ) ) . unwrap ( ) ;
508
- find_xsb ( arg)
515
+ find_xsb ( arg, 0 )
509
516
} else {
510
- let arg = arg. def ( self ) ;
511
- let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
512
- let xsb = find_xsb ( arg) ;
513
- let subtrahend = self . constant_u32 ( self . span ( ) , 32 - bits) . def ( self ) ;
514
- self . emit ( ) . i_sub ( u32, None , xsb, subtrahend) . unwrap ( )
517
+ find_xsb ( arg, bits as i32 - 32 )
515
518
}
516
519
}
517
- 32 => find_xsb ( arg. def ( self ) ) ,
520
+ 32 => find_xsb ( arg. def ( self ) , 0 ) ,
518
521
64 => {
519
522
let u32_0 = self . constant_int ( u32, 0 ) . def ( self ) ;
520
523
let u32_32 = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
@@ -527,20 +530,17 @@ impl Builder<'_, '_> {
527
530
. unwrap ( ) ;
528
531
let higher = self . emit ( ) . u_convert ( u32, None , higher) . unwrap ( ) ;
529
532
530
- let lower_bits = find_xsb ( lower) ;
531
- let higher_bits = find_xsb ( higher) ;
532
-
533
533
if trailing {
534
534
let use_lower = self . emit ( ) . i_equal ( bool, None , higher, u32_0) . unwrap ( ) ;
535
- let lower_bits =
536
- self . emit ( ) . i_add ( u32 , None , lower_bits , u32_32 ) . unwrap ( ) ;
535
+ let lower_bits = find_xsb ( lower , 32 ) ;
536
+ let higher_bits = find_xsb ( higher , 0 ) ;
537
537
self . emit ( )
538
538
. select ( u32, None , use_lower, lower_bits, higher_bits)
539
539
. unwrap ( )
540
540
} else {
541
541
let use_higher = self . emit ( ) . i_equal ( bool, None , lower, u32_0) . unwrap ( ) ;
542
- let higher_bits =
543
- self . emit ( ) . i_add ( u32 , None , higher_bits , u32_32 ) . unwrap ( ) ;
542
+ let lower_bits = find_xsb ( lower , 0 ) ;
543
+ let higher_bits = find_xsb ( higher , 32 ) ;
544
544
self . emit ( )
545
545
. select ( u32, None , use_higher, higher_bits, lower_bits)
546
546
. unwrap ( )
0 commit comments