Skip to content

Commit

Permalink
Refactor code to use State::get_symbol instead of State::get_or_inser…
Browse files Browse the repository at this point in the history
…t_fn
  • Loading branch information
lcnbr committed Mar 21, 2024
1 parent 2c918a8 commit 6f4cb22
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 68 deletions.
2 changes: 1 addition & 1 deletion benches/evaluate_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn gamma_net_param(
.into_iter()
.collect();
i += 1;
let pid = State::get_or_insert_fn(&format!("p{}", i), None).unwrap();
let pid = State::get_symbol(&format!("p{}", i));

result.push(p.shadow_with(pid).into());

Expand Down
2 changes: 1 addition & 1 deletion examples/Rust/Tensors/evaluate_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn gamma_net_param(
.into_iter()
.collect();
i += 1;
let pid = State::get_or_insert_fn(&format!("p{}", i), None).unwrap();
let pid = State::get_symbol(&format!("p{}", i));

result.push(p.shadow_with(pid).into());

Expand Down
17 changes: 10 additions & 7 deletions src/tensor/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,10 +686,11 @@ where
pub fn symbolic_shadow(&mut self, name: &str) -> TensorNetwork<MixedTensors> {
{
for (i, n) in &mut self.graph.nodes {
n.mut_structure().set_name(
&State::get_or_insert_fn(format!("{}{}", name, i.data().as_ffi()), None)
.unwrap(),
);
n.mut_structure().set_name(&State::get_symbol(format!(
"{}{}",
name,
i.data().as_ffi()
)));
}
}

Expand Down Expand Up @@ -751,9 +752,11 @@ where
{
pub fn namesym(&mut self, name: &str) {
for (id, n) in &mut self.graph.nodes {
n.set_name(
&State::get_or_insert_fn(format!("{}{}", name, id.data().as_ffi()), None).unwrap(),
);
n.set_name(&State::get_symbol(format!(
"{}{}",
name,
id.data().as_ffi()
)));
}
}
}
Expand Down
52 changes: 26 additions & 26 deletions src/tensor/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,17 @@ impl Representation {
#[allow(clippy::cast_possible_wrap)]
pub fn to_fnbuilder<'a, 'b: 'a>(&'a self) -> FunctionBuilder {
let (value, id) = match *self {
Self::Euclidean(value) => (value, State::get_or_insert_fn("euc", None)),
Self::Lorentz(value) => (value, State::get_or_insert_fn("lor", None)),
Self::Spin(value) => (value, State::get_or_insert_fn("spin", None)),
Self::ColorAdjoint(value) => (value, State::get_or_insert_fn("CAdj", None)),
Self::ColorFundamental(value) => (value, State::get_or_insert_fn("CF", None)),
Self::ColorAntiFundamental(value) => (value, State::get_or_insert_fn("CAF", None)),
Self::ColorSextet(value) => (value, State::get_or_insert_fn("CS", None)),
Self::ColorAntiSextet(value) => (value, State::get_or_insert_fn("CAS", None)),
Self::Euclidean(value) => (value, State::get_symbol("euc")),
Self::Lorentz(value) => (value, State::get_symbol("lor")),
Self::Spin(value) => (value, State::get_symbol("spin")),
Self::ColorAdjoint(value) => (value, State::get_symbol("CAdj")),
Self::ColorFundamental(value) => (value, State::get_symbol("CF")),
Self::ColorAntiFundamental(value) => (value, State::get_symbol("CAF")),
Self::ColorSextet(value) => (value, State::get_symbol("CS")),
Self::ColorAntiSextet(value) => (value, State::get_symbol("CAS")),
};

let mut value_builder = FunctionBuilder::new(id.unwrap_or_else(|_| unreachable!()));
let mut value_builder = FunctionBuilder::new(id);

value_builder =
value_builder.add_arg(Atom::new_num(usize::from(value) as i64).as_atom_view());
Expand Down Expand Up @@ -375,14 +375,14 @@ impl TryFrom<AtomView<'_>> for Slot {
return Err("Too many arguments");
}

let euc = State::get_or_insert_fn("euc", None).unwrap();
let lor = State::get_or_insert_fn("lor", None).unwrap();
let spin = State::get_or_insert_fn("spin", None).unwrap();
let cadj = State::get_or_insert_fn("CAdj", None).unwrap();
let cf = State::get_or_insert_fn("CF", None).unwrap();
let caf = State::get_or_insert_fn("CAF", None).unwrap();
let cs = State::get_or_insert_fn("CS", None).unwrap();
let cas = State::get_or_insert_fn("CAS", None).unwrap();
let euc = State::get_symbol("euc");
let lor = State::get_symbol("lor");
let spin = State::get_symbol("spin");
let cadj = State::get_symbol("CAdj");
let cf = State::get_symbol("CF");
let caf = State::get_symbol("CAF");
let cs = State::get_symbol("CS");
let cas = State::get_symbol("CAS");

let representation = if let AtomView::Fun(f) = value {
let sym = f.get_symbol();
Expand Down Expand Up @@ -767,12 +767,12 @@ pub trait TensorStructure {
Self: std::marker::Sized,
Self::Structure: Clone + TensorStructure,
{
let id = State::get_or_insert_fn("id", None).unwrap();
let gamma = State::get_or_insert_fn("γ", None).unwrap();
let gamma5 = State::get_or_insert_fn("γ5", None).unwrap();
let proj_m = State::get_or_insert_fn("ProjM", None).unwrap();
let proj_p = State::get_or_insert_fn("ProjP", None).unwrap();
let sigma = State::get_or_insert_fn("σ", None).unwrap();
let id = State::get_symbol("id");
let gamma = State::get_symbol("γ");
let gamma5 = State::get_symbol("γ5");
let proj_m = State::get_symbol("ProjM");
let proj_p = State::get_symbol("ProjP");
let sigma = State::get_symbol("σ");

match f_id {
_ if f_id == id => {
Expand Down Expand Up @@ -1603,7 +1603,7 @@ pub trait IntoId {

impl IntoId for SmartString<LazyCompact> {
fn into_id(self) -> Symbol {
State::get_or_insert_fn(self, None).unwrap()
State::get_symbol(self)
}
}

Expand All @@ -1615,13 +1615,13 @@ impl IntoId for Symbol {

impl IntoId for &str {
fn into_id(self) -> Symbol {
State::get_or_insert_fn(self, None).unwrap()
State::get_symbol(self)
}
}

impl IntoId for std::string::String {
fn into_id(self) -> Symbol {
State::get_or_insert_fn(self, None).unwrap()
State::get_symbol(self)
}
}

Expand Down
6 changes: 1 addition & 5 deletions src/tensor/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ use rand::{distributions::Uniform, Rng, SeedableRng};
use rand_xoshiro::Xoroshiro64Star;
use smartstring::alias::String;
use symbolica::domains::float::Complex;
use symbolica::{
representations::Atom,
state::{State, Workspace},
};
use symbolica::{representations::Atom, state::State};

use super::{
symbolic::SymbolicTensor, ufo, AbstractIndex, DataTensor, Dimension, HistoryStructure,
Expand Down Expand Up @@ -719,7 +716,6 @@ fn evaluate() {

#[test]
fn convert_sym() {
let _ws = Workspace::new();
let i = Complex::new(0.0, 1.0);
let mut data_b = vec![i * Complex::from(5.0), Complex::from(2.6) + i];
data_b.append(
Expand Down
50 changes: 22 additions & 28 deletions src/tensor/ufo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ use symbolica::{
};

// pub fn init_state() {
// assert!(EUC == State::get_or_insert_fn("euc", None).unwrap());
// assert!(LOR == State::get_or_insert_fn("lor", None).unwrap());
// assert!(SPIN == State::get_or_insert_fn("spin", None).unwrap());
// assert!(CADJ == State::get_or_insert_fn("CAdj", None).unwrap());
// assert!(CF == State::get_or_insert_fn("CF", None).unwrap());
// assert!(CAF == State::get_or_insert_fn("CAF", None).unwrap());
// assert!(CS == State::get_or_insert_fn("CS", None).unwrap());
// assert!(CAS == State::get_or_insert_fn("CAS", None).unwrap());

// assert!(ID == State::get_or_insert_fn("id", None).unwrap());
// assert!(GAMMA == State::get_or_insert_fn("γ", None).unwrap());
// assert!(GAMMA5 == State::get_or_insert_fn("γ5", None).unwrap());
// assert!(PROJM == State::get_or_insert_fn("ProjM", None).unwrap());
// assert!(PROJP == State::get_or_insert_fn("ProjP", None).unwrap());
// assert!(SIGMA == State::get_or_insert_fn("σ", None).unwrap());
// assert!(EUC == State::get_symbol("euc", None).unwrap());
// assert!(LOR == State::get_symbol("lor", None).unwrap());
// assert!(SPIN == State::get_symbol("spin", None).unwrap());
// assert!(CADJ == State::get_symbol("CAdj", None).unwrap());
// assert!(CF == State::get_symbol("CF", None).unwrap());
// assert!(CAF == State::get_symbol("CAF", None).unwrap());
// assert!(CS == State::get_symbol("CS", None).unwrap());
// assert!(CAS == State::get_symbol("CAS", None).unwrap());

// assert!(ID == State::get_symbol("id", None).unwrap());
// assert!(GAMMA == State::get_symbol("γ", None).unwrap());
// assert!(GAMMA5 == State::get_symbol("γ5", None).unwrap());
// assert!(PROJM == State::get_symbol("ProjM", None).unwrap());
// assert!(PROJP == State::get_symbol("ProjP", None).unwrap());
// assert!(SIGMA == State::get_symbol("σ", None).unwrap());
// }

#[allow(dead_code)]
Expand Down Expand Up @@ -124,10 +124,7 @@ where
{
DenseTensor::from_data(
p,
HistoryStructure::new(
&[(index, Lorentz(4.into()))],
State::get_or_insert_fn("p", None).unwrap_or_else(|_| unreachable!()),
),
HistoryStructure::new(&[(index, Lorentz(4.into()))], State::get_symbol("p")),
)
.unwrap_or_else(|_| unreachable!())
}
Expand Down Expand Up @@ -155,10 +152,7 @@ where
{
DenseTensor::from_data(
p,
HistoryStructure::new(
&[(index, Euclidean(4.into()))],
State::get_or_insert_fn("p", None).unwrap_or_else(|_| unreachable!()),
),
HistoryStructure::new(&[(index, Euclidean(4.into()))], State::get_symbol("p")),
)
.unwrap_or_else(|_| unreachable!())
}
Expand Down Expand Up @@ -236,7 +230,7 @@ where
(indices.1, Euclidean(4.into())),
(minkindex, Lorentz(4.into())),
],
State::get_or_insert_fn("γ", None).unwrap_or_else(|_| unreachable!()),
State::get_symbol("γ"),
);

gamma_data(structure)
Expand Down Expand Up @@ -306,7 +300,7 @@ where
(indices.0, Euclidean(4.into())),
(indices.1, Euclidean(4.into())),
],
State::get_or_insert_fn("γ5", None).unwrap_or_else(|_| unreachable!()),
State::get_symbol("γ5"),
);

gamma5_data(structure)
Expand Down Expand Up @@ -357,7 +351,7 @@ where
(indices.0, Euclidean(4.into())),
(indices.1, Euclidean(4.into())),
],
State::get_or_insert_fn("ProjM", None).unwrap_or_else(|_| unreachable!()),
State::get_symbol("ProjM"),
);

proj_m_data(structure)
Expand Down Expand Up @@ -416,7 +410,7 @@ where
(indices.0, Euclidean(4.into())),
(indices.1, Euclidean(4.into())),
],
State::get_or_insert_fn("ProjP", None).unwrap_or_else(|_| unreachable!()),
State::get_symbol("ProjP"),
);

proj_p_data(structure)
Expand Down Expand Up @@ -496,7 +490,7 @@ where
(minkdices.0, Lorentz(4.into())),
(minkdices.1, Lorentz(4.into())),
],
State::get_or_insert_fn("σ", None).unwrap_or_else(|_| unreachable!()),
State::get_symbol("σ"),
);

sigma_data(structure)
Expand Down

0 comments on commit 6f4cb22

Please sign in to comment.