@@ -10,23 +10,23 @@ pub mod matrix;
10
10
11
11
impl < ' a > AtomView < ' a > {
12
12
/// Canonize (products of) tensors in the expression by relabeling repeated indices.
13
- /// The tensors must be written as functions, with its indices are the arguments.
14
- /// The repeated indices should be provided in `contracted_indices `.
13
+ /// The tensors must be written as functions, with its indices as the arguments.
14
+ /// Indices should be provided in `indices `.
15
15
///
16
16
/// If the contracted indices are distinguishable (for example in their dimension),
17
17
/// you can provide an optional group marker for each index using `index_group`.
18
18
/// This makes sure that an index will not be renamed to an index from a different group.
19
19
pub ( crate ) fn canonize_tensors (
20
20
& self ,
21
- contracted_indices : & [ AtomView ] ,
21
+ indices : & [ AtomView ] ,
22
22
index_group : Option < & [ AtomView ] > ,
23
23
) -> Result < Atom , String > {
24
24
if self . is_zero ( ) {
25
25
return Ok ( self . to_owned ( ) ) ;
26
26
}
27
27
28
28
if let Some ( c) = index_group {
29
- if c. len ( ) != contracted_indices . len ( ) {
29
+ if c. len ( ) != indices . len ( ) {
30
30
return Err (
31
31
"Index group must have the same length as contracted indices" . to_owned ( ) ,
32
32
) ;
@@ -39,18 +39,15 @@ impl<'a> AtomView<'a> {
39
39
let add = aa. to_add ( ) ;
40
40
41
41
for a in a. iter ( ) {
42
- add. extend (
43
- a. canonize_tensor_product ( contracted_indices, index_group) ?
44
- . as_view ( ) ,
45
- ) ;
42
+ add. extend ( a. canonize_tensor_product ( indices, index_group) ?. as_view ( ) ) ;
46
43
}
47
44
48
45
let mut out = Atom :: new ( ) ;
49
46
aa. as_view ( ) . normalize ( ws, & mut out) ;
50
47
Ok ( out)
51
48
} else {
52
49
Ok ( self
53
- . canonize_tensor_product ( contracted_indices , index_group) ?
50
+ . canonize_tensor_product ( indices , index_group) ?
54
51
. into_inner ( ) )
55
52
}
56
53
} )
@@ -59,30 +56,49 @@ impl<'a> AtomView<'a> {
59
56
/// Canonize a tensor product by relabeling repeated indices.
60
57
fn canonize_tensor_product (
61
58
& self ,
62
- contracted_indices : & [ AtomView ] ,
59
+ indices : & [ AtomView ] ,
63
60
index_group : Option < & [ AtomView ] > ,
64
61
) -> Result < RecycledAtom , String > {
65
62
let mut g = Graph :: new ( ) ;
66
- let mut connections = vec ! [ vec![ ] ; contracted_indices. len( ) ] ;
63
+ let mut connections = vec ! [ ( vec![ ] , false ) ; indices. len( ) ] ;
64
+
65
+ // strip all top-level factors that do not have any indices, so that
66
+ // they do not influence the canonization
67
+ let mut stripped = Atom :: new ( ) ;
68
+ if let AtomView :: Mul ( m) = self {
69
+ let mm = stripped. to_mul ( ) ;
70
+ for a in m. iter ( ) {
71
+ if indices. iter ( ) . any ( |x| a. contains ( * x) ) {
72
+ mm. extend ( a) ;
73
+ }
74
+ }
67
75
68
- // TODO: strip all top-level products that do not have any contracted indices
69
- // this ensures that graphs that are the same up to multiplication of constants
70
- // map to the same graph
76
+ stripped. as_view ( ) . tensor_to_graph_impl (
77
+ indices,
78
+ index_group,
79
+ & mut connections,
80
+ & mut g,
81
+ ) ?;
82
+ } else {
83
+ self . tensor_to_graph_impl ( indices, index_group, & mut connections, & mut g) ?;
84
+ }
71
85
72
- self . tensor_to_graph_impl ( contracted_indices, index_group, & mut connections, & mut g) ?;
86
+ let mut used_indices = vec ! [ false ; indices. len( ) ] ;
87
+ let mut map = vec ! [ None ; indices. len( ) ] ;
73
88
74
- for ( i, f) in contracted_indices. iter ( ) . zip ( & connections) {
75
- if !f. is_empty ( ) {
76
- return Err ( format ! ( "Index {} is not contracted" , i) ) ;
89
+ for ( i, ( ii, ( f, used) ) ) in indices. iter ( ) . zip ( & connections) . enumerate ( ) {
90
+ if !f. is_empty ( ) && * used {
91
+ return Err ( format ! ( "Index {} is contracted more than once" , ii) ) ;
92
+ } else if f. len ( ) == 1 && !used {
93
+ used_indices[ i] = true ;
94
+ map[ i] = Some ( i) ;
77
95
}
78
96
}
79
97
80
98
let gc = g. canonize ( ) . graph ;
81
99
82
100
// connect dummy indices
83
101
// TODO: recycle dummy indices that are contracted on a deeper level?
84
- let mut used_indices = vec ! [ false ; contracted_indices. len( ) ] ;
85
- let mut map = vec ! [ None ; contracted_indices. len( ) ] ;
86
102
for e in gc. edges ( ) {
87
103
if e. directed {
88
104
continue ;
@@ -112,9 +128,9 @@ impl<'a> AtomView<'a> {
112
128
// map the contracted indices
113
129
Ok ( self
114
130
. replace_map ( & |a, _ctx, out| {
115
- if let Some ( p) = contracted_indices . iter ( ) . position ( |x| * x == a) {
131
+ if let Some ( p) = indices . iter ( ) . position ( |x| * x == a) {
116
132
if let Some ( q) = map[ p] {
117
- out. set_from_view ( & contracted_indices [ q] ) ;
133
+ out. set_from_view ( & indices [ q] ) ;
118
134
true
119
135
} else {
120
136
unreachable ! ( )
@@ -128,12 +144,12 @@ impl<'a> AtomView<'a> {
128
144
129
145
fn tensor_to_graph_impl (
130
146
& self ,
131
- contracted_indices : & [ AtomView ] ,
147
+ indices : & [ AtomView ] ,
132
148
index_group : Option < & [ AtomView < ' a > ] > ,
133
- connections : & mut [ Vec < usize > ] ,
149
+ connections : & mut [ ( Vec < usize > , bool ) ] ,
134
150
g : & mut Graph < AtomOrView < ' a > , ( HiddenData < bool , usize > , Option < AtomOrView < ' a > > ) > ,
135
151
) -> Result < usize , String > {
136
- if !contracted_indices . iter ( ) . any ( |a| self . contains ( * a) ) {
152
+ if !indices . iter ( ) . any ( |a| self . contains ( * a) ) {
137
153
let node = g. add_node ( self . into ( ) ) ;
138
154
return Ok ( node) ;
139
155
}
@@ -150,7 +166,7 @@ impl<'a> AtomView<'a> {
150
166
let mut nodes = vec ! [ ] ;
151
167
for _ in 0 ..n {
152
168
nodes. push ( b. tensor_to_graph_impl (
153
- contracted_indices ,
169
+ indices ,
154
170
index_group,
155
171
connections,
156
172
g,
@@ -184,19 +200,27 @@ impl<'a> AtomView<'a> {
184
200
let fff = ff. to_fun ( f. get_symbol ( ) ) ;
185
201
186
202
for a in f. iter ( ) {
187
- if !contracted_indices . contains ( & a) {
203
+ if !indices . contains ( & a) {
188
204
fff. add_arg ( a) ;
189
205
}
190
206
}
191
207
fff. set_normalized ( true ) ;
192
208
193
209
let n = g. add_node ( ff. into ( ) ) ;
194
210
for a in f. iter ( ) {
195
- if let Some ( p) = contracted_indices. iter ( ) . position ( |x| x == & a) {
196
- if connections[ p] . is_empty ( ) {
197
- connections[ p] . push ( n) ;
211
+ if let Some ( p) = indices. iter ( ) . position ( |x| x == & a) {
212
+ if connections[ p] . 1 {
213
+ return Err ( format ! (
214
+ "Index {} is contracted more than once" ,
215
+ indices[ p]
216
+ ) ) ;
217
+ }
218
+
219
+ if connections[ p] . 0 . is_empty ( ) {
220
+ connections[ p] . 0 . push ( n) ;
198
221
} else {
199
- for n2 in connections[ p] . drain ( ..) {
222
+ connections[ p] . 1 = true ;
223
+ for n2 in connections[ p] . 0 . drain ( ..) {
200
224
g. add_edge (
201
225
n,
202
226
n2,
@@ -225,14 +249,22 @@ impl<'a> AtomView<'a> {
225
249
let mut ff = Atom :: new ( ) ;
226
250
let fff = ff. to_fun ( f. get_symbol ( ) ) ;
227
251
228
- if let Some ( p) = contracted_indices . iter ( ) . position ( |x| x == & a) {
252
+ if let Some ( p) = indices . iter ( ) . position ( |x| x == & a) {
229
253
ff. set_normalized ( true ) ;
230
254
g. add_node ( Atom :: Zero . into ( ) ) ;
231
255
232
- if connections[ p] . is_empty ( ) {
233
- connections[ p] . push ( start + i) ;
256
+ if connections[ p] . 1 {
257
+ return Err ( format ! (
258
+ "Index {} is contracted more than once" ,
259
+ indices[ p]
260
+ ) ) ;
261
+ }
262
+
263
+ if connections[ p] . 0 . is_empty ( ) {
264
+ connections[ p] . 0 . push ( start + i) ;
234
265
} else {
235
- for n2 in connections[ p] . drain ( ..) {
266
+ for n2 in connections[ p] . 0 . drain ( ..) {
267
+ connections[ p] . 1 = true ;
236
268
g. add_edge (
237
269
start + i,
238
270
n2,
@@ -282,13 +314,11 @@ impl<'a> AtomView<'a> {
282
314
}
283
315
AtomView :: Mul ( m) => {
284
316
let mut nodes = vec ! [ ] ;
317
+
318
+ // TODO: check for a -1 and absorb it into an antisymmetric factor
319
+ // by rearranging its arguments
285
320
for a in m. iter ( ) {
286
- nodes. push ( a. tensor_to_graph_impl (
287
- contracted_indices,
288
- index_group,
289
- connections,
290
- g,
291
- ) ?) ;
321
+ nodes. push ( a. tensor_to_graph_impl ( indices, index_group, connections, g) ?) ;
292
322
}
293
323
let node = g. add_node (
294
324
Symbol :: new_with_attributes ( "PROD" , & [ FunctionAttribute :: Symmetric ] )
@@ -309,20 +339,16 @@ impl<'a> AtomView<'a> {
309
339
for arg in a {
310
340
let mut sub_connections = connections. to_vec ( ) ;
311
341
312
- let node = arg. tensor_to_graph_impl (
313
- contracted_indices,
314
- index_group,
315
- & mut sub_connections,
316
- g,
317
- ) ?;
342
+ let node =
343
+ arg. tensor_to_graph_impl ( indices, index_group, & mut sub_connections, g) ?;
318
344
319
345
subgraphs. push ( ( node, sub_connections) ) ;
320
346
}
321
347
322
348
if subgraphs. iter ( ) . any ( |x| {
323
349
x. 1 . iter ( )
324
350
. zip ( & subgraphs[ 0 ] . 1 )
325
- . any ( |( a, b) | a. len ( ) != b. len ( ) )
351
+ . any ( |( a, b) | a. 0 . len ( ) != b. 0 . len ( ) || a . 1 != b . 1 )
326
352
} ) {
327
353
return Err (
328
354
"All components of nested sums must have the same open indices" . to_owned ( ) ,
@@ -344,7 +370,7 @@ impl<'a> AtomView<'a> {
344
370
for c in connections. iter_mut ( ) . zip ( cons) {
345
371
// add new open indices from this subgraph
346
372
if * c. 0 != c. 1 {
347
- c. 0 . extend ( c. 1 ) ;
373
+ c. 0 . 0 . extend ( c. 1 . 0 ) ;
348
374
}
349
375
}
350
376
}
0 commit comments