Skip to content

Commit 99c249e

Browse files
committed
Allow uncontracted indices
- Strip top-level constants before canonization - Yield an error when an index is repeated more than twice
1 parent 1dee648 commit 99c249e

File tree

2 files changed

+78
-54
lines changed

2 files changed

+78
-54
lines changed

src/atom/core.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -1025,8 +1025,7 @@ pub trait AtomCore {
10251025
}
10261026

10271027
/// Canonize (products of) tensors in the expression by relabeling repeated indices.
1028-
/// The tensors must be written as functions, with its indices are the arguments.
1029-
/// The repeated indices should be provided in `contracted_indices`.
1028+
/// The tensors must be written as functions, with its indices as the arguments.
10301029
///
10311030
/// If the contracted indices are distinguishable (for example in their dimension),
10321031
/// you can provide an optional group marker for each index using `index_group`.
@@ -1054,11 +1053,10 @@ pub trait AtomCore {
10541053
/// yields `fs(mu1,mu2)*fc(mu1,k1,mu3,k1,mu2,mu3)`.
10551054
fn canonize_tensors(
10561055
&self,
1057-
contracted_indices: &[AtomView],
1056+
indices: &[AtomView],
10581057
index_group: Option<&[AtomView]>,
10591058
) -> Result<Atom, String> {
1060-
self.as_atom_view()
1061-
.canonize_tensors(contracted_indices, index_group)
1059+
self.as_atom_view().canonize_tensors(indices, index_group)
10621060
}
10631061

10641062
fn to_pattern(&self) -> Pattern {

src/tensors.rs

+75-49
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,23 @@ pub mod matrix;
1010

1111
impl<'a> AtomView<'a> {
1212
/// Canonize (products of) tensors in the expression by relabeling repeated indices.
13-
/// The tensors must be written as functions, with its indices are the arguments.
14-
/// The repeated indices should be provided in `contracted_indices`.
13+
/// The tensors must be written as functions, with its indices as the arguments.
14+
/// Indices should be provided in `indices`.
1515
///
1616
/// If the contracted indices are distinguishable (for example in their dimension),
1717
/// you can provide an optional group marker for each index using `index_group`.
1818
/// This makes sure that an index will not be renamed to an index from a different group.
1919
pub(crate) fn canonize_tensors(
2020
&self,
21-
contracted_indices: &[AtomView],
21+
indices: &[AtomView],
2222
index_group: Option<&[AtomView]>,
2323
) -> Result<Atom, String> {
2424
if self.is_zero() {
2525
return Ok(self.to_owned());
2626
}
2727

2828
if let Some(c) = index_group {
29-
if c.len() != contracted_indices.len() {
29+
if c.len() != indices.len() {
3030
return Err(
3131
"Index group must have the same length as contracted indices".to_owned(),
3232
);
@@ -39,18 +39,15 @@ impl<'a> AtomView<'a> {
3939
let add = aa.to_add();
4040

4141
for a in a.iter() {
42-
add.extend(
43-
a.canonize_tensor_product(contracted_indices, index_group)?
44-
.as_view(),
45-
);
42+
add.extend(a.canonize_tensor_product(indices, index_group)?.as_view());
4643
}
4744

4845
let mut out = Atom::new();
4946
aa.as_view().normalize(ws, &mut out);
5047
Ok(out)
5148
} else {
5249
Ok(self
53-
.canonize_tensor_product(contracted_indices, index_group)?
50+
.canonize_tensor_product(indices, index_group)?
5451
.into_inner())
5552
}
5653
})
@@ -59,30 +56,49 @@ impl<'a> AtomView<'a> {
5956
/// Canonize a tensor product by relabeling repeated indices.
6057
fn canonize_tensor_product(
6158
&self,
62-
contracted_indices: &[AtomView],
59+
indices: &[AtomView],
6360
index_group: Option<&[AtomView]>,
6461
) -> Result<RecycledAtom, String> {
6562
let mut g = Graph::new();
66-
let mut connections = vec![vec![]; contracted_indices.len()];
63+
let mut connections = vec![(vec![], false); indices.len()];
64+
65+
// strip all top-level factors that do not have any indices, so that
66+
// they do not influence the canonization
67+
let mut stripped = Atom::new();
68+
if let AtomView::Mul(m) = self {
69+
let mm = stripped.to_mul();
70+
for a in m.iter() {
71+
if indices.iter().any(|x| a.contains(*x)) {
72+
mm.extend(a);
73+
}
74+
}
6775

68-
// TODO: strip all top-level products that do not have any contracted indices
69-
// this ensures that graphs that are the same up to multiplication of constants
70-
// map to the same graph
76+
stripped.as_view().tensor_to_graph_impl(
77+
indices,
78+
index_group,
79+
&mut connections,
80+
&mut g,
81+
)?;
82+
} else {
83+
self.tensor_to_graph_impl(indices, index_group, &mut connections, &mut g)?;
84+
}
7185

72-
self.tensor_to_graph_impl(contracted_indices, index_group, &mut connections, &mut g)?;
86+
let mut used_indices = vec![false; indices.len()];
87+
let mut map = vec![None; indices.len()];
7388

74-
for (i, f) in contracted_indices.iter().zip(&connections) {
75-
if !f.is_empty() {
76-
return Err(format!("Index {} is not contracted", i));
89+
for (i, (ii, (f, used))) in indices.iter().zip(&connections).enumerate() {
90+
if !f.is_empty() && *used {
91+
return Err(format!("Index {} is contracted more than once", ii));
92+
} else if f.len() == 1 && !used {
93+
used_indices[i] = true;
94+
map[i] = Some(i);
7795
}
7896
}
7997

8098
let gc = g.canonize().graph;
8199

82100
// connect dummy indices
83101
// TODO: recycle dummy indices that are contracted on a deeper level?
84-
let mut used_indices = vec![false; contracted_indices.len()];
85-
let mut map = vec![None; contracted_indices.len()];
86102
for e in gc.edges() {
87103
if e.directed {
88104
continue;
@@ -112,9 +128,9 @@ impl<'a> AtomView<'a> {
112128
// map the contracted indices
113129
Ok(self
114130
.replace_map(&|a, _ctx, out| {
115-
if let Some(p) = contracted_indices.iter().position(|x| *x == a) {
131+
if let Some(p) = indices.iter().position(|x| *x == a) {
116132
if let Some(q) = map[p] {
117-
out.set_from_view(&contracted_indices[q]);
133+
out.set_from_view(&indices[q]);
118134
true
119135
} else {
120136
unreachable!()
@@ -128,12 +144,12 @@ impl<'a> AtomView<'a> {
128144

129145
fn tensor_to_graph_impl(
130146
&self,
131-
contracted_indices: &[AtomView],
147+
indices: &[AtomView],
132148
index_group: Option<&[AtomView<'a>]>,
133-
connections: &mut [Vec<usize>],
149+
connections: &mut [(Vec<usize>, bool)],
134150
g: &mut Graph<AtomOrView<'a>, (HiddenData<bool, usize>, Option<AtomOrView<'a>>)>,
135151
) -> Result<usize, String> {
136-
if !contracted_indices.iter().any(|a| self.contains(*a)) {
152+
if !indices.iter().any(|a| self.contains(*a)) {
137153
let node = g.add_node(self.into());
138154
return Ok(node);
139155
}
@@ -150,7 +166,7 @@ impl<'a> AtomView<'a> {
150166
let mut nodes = vec![];
151167
for _ in 0..n {
152168
nodes.push(b.tensor_to_graph_impl(
153-
contracted_indices,
169+
indices,
154170
index_group,
155171
connections,
156172
g,
@@ -184,19 +200,27 @@ impl<'a> AtomView<'a> {
184200
let fff = ff.to_fun(f.get_symbol());
185201

186202
for a in f.iter() {
187-
if !contracted_indices.contains(&a) {
203+
if !indices.contains(&a) {
188204
fff.add_arg(a);
189205
}
190206
}
191207
fff.set_normalized(true);
192208

193209
let n = g.add_node(ff.into());
194210
for a in f.iter() {
195-
if let Some(p) = contracted_indices.iter().position(|x| x == &a) {
196-
if connections[p].is_empty() {
197-
connections[p].push(n);
211+
if let Some(p) = indices.iter().position(|x| x == &a) {
212+
if connections[p].1 {
213+
return Err(format!(
214+
"Index {} is contracted more than once",
215+
indices[p]
216+
));
217+
}
218+
219+
if connections[p].0.is_empty() {
220+
connections[p].0.push(n);
198221
} else {
199-
for n2 in connections[p].drain(..) {
222+
connections[p].1 = true;
223+
for n2 in connections[p].0.drain(..) {
200224
g.add_edge(
201225
n,
202226
n2,
@@ -225,14 +249,22 @@ impl<'a> AtomView<'a> {
225249
let mut ff = Atom::new();
226250
let fff = ff.to_fun(f.get_symbol());
227251

228-
if let Some(p) = contracted_indices.iter().position(|x| x == &a) {
252+
if let Some(p) = indices.iter().position(|x| x == &a) {
229253
ff.set_normalized(true);
230254
g.add_node(Atom::Zero.into());
231255

232-
if connections[p].is_empty() {
233-
connections[p].push(start + i);
256+
if connections[p].1 {
257+
return Err(format!(
258+
"Index {} is contracted more than once",
259+
indices[p]
260+
));
261+
}
262+
263+
if connections[p].0.is_empty() {
264+
connections[p].0.push(start + i);
234265
} else {
235-
for n2 in connections[p].drain(..) {
266+
for n2 in connections[p].0.drain(..) {
267+
connections[p].1 = true;
236268
g.add_edge(
237269
start + i,
238270
n2,
@@ -282,13 +314,11 @@ impl<'a> AtomView<'a> {
282314
}
283315
AtomView::Mul(m) => {
284316
let mut nodes = vec![];
317+
318+
// TODO: check for a -1 and absorb it into an antisymmetric factor
319+
// by rearranging its arguments
285320
for a in m.iter() {
286-
nodes.push(a.tensor_to_graph_impl(
287-
contracted_indices,
288-
index_group,
289-
connections,
290-
g,
291-
)?);
321+
nodes.push(a.tensor_to_graph_impl(indices, index_group, connections, g)?);
292322
}
293323
let node = g.add_node(
294324
Symbol::new_with_attributes("PROD", &[FunctionAttribute::Symmetric])
@@ -309,20 +339,16 @@ impl<'a> AtomView<'a> {
309339
for arg in a {
310340
let mut sub_connections = connections.to_vec();
311341

312-
let node = arg.tensor_to_graph_impl(
313-
contracted_indices,
314-
index_group,
315-
&mut sub_connections,
316-
g,
317-
)?;
342+
let node =
343+
arg.tensor_to_graph_impl(indices, index_group, &mut sub_connections, g)?;
318344

319345
subgraphs.push((node, sub_connections));
320346
}
321347

322348
if subgraphs.iter().any(|x| {
323349
x.1.iter()
324350
.zip(&subgraphs[0].1)
325-
.any(|(a, b)| a.len() != b.len())
351+
.any(|(a, b)| a.0.len() != b.0.len() || a.1 != b.1)
326352
}) {
327353
return Err(
328354
"All components of nested sums must have the same open indices".to_owned(),
@@ -344,7 +370,7 @@ impl<'a> AtomView<'a> {
344370
for c in connections.iter_mut().zip(cons) {
345371
// add new open indices from this subgraph
346372
if *c.0 != c.1 {
347-
c.0.extend(c.1);
373+
c.0 .0.extend(c.1 .0);
348374
}
349375
}
350376
}

0 commit comments

Comments
 (0)