@@ -91,32 +91,35 @@ where
91
91
T ,
92
92
Out : FallibleAddAssign < U :: Out > + FallibleSubAssign < U :: Out > + Clone + RefZero + IsZero ,
93
93
> ,
94
+ T : TrySmallestUpgrade < U , LCM = U :: Out > ,
94
95
I : TensorStructure + Clone + StructureContract ,
95
96
{
96
97
type LCM = DenseTensor < U :: Out , I > ;
97
98
98
99
fn exterior_product ( & self , other : & DenseTensor < T , I > ) -> Result < Self :: LCM , ContractionError > {
99
100
let mut final_structure = self . structure ( ) . clone ( ) ;
100
101
final_structure. merge ( other. structure ( ) ) ;
101
- if let Some ( ( _, s) ) = self . flat_iter ( ) . next ( ) {
102
- let zero = s. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( ) ;
103
- let mut out = DenseTensor {
104
- data : vec ! [ zero. clone( ) ; final_structure. size( ) ?] ,
105
- structure : final_structure,
106
- } ;
102
+ let zero = if let Some ( ( _, s) ) = self . flat_iter ( ) . next ( ) {
103
+ s. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( )
104
+ } else if let Some ( ( _, o) ) = other. iter_flat ( ) . next ( ) {
105
+ o. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( )
106
+ } else {
107
+ return Err ( ContractionError :: EmptySparse ) ;
108
+ } ;
109
+ let mut out = DenseTensor {
110
+ data : vec ! [ zero. clone( ) ; final_structure. size( ) ?] ,
111
+ structure : final_structure,
112
+ } ;
107
113
108
- let stride = other. size ( ) ?;
114
+ let stride = other. size ( ) ?;
109
115
110
- for ( i, u) in self . flat_iter ( ) {
111
- for ( j, t) in other. flat_iter ( ) {
112
- let _ = out. set_flat ( i * stride + j, u. mul_fallible ( t) . unwrap ( ) ) ;
113
- }
116
+ for ( i, u) in self . flat_iter ( ) {
117
+ for ( j, t) in other. flat_iter ( ) {
118
+ let _ = out. set_flat ( i * stride + j, u. mul_fallible ( t) . unwrap ( ) ) ;
114
119
}
115
-
116
- Ok ( out)
117
- } else {
118
- Err ( ContractionError :: EmptySparse )
119
120
}
121
+
122
+ Ok ( out)
120
123
}
121
124
}
122
125
@@ -529,6 +532,7 @@ where
529
532
T ,
530
533
Out : FallibleAddAssign < U :: Out > + FallibleSubAssign < U :: Out > + Clone + RefZero + IsZero ,
531
534
> ,
535
+ T : TrySmallestUpgrade < U , LCM = U :: Out > ,
532
536
I : TensorStructure + Clone + StructureContract ,
533
537
{
534
538
type LCM = DenseTensor < U :: Out , I > ;
@@ -540,45 +544,48 @@ where
540
544
j : usize ,
541
545
) -> Result < Self :: LCM , ContractionError > {
542
546
trace ! ( "single contract sparse dense" ) ;
543
- if let Some ( ( _, s) ) = self . flat_iter ( ) . next ( ) {
544
- let zero = s. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( ) ;
545
- let final_structure = self . structure . merge_at ( & other. structure , ( i, j) ) ;
546
- let mut result_data = vec ! [ zero. clone( ) ; final_structure. size( ) ?] ;
547
- let mut result_index = 0 ;
547
+ let zero = if let Some ( ( _, s) ) = self . flat_iter ( ) . next ( ) {
548
+ s. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( )
549
+ } else if let Some ( ( _, o) ) = other. iter_flat ( ) . next ( ) {
550
+ o. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( )
551
+ } else {
552
+ return Err ( ContractionError :: EmptySparse ) ;
553
+ } ;
548
554
549
- let mut self_iter = self . fiber_class ( i. into ( ) ) . iter ( ) ;
550
- let mut other_iter = other. fiber_class ( j. into ( ) ) . iter ( ) ;
555
+ let final_structure = self . structure . merge_at ( & other. structure , ( i, j) ) ;
556
+ let mut result_data = vec ! [ zero. clone( ) ; final_structure. size( ) ?] ;
557
+ let mut result_index = 0 ;
551
558
552
- let fiber_representation = self . reps ( ) [ i] ;
559
+ let mut self_iter = self . fiber_class ( i. into ( ) ) . iter ( ) ;
560
+ let mut other_iter = other. fiber_class ( j. into ( ) ) . iter ( ) ;
553
561
554
- for mut fiber_a in self_iter. by_ref ( ) {
555
- for mut fiber_b in other_iter. by_ref ( ) {
556
- for ( k, ( a, skip, _) ) in fiber_a. by_ref ( ) . enumerate ( ) {
557
- if let Some ( ( b, _) ) = fiber_b. by_ref ( ) . nth ( skip) {
558
- if fiber_representation. is_neg ( k + skip) {
559
- result_data[ result_index]
560
- . sub_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
561
- } else {
562
- result_data[ result_index]
563
- . add_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
564
- }
562
+ let fiber_representation = self . reps ( ) [ i] ;
563
+
564
+ for mut fiber_a in self_iter. by_ref ( ) {
565
+ for mut fiber_b in other_iter. by_ref ( ) {
566
+ for ( k, ( a, skip, _) ) in fiber_a. by_ref ( ) . enumerate ( ) {
567
+ if let Some ( ( b, _) ) = fiber_b. by_ref ( ) . nth ( skip) {
568
+ if fiber_representation. is_neg ( k + skip) {
569
+ result_data[ result_index]
570
+ . sub_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
571
+ } else {
572
+ result_data[ result_index]
573
+ . add_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
565
574
}
566
575
}
567
- result_index += 1 ;
568
- fiber_a. reset ( ) ;
569
576
}
570
- other_iter. reset ( ) ;
577
+ result_index += 1 ;
578
+ fiber_a. reset ( ) ;
571
579
}
580
+ other_iter. reset ( ) ;
581
+ }
572
582
573
- let result = DenseTensor {
574
- data : result_data,
575
- structure : final_structure,
576
- } ;
583
+ let result = DenseTensor {
584
+ data : result_data,
585
+ structure : final_structure,
586
+ } ;
577
587
578
- Ok ( result)
579
- } else {
580
- Err ( ContractionError :: EmptySparse )
581
- }
588
+ Ok ( result)
582
589
}
583
590
}
584
591
@@ -643,56 +650,59 @@ where
643
650
T ,
644
651
Out : FallibleAddAssign < U :: Out > + FallibleSubAssign < U :: Out > + Clone + RefZero + IsZero ,
645
652
> ,
653
+ T : TrySmallestUpgrade < U , LCM = U :: Out > ,
646
654
I : TensorStructure + Clone + StructureContract ,
647
655
{
648
656
type LCM = DenseTensor < U :: Out , I > ;
649
657
fn multi_contract ( & self , other : & DenseTensor < T , I > ) -> Result < Self :: LCM , ContractionError > {
650
658
trace ! ( "multi contract sparse dense" ) ;
651
- if let Some ( ( _, s) ) = self . flat_iter ( ) . next ( ) {
652
- let zero = s. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( ) ;
653
- // let zero = other.data[0].try_upgrade().unwrap().as_ref().ref_zero();
654
- let ( permutation, self_matches, other_matches) =
655
- self . structure ( ) . match_indices ( other. structure ( ) ) . unwrap ( ) ;
659
+ let zero = if let Some ( ( _, s) ) = self . flat_iter ( ) . next ( ) {
660
+ s. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( )
661
+ } else if let Some ( ( _, o) ) = other. iter_flat ( ) . next ( ) {
662
+ o. try_upgrade ( ) . unwrap ( ) . as_ref ( ) . ref_zero ( )
663
+ } else {
664
+ return Err ( ContractionError :: EmptySparse ) ;
665
+ } ;
666
+ // let zero = other.data[0].try_upgrade().unwrap().as_ref().ref_zero();
667
+ let ( permutation, self_matches, other_matches) =
668
+ self . structure ( ) . match_indices ( other. structure ( ) ) . unwrap ( ) ;
656
669
657
- let mut final_structure = self . structure . clone ( ) ;
658
- let _ = final_structure. merge ( & other. structure ) ;
670
+ let mut final_structure = self . structure . clone ( ) ;
671
+ let _ = final_structure. merge ( & other. structure ) ;
659
672
660
- let mut result_data = vec ! [ zero. clone( ) ; final_structure. size( ) ?] ;
661
- let mut result_index = 0 ;
673
+ let mut result_data = vec ! [ zero. clone( ) ; final_structure. size( ) ?] ;
674
+ let mut result_index = 0 ;
662
675
663
- let selfiter = self
664
- . fiber_class ( self_matches. as_slice ( ) . into ( ) )
665
- . iter_perm_metric ( permutation) ;
666
- let mut other_iter = other. fiber_class ( other_matches. as_slice ( ) . into ( ) ) . iter ( ) ;
676
+ let selfiter = self
677
+ . fiber_class ( self_matches. as_slice ( ) . into ( ) )
678
+ . iter_perm_metric ( permutation) ;
679
+ let mut other_iter = other. fiber_class ( other_matches. as_slice ( ) . into ( ) ) . iter ( ) ;
667
680
668
- for mut fiber_a in selfiter {
669
- for mut fiber_b in other_iter. by_ref ( ) {
670
- for ( a, skip, ( neg, _) ) in fiber_a. by_ref ( ) {
671
- if let Some ( ( b, _) ) = fiber_b. by_ref ( ) . nth ( skip) {
672
- if neg {
673
- result_data[ result_index]
674
- . sub_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
675
- } else {
676
- result_data[ result_index]
677
- . add_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
678
- }
681
+ for mut fiber_a in selfiter {
682
+ for mut fiber_b in other_iter. by_ref ( ) {
683
+ for ( a, skip, ( neg, _) ) in fiber_a. by_ref ( ) {
684
+ if let Some ( ( b, _) ) = fiber_b. by_ref ( ) . nth ( skip) {
685
+ if neg {
686
+ result_data[ result_index]
687
+ . sub_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
688
+ } else {
689
+ result_data[ result_index]
690
+ . add_assign_fallible ( & a. mul_fallible ( b) . unwrap ( ) ) ;
679
691
}
680
692
}
681
- result_index += 1 ;
682
- fiber_a. reset ( ) ;
683
693
}
684
- other_iter. reset ( ) ;
694
+ result_index += 1 ;
695
+ fiber_a. reset ( ) ;
685
696
}
697
+ other_iter. reset ( ) ;
698
+ }
686
699
687
- let result = DenseTensor {
688
- data : result_data,
689
- structure : final_structure,
690
- } ;
700
+ let result = DenseTensor {
701
+ data : result_data,
702
+ structure : final_structure,
703
+ } ;
691
704
692
- Ok ( result)
693
- } else {
694
- Err ( ContractionError :: EmptySparse )
695
- }
705
+ Ok ( result)
696
706
}
697
707
}
698
708
0 commit comments