Skip to content

Commit 25cddba

Browse files
committed
fix contraction sparse dense to always obtain a zero
1 parent 44fb65a commit 25cddba

File tree

4 files changed

+133
-90
lines changed

4 files changed

+133
-90
lines changed

.python-version

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.13

flake.nix

+3
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@
180180
pkgs.git
181181
# pkgs.ripgrep
182182
pkgs.quarto
183+
pkgs.nodejs
184+
pkgs.uv
185+
pkgs.python313
183186
pkgs.deno
184187
pkgs.jujutsu
185188
pkgs.nixd

src/contraction.rs

+90-80
Original file line numberDiff line numberDiff line change
@@ -91,32 +91,35 @@ where
9191
T,
9292
Out: FallibleAddAssign<U::Out> + FallibleSubAssign<U::Out> + Clone + RefZero + IsZero,
9393
>,
94+
T: TrySmallestUpgrade<U, LCM = U::Out>,
9495
I: TensorStructure + Clone + StructureContract,
9596
{
9697
type LCM = DenseTensor<U::Out, I>;
9798

9899
fn exterior_product(&self, other: &DenseTensor<T, I>) -> Result<Self::LCM, ContractionError> {
99100
let mut final_structure = self.structure().clone();
100101
final_structure.merge(other.structure());
101-
if let Some((_, s)) = self.flat_iter().next() {
102-
let zero = s.try_upgrade().unwrap().as_ref().ref_zero();
103-
let mut out = DenseTensor {
104-
data: vec![zero.clone(); final_structure.size()?],
105-
structure: final_structure,
106-
};
102+
let zero = if let Some((_, s)) = self.flat_iter().next() {
103+
s.try_upgrade().unwrap().as_ref().ref_zero()
104+
} else if let Some((_, o)) = other.iter_flat().next() {
105+
o.try_upgrade().unwrap().as_ref().ref_zero()
106+
} else {
107+
return Err(ContractionError::EmptySparse);
108+
};
109+
let mut out = DenseTensor {
110+
data: vec![zero.clone(); final_structure.size()?],
111+
structure: final_structure,
112+
};
107113

108-
let stride = other.size()?;
114+
let stride = other.size()?;
109115

110-
for (i, u) in self.flat_iter() {
111-
for (j, t) in other.flat_iter() {
112-
let _ = out.set_flat(i * stride + j, u.mul_fallible(t).unwrap());
113-
}
116+
for (i, u) in self.flat_iter() {
117+
for (j, t) in other.flat_iter() {
118+
let _ = out.set_flat(i * stride + j, u.mul_fallible(t).unwrap());
114119
}
115-
116-
Ok(out)
117-
} else {
118-
Err(ContractionError::EmptySparse)
119120
}
121+
122+
Ok(out)
120123
}
121124
}
122125

@@ -529,6 +532,7 @@ where
529532
T,
530533
Out: FallibleAddAssign<U::Out> + FallibleSubAssign<U::Out> + Clone + RefZero + IsZero,
531534
>,
535+
T: TrySmallestUpgrade<U, LCM = U::Out>,
532536
I: TensorStructure + Clone + StructureContract,
533537
{
534538
type LCM = DenseTensor<U::Out, I>;
@@ -540,45 +544,48 @@ where
540544
j: usize,
541545
) -> Result<Self::LCM, ContractionError> {
542546
trace!("single contract sparse dense");
543-
if let Some((_, s)) = self.flat_iter().next() {
544-
let zero = s.try_upgrade().unwrap().as_ref().ref_zero();
545-
let final_structure = self.structure.merge_at(&other.structure, (i, j));
546-
let mut result_data = vec![zero.clone(); final_structure.size()?];
547-
let mut result_index = 0;
547+
let zero = if let Some((_, s)) = self.flat_iter().next() {
548+
s.try_upgrade().unwrap().as_ref().ref_zero()
549+
} else if let Some((_, o)) = other.iter_flat().next() {
550+
o.try_upgrade().unwrap().as_ref().ref_zero()
551+
} else {
552+
return Err(ContractionError::EmptySparse);
553+
};
548554

549-
let mut self_iter = self.fiber_class(i.into()).iter();
550-
let mut other_iter = other.fiber_class(j.into()).iter();
555+
let final_structure = self.structure.merge_at(&other.structure, (i, j));
556+
let mut result_data = vec![zero.clone(); final_structure.size()?];
557+
let mut result_index = 0;
551558

552-
let fiber_representation = self.reps()[i];
559+
let mut self_iter = self.fiber_class(i.into()).iter();
560+
let mut other_iter = other.fiber_class(j.into()).iter();
553561

554-
for mut fiber_a in self_iter.by_ref() {
555-
for mut fiber_b in other_iter.by_ref() {
556-
for (k, (a, skip, _)) in fiber_a.by_ref().enumerate() {
557-
if let Some((b, _)) = fiber_b.by_ref().nth(skip) {
558-
if fiber_representation.is_neg(k + skip) {
559-
result_data[result_index]
560-
.sub_assign_fallible(&a.mul_fallible(b).unwrap());
561-
} else {
562-
result_data[result_index]
563-
.add_assign_fallible(&a.mul_fallible(b).unwrap());
564-
}
562+
let fiber_representation = self.reps()[i];
563+
564+
for mut fiber_a in self_iter.by_ref() {
565+
for mut fiber_b in other_iter.by_ref() {
566+
for (k, (a, skip, _)) in fiber_a.by_ref().enumerate() {
567+
if let Some((b, _)) = fiber_b.by_ref().nth(skip) {
568+
if fiber_representation.is_neg(k + skip) {
569+
result_data[result_index]
570+
.sub_assign_fallible(&a.mul_fallible(b).unwrap());
571+
} else {
572+
result_data[result_index]
573+
.add_assign_fallible(&a.mul_fallible(b).unwrap());
565574
}
566575
}
567-
result_index += 1;
568-
fiber_a.reset();
569576
}
570-
other_iter.reset();
577+
result_index += 1;
578+
fiber_a.reset();
571579
}
580+
other_iter.reset();
581+
}
572582

573-
let result = DenseTensor {
574-
data: result_data,
575-
structure: final_structure,
576-
};
583+
let result = DenseTensor {
584+
data: result_data,
585+
structure: final_structure,
586+
};
577587

578-
Ok(result)
579-
} else {
580-
Err(ContractionError::EmptySparse)
581-
}
588+
Ok(result)
582589
}
583590
}
584591

@@ -643,56 +650,59 @@ where
643650
T,
644651
Out: FallibleAddAssign<U::Out> + FallibleSubAssign<U::Out> + Clone + RefZero + IsZero,
645652
>,
653+
T: TrySmallestUpgrade<U, LCM = U::Out>,
646654
I: TensorStructure + Clone + StructureContract,
647655
{
648656
type LCM = DenseTensor<U::Out, I>;
649657
fn multi_contract(&self, other: &DenseTensor<T, I>) -> Result<Self::LCM, ContractionError> {
650658
trace!("multi contract sparse dense");
651-
if let Some((_, s)) = self.flat_iter().next() {
652-
let zero = s.try_upgrade().unwrap().as_ref().ref_zero();
653-
// let zero = other.data[0].try_upgrade().unwrap().as_ref().ref_zero();
654-
let (permutation, self_matches, other_matches) =
655-
self.structure().match_indices(other.structure()).unwrap();
659+
let zero = if let Some((_, s)) = self.flat_iter().next() {
660+
s.try_upgrade().unwrap().as_ref().ref_zero()
661+
} else if let Some((_, o)) = other.iter_flat().next() {
662+
o.try_upgrade().unwrap().as_ref().ref_zero()
663+
} else {
664+
return Err(ContractionError::EmptySparse);
665+
};
666+
// let zero = other.data[0].try_upgrade().unwrap().as_ref().ref_zero();
667+
let (permutation, self_matches, other_matches) =
668+
self.structure().match_indices(other.structure()).unwrap();
656669

657-
let mut final_structure = self.structure.clone();
658-
let _ = final_structure.merge(&other.structure);
670+
let mut final_structure = self.structure.clone();
671+
let _ = final_structure.merge(&other.structure);
659672

660-
let mut result_data = vec![zero.clone(); final_structure.size()?];
661-
let mut result_index = 0;
673+
let mut result_data = vec![zero.clone(); final_structure.size()?];
674+
let mut result_index = 0;
662675

663-
let selfiter = self
664-
.fiber_class(self_matches.as_slice().into())
665-
.iter_perm_metric(permutation);
666-
let mut other_iter = other.fiber_class(other_matches.as_slice().into()).iter();
676+
let selfiter = self
677+
.fiber_class(self_matches.as_slice().into())
678+
.iter_perm_metric(permutation);
679+
let mut other_iter = other.fiber_class(other_matches.as_slice().into()).iter();
667680

668-
for mut fiber_a in selfiter {
669-
for mut fiber_b in other_iter.by_ref() {
670-
for (a, skip, (neg, _)) in fiber_a.by_ref() {
671-
if let Some((b, _)) = fiber_b.by_ref().nth(skip) {
672-
if neg {
673-
result_data[result_index]
674-
.sub_assign_fallible(&a.mul_fallible(b).unwrap());
675-
} else {
676-
result_data[result_index]
677-
.add_assign_fallible(&a.mul_fallible(b).unwrap());
678-
}
681+
for mut fiber_a in selfiter {
682+
for mut fiber_b in other_iter.by_ref() {
683+
for (a, skip, (neg, _)) in fiber_a.by_ref() {
684+
if let Some((b, _)) = fiber_b.by_ref().nth(skip) {
685+
if neg {
686+
result_data[result_index]
687+
.sub_assign_fallible(&a.mul_fallible(b).unwrap());
688+
} else {
689+
result_data[result_index]
690+
.add_assign_fallible(&a.mul_fallible(b).unwrap());
679691
}
680692
}
681-
result_index += 1;
682-
fiber_a.reset();
683693
}
684-
other_iter.reset();
694+
result_index += 1;
695+
fiber_a.reset();
685696
}
697+
other_iter.reset();
698+
}
686699

687-
let result = DenseTensor {
688-
data: result_data,
689-
structure: final_structure,
690-
};
700+
let result = DenseTensor {
701+
data: result_data,
702+
structure: final_structure,
703+
};
691704

692-
Ok(result)
693-
} else {
694-
Err(ContractionError::EmptySparse)
695-
}
705+
Ok(result)
696706
}
697707
}
698708

src/network.rs

+39-10
Original file line numberDiff line numberDiff line change
@@ -2655,6 +2655,8 @@ pub enum TensorNetworkError {
26552655
FailedContract(ContractionError),
26562656
#[error("negative exponent not yet supported")]
26572657
NegativeExponent,
2658+
#[error("failed to contract: {0}")]
2659+
FailedContractMsg(String),
26582660
#[error(transparent)]
26592661
Other(#[from] anyhow::Error),
26602662
#[error("Io error")]
@@ -2920,7 +2922,13 @@ where
29202922
for arg in value.iter() {
29212923
let mut net = Self::try_from(arg)?;
29222924
// trace!("mul net: {}", net.dot_nodes());
2923-
net.contract();
2925+
//
2926+
if net.contract().is_err() {
2927+
return Err(TensorNetworkError::FailedContractMsg(
2928+
format!("Mul failed: {}", arg).into(),
2929+
));
2930+
}
2931+
29242932
if let Some(ref s) = net.scalar {
29252933
has_scalar = true;
29262934
scalars = scalars.mul_fallible(s).unwrap();
@@ -3014,7 +3022,11 @@ where
30143022
}
30153023
let mut net = Self::try_from(base)?;
30163024

3017-
net.contract();
3025+
if net.contract().is_err() {
3026+
return Err(TensorNetworkError::FailedContractMsg(
3027+
format!("Pow failed: {}", base).into(),
3028+
));
3029+
}
30183030

30193031
match net.result() {
30203032
Ok((res, _s)) => {
@@ -3025,7 +3037,13 @@ where
30253037
} else {
30263038
new.push(res.clone());
30273039
}
3028-
new.contract();
3040+
3041+
if new.contract().is_err() {
3042+
return Err(TensorNetworkError::FailedContractMsg(
3043+
value.as_view().to_string().into(),
3044+
));
3045+
}
3046+
30293047
n -= 1;
30303048
}
30313049
}
@@ -3108,7 +3126,13 @@ where
31083126
for summand in value.iter() {
31093127
// trace!("summand: {}", summand);
31103128
let mut net = Self::try_from(summand)?;
3111-
net.contract();
3129+
3130+
if net.contract().is_err() {
3131+
return Err(TensorNetworkError::FailedContractMsg(
3132+
format!("Sum failed: {}", summand).into(),
3133+
));
3134+
}
3135+
31123136
match net.result() {
31133137
Ok((mut t, s)) => {
31143138
if let Some(s) = s {
@@ -3325,27 +3349,32 @@ where
33253349
T: Contract<T, LCM = T> + HasStructure,
33263350
T::Structure: TensorStructure<Slot: Serialize + for<'a> Deserialize<'a>>,
33273351
{
3328-
pub fn contract_algo(&mut self, edge_choice: impl Fn(&Self) -> Option<HedgeId>) {
3352+
pub fn contract_algo(
3353+
&mut self,
3354+
edge_choice: impl Fn(&Self) -> Option<HedgeId>,
3355+
) -> Result<(), ContractionError> {
33293356
if let Some(e) = edge_choice(self) {
3330-
self.contract_edge(e);
3357+
self.contract_edge(e)?;
33313358

33323359
// println!("{}", self.dot());
3333-
self.contract_algo(edge_choice);
3360+
self.contract_algo(edge_choice)?;
33343361
}
3362+
Ok(())
33353363
}
3336-
fn contract_edge(&mut self, edge_idx: HedgeId) {
3364+
fn contract_edge(&mut self, edge_idx: HedgeId) -> Result<(), ContractionError> {
33373365
let a = self.graph.nodemap[edge_idx];
33383366
let b = self.graph.nodemap[self.graph.involution[edge_idx].data];
33393367

33403368
let ai = self.graph.nodes.get(a).unwrap();
33413369
let bi = self.graph.nodes.get(b).unwrap();
33423370

3343-
let f = ai.contract(bi).unwrap();
3371+
let f = ai.contract(bi)?;
33443372

33453373
self.graph.merge_nodes(a, b, f);
3374+
Ok(())
33463375
}
33473376

3348-
pub fn contract(&mut self) {
3377+
pub fn contract(&mut self) -> std::result::Result<(), ContractionError> {
33493378
self.contract_algo(Self::edge_to_min_degree_node)
33503379
}
33513380
}

0 commit comments

Comments
 (0)