-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathparameter_shift.py
105 lines (71 loc) · 2.57 KB
/
parameter_shift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Demonstration on the correctness and efficiency of parameter shift gradient implementation
"""
import sys
import numpy as np
sys.path.insert(0, "../")
import tensorcircuit as tc
from tensorcircuit import experimental as E
K = tc.set_backend("tensorflow")
n = 6
m = 3
def f1(param):
c = tc.Circuit(n)
for j in range(m):
for i in range(n - 1):
c.cnot(i, i + 1)
for i in range(n):
c.rx(i, theta=param[i, j])
return K.real(c.expectation_ps(y=[n // 2]))
g1f1 = K.jit(K.grad(f1))
r1, ts, tr = tc.utils.benchmark(g1f1, K.ones([n, m], dtype="float32"))
g2f1 = K.jit(E.parameter_shift_grad(f1))
r2, ts, tr = tc.utils.benchmark(g2f1, K.ones([n, m], dtype="float32"))
np.testing.assert_allclose(r1, r2, atol=1e-5)
print("equality test passed!")
# mutiple weights args version
def f2(paramzz, paramx):
c = tc.Circuit(n)
for j in range(m):
for i in range(n - 1):
c.rzz(i, i + 1, theta=paramzz[i, j])
for i in range(n):
c.rx(i, theta=paramx[i, j])
return K.real(c.expectation_ps(y=[n // 2]))
g1f2 = K.jit(K.grad(f2, argnums=(0, 1)))
r12, ts, tr = tc.utils.benchmark(
g1f2, K.ones([n, m], dtype="float32"), K.ones([n, m], dtype="float32")
)
g2f2 = K.jit(E.parameter_shift_grad(f2, argnums=(0, 1)))
r22, ts, tr = tc.utils.benchmark(
g2f2, K.ones([n, m], dtype="float32"), K.ones([n, m], dtype="float32")
)
np.testing.assert_allclose(r12[0], r22[0], atol=1e-5)
np.testing.assert_allclose(r12[1], r22[1], atol=1e-5)
print("mutilple weight inputs: equality test passed!")
# sampled expectation version
def f3(param):
c = tc.Circuit(n)
for j in range(m):
for i in range(n - 1):
c.cnot(i, i + 1)
for i in range(n):
c.rx(i, theta=param[i, j])
return K.real(c.sample_expectation_ps(y=[n // 2]))
g2f3 = K.jit(E.parameter_shift_grad(f3))
r2, ts, tr = tc.utils.benchmark(g2f3, K.ones([n, m], dtype="float32"))
np.testing.assert_allclose(r1, r2, atol=1e-5)
print("analytical sampled expectation: equality test passed!")
# def f3(param):
# c = tc.Circuit(n)
# for j in range(m):
# for i in range(n - 1):
# c.cnot(i, i + 1)
# for i in range(n):
# c.rx(i, theta=param[i, j])
# return K.real(c.sample_expectation_ps(y=[n // 2], shots=81920))
# g2f3 = K.jit(E.parameter_shift_grad(f3))
# r2, ts, tr = tc.utils.benchmark(g2f3, K.ones([n, m], dtype="float32"))
# print(r1 - r2)
# np.testing.assert_allclose(r1 - r2, np.zeros_like(r1), atol=1e-3)
# print("finite sampled expectation: equality test passed!")