Skip to content

Commit c52e6e5

Browse files
committed
add metadata
1 parent e3a89d7 commit c52e6e5

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

src/lib.rs

+28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use log::Logger;
66
use f128::f128;
77
use float::FloatLike;
88
use itertools::Itertools;
9+
use matrix::{DecompositionResult, SquareMatrix};
910
use preprocessing::{TropicalGraph, TropicalSubgraphTable};
1011
use rand::Rng;
1112
use sampling::{sample, SamplingError};
@@ -30,6 +31,7 @@ pub struct TropicalSamplingSettings {
3031
pub upcast_on_failure: bool,
3132
pub matrix_stability_test: Option<f64>,
3233
pub print_debug_info: bool,
34+
pub return_metadata: bool,
3335
}
3436

3537
impl Default for TropicalSamplingSettings {
@@ -38,6 +40,7 @@ impl Default for TropicalSamplingSettings {
3840
upcast_on_failure: true,
3941
matrix_stability_test: None,
4042
print_debug_info: false,
43+
return_metadata: false,
4144
}
4245
}
4346
}
@@ -72,6 +75,30 @@ pub struct TropicalSampleResult<T: FloatLike, const D: usize> {
7275
pub u: T,
7376
pub v: T,
7477
pub jacobian: T,
78+
pub metadata: Option<Metadata<T, D>>,
79+
}
80+
81+
#[derive(Debug, Clone)]
82+
pub struct Metadata<T: FloatLike, const D: usize> {
83+
pub q_vectors: Vec<Vector<T, D>>,
84+
pub lambda: T,
85+
pub l_matrix: SquareMatrix<T>,
86+
pub decompoisiton_result: DecompositionResult<T>,
87+
pub u_vectors: Vec<Vector<T, D>>,
88+
pub shift: Vec<Vector<T, D>>,
89+
}
90+
91+
impl<const D: usize> Metadata<f128, D> {
92+
pub fn downcast(&self) -> Metadata<f64, D> {
93+
Metadata {
94+
q_vectors: self.q_vectors.iter().map(|v| v.downcast()).collect_vec(),
95+
lambda: self.lambda.into(),
96+
l_matrix: self.l_matrix.downcast(),
97+
decompoisiton_result: self.decompoisiton_result.downcast(),
98+
u_vectors: self.u_vectors.iter().map(|v| v.downcast()).collect_vec(),
99+
shift: self.shift.iter().map(|v| v.downcast()).collect(),
100+
}
101+
}
75102
}
76103

77104
impl<const D: usize> TropicalSampleResult<f128, D> {
@@ -83,6 +110,7 @@ impl<const D: usize> TropicalSampleResult<f128, D> {
83110
u: self.u.into(),
84111
v: self.v.into(),
85112
jacobian: self.jacobian.into(),
113+
metadata: self.metadata.as_ref().map(|metadata| metadata.downcast()),
86114
}
87115
}
88116
}

src/matrix.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use smallvec::SmallVec;
22
use std::ops::{Add, Index, IndexMut, Mul, Sub};
33

44
use crate::{float::FloatLike, TropicalSamplingSettings};
5+
use f128::f128;
56

67
#[derive(Debug, Clone, Default)]
78
/// square symmetric matrix for use in the tropical sampling algorithm
@@ -264,18 +265,36 @@ impl<T: FloatLike> SquareMatrix<T> {
264265
}
265266
}
266267

267-
#[derive(Debug)]
268+
impl SquareMatrix<f128> {
269+
pub fn downcast(&self) -> SquareMatrix<f64> {
270+
SquareMatrix {
271+
data: self.data.iter().map(|&x| x.into()).collect(),
272+
dim: self.dim,
273+
}
274+
}
275+
}
276+
277+
#[derive(Debug, Clone)]
268278
pub struct DecompositionResult<T> {
269279
pub determinant: T,
270280
pub inverse: SquareMatrix<T>,
271281
pub q_transposed_inverse: SquareMatrix<T>,
272282
}
273283

284+
impl DecompositionResult<f128> {
285+
pub fn downcast(&self) -> DecompositionResult<f64> {
286+
DecompositionResult {
287+
determinant: self.determinant.into(),
288+
inverse: self.inverse.downcast(),
289+
q_transposed_inverse: self.q_transposed_inverse.downcast(),
290+
}
291+
}
292+
}
293+
274294
#[cfg(test)]
275295
mod tests {
276296
use super::*;
277297
use crate::assert_approx_eq;
278-
use f128::f128;
279298

280299
const EPSILON: f64 = 1e-12;
281300

src/sampling.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::matrix::{MatrixError, SquareMatrix};
77
use crate::mimic_rng::MimicRng;
88
use crate::vector::Vector;
99
use crate::{float::FloatLike, gamma::inverse_gamma_lr};
10-
use crate::{TropicalSampleResult, TropicalSamplingSettings};
10+
use crate::{Metadata, TropicalSampleResult, TropicalSamplingSettings};
1111

1212
fn box_muller<T: FloatLike>(x1: T, x2: T) -> (T, T) {
1313
let r = (-Into::<T>::into(2.) * x1.ln()).sqrt();
@@ -116,13 +116,27 @@ pub fn sample<T: FloatLike + Into<f64>, const D: usize, #[cfg(feature = "log")]
116116
.powf(Into::<T>::into(tropical_subgraph_table.tropical_graph.dod))
117117
* Into::<T>::into(tropical_subgraph_table.cached_factor);
118118

119+
let metadata = if settings.return_metadata {
120+
Some(Metadata {
121+
l_matrix,
122+
q_vectors,
123+
lambda,
124+
shift: compute_only_shift(&decomposed_l_matrix.inverse, &u_vectors),
125+
decompoisiton_result: decomposed_l_matrix,
126+
u_vectors,
127+
})
128+
} else {
129+
None
130+
};
131+
119132
Ok(TropicalSampleResult {
120133
loop_momenta,
121134
u_trop,
122135
v_trop,
123136
u,
124137
v,
125138
jacobian,
139+
metadata,
126140
})
127141
}
128142

@@ -378,3 +392,21 @@ fn compute_loop_momenta<T: FloatLike, const D: usize>(
378392
})
379393
.collect_vec()
380394
}
395+
396+
fn compute_only_shift<T: FloatLike, const D: usize>(
397+
l_inverse: &SquareMatrix<T>,
398+
u_vectors: &[Vector<T, D>],
399+
) -> Vec<Vector<T, D>> {
400+
let num_loops = l_inverse.get_dim();
401+
(0..num_loops)
402+
.map(|l| {
403+
u_vectors
404+
.iter()
405+
.enumerate()
406+
.fold(Vector::new(), |acc, (l_prime, u)| {
407+
let u_part: Vector<T, D> = u * l_inverse[(l, l_prime)];
408+
&acc + &u_part
409+
})
410+
})
411+
.collect_vec()
412+
}

0 commit comments

Comments
 (0)