-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultitraj_patch.py
225 lines (188 loc) · 9.12 KB
/
multitraj_patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
__all__ = ['InitialStateGenerator',
'EnhancedMultiTrajResult',
'EnhancedMultiTrajSolver']
from qutip.solver.multitraj import MultiTrajSolver, _get_map
import numpy as np
import qutip as qt
from time import time
from operator import itemgetter
from collections import Counter
import bisect
from typing import Any
class InitialStateGenerator:
def __init__(self, initial_conditions: list[tuple[Any, float, complex]],
ntraj: int):
"""
initial_conditions: a list of possible initial conditions
Each entry in the list contains three parts:
* state: the state to be passed to the solver's `run` method
* frequency: the desired frequency (between 0 and 1) of this
initial state among the trajectories
* weight: an additional weight to be added to trajectories
starting from this initial state
Note that sum(f for f in frequencies) must be one.
ntraj: number of trajectories
There must be at least one trajectory for each state with non-zero
frequency.
An InitialStateGenerator object represents a list of states together
with the number of trajectories starting from the respective states,
`trajectory_count(n)`, and corrected weights, `corrected_weight(n)`.
It is guaranteed that the total number of trajectories is `ntraj`:
sum_n trajectory_count(n) = ntraj.
It is further guaranteed that for each n,
weight * frequency = corrected_weight * (trajectory_count / ntraj).
We try to generate a distribution of trajectory counts that
approximates the provided frequencies as well as possible under these
constraints,
trajectory_count ~ frequency * ntraj.
"""
self.ntraj = ntraj
self._states: list[tuple[Any, float, complex, int]] = []
# remove zero-frequency entries
# also, note down originally requested frequency for each entry
# and calculate the "target" freq * ntraj
filtered_ics = [(state, freq, weight, freq * ntraj)
for state, freq, weight in initial_conditions
if freq > 0]
if len(filtered_ics) > ntraj:
raise ValueError("Not enough trajectories "
"for mixed initial conditions")
# The following algorithm is loosely based on
# https://stackoverflow.com/a/792490
# We initially round up because each state needs at least
# one trajectory. We then remove trajectories until the correct
# total number of trajectories is reached.
# We remove the trajectory from the state with maximum result / target,
# but we never remove the only remaining trajectory
self._states = []
under_consideration = []
total_number = 0
for state, freq, weight, target in filtered_ics:
result = int(np.ceil(target))
total_number += result
# if only one trajectory, can be added to self._states and not
# considered further
if result == 1:
self._states.append((state, freq, weight, result))
continue
ratio = result / target
# under_consideration is kept sorted according to the ratio
bisect.insort(under_consideration,
(state, freq, weight, result, ratio),
key=itemgetter(4))
while total_number > ntraj:
state, freq, weight, result, _ = under_consideration.pop()
result -= 1
total_number -= 1
if result == 1:
self._states.append((state, freq, weight, result))
continue
ratio = result / target
bisect.insort(under_consideration,
(state, freq, weight, result, ratio),
key=itemgetter(4))
# Finally we have achieved total_number = ntraj, add all remaining
# states to self._states
for state, freq, weight, result, _ in under_consideration:
self._states.append((state, freq, weight, result))
def nstates(self):
return len(self._states)
def state(self, n: int) -> Any:
return self._states[n][0]
def trajectory_count(self, n: int) -> int:
return self._states[n][3]
def weight(self, n: int) -> complex:
_, orig_freq, extra_weight, traj_count = self._states[n]
return extra_weight * orig_freq * self.ntraj / traj_count
def state_numbers(self) -> list[int]:
counts = Counter({n: self.trajectory_count(n)
for n in range(self.nstates())})
return list(counts.elements())
class EnhancedMultiTrajResult(qt.MultiTrajResult):
def _weighted_dm(self, state, weight):
if state is None:
return state
return qt.ket2dm(state) * weight
def add(self, trajectory_info: tuple[np.random.SeedSequence, qt.Result]):
_, trajectory = trajectory_info
if not hasattr(trajectory, 'weight'):
return super().add(trajectory_info)
weight = trajectory.weight
old_states = trajectory.states
trajectory.states = [self._weighted_dm(state, weight)
for state in old_states]
old_final_state = trajectory.final_state
trajectory.final_state = self._weighted_dm(old_final_state, weight)
old_edata = trajectory.e_data
trajectory.e_data = {key: weight * np.asarray(data, dtype=np.complex_)
for key, data in old_edata.items()}
# We make the values of e_data arrays with a complex data type
# Otherwise, we would get an exception in situations where the first
# trajectory only contains real numbers (meaning that the
# multi-trajectory result gets initialized with float-typed arrays) and
# a later trajectory contains complex numbers and cant get added
result = super().add(trajectory_info)
trajectory.states = old_states
trajectory.final_state = old_final_state
trajectory.e_data = old_edata
return result
class EnhancedMultiTrajSolver(MultiTrajSolver):
resultclass = EnhancedMultiTrajResult
# little hack to make _restore_state aware of the current time
def _initialize_run_one_traj(self, seed, state, tlist, e_ops):
self._restore_state_time = tlist[0]
return super()._initialize_run_one_traj(seed, state, tlist, e_ops)
# Make use of Integrator's `run` method which might be more efficient
# than repeated calls to its `integrate` method.
def _integrate_one_traj(self, seed: np.random.SeedSequence,
tlist: list[float], result: qt.Result
) -> tuple[np.random.SeedSequence, qt.Result]:
# Note that integrator.run discards first value of tlist
for t, state in self._integrator.run(tlist):
self._restore_state_time = t
result.add(t, self._restore_state(state, copy=False))
return seed, result
def step(self, t, *, args=None, copy=True):
self._restore_state_time = t
return super().step(t, args=args, copy=copy)
# Support for mixed initial state
# Argument types mostly too complicated for type hints
# We do not support target tolerance
def run_mixed(self, state_generator: InitialStateGenerator,
tlist: list[float], *,
args=None, e_ops=(), timeout=None, seed=None):
start_time = time()
self._argument(args)
stats = self._initialize_stats()
seeds = self._read_seed(seed, state_generator.ntraj)
result = self.resultclass(
e_ops, self.options, solver=self.name, stats=stats
)
map_func = _get_map[self.options['map']]
map_kw = {
'timeout': timeout,
'job_timeout': self.options['job_timeout'],
'num_cpus': self.options['num_cpus'],
}
trajectory_infos = list(zip(seeds, state_generator.state_numbers()))
stats['preparation time'] += time() - start_time
start_time = time()
map_func(self._run_one_traj_mixed, trajectory_infos,
(state_generator, tlist, e_ops),
reduce_func=result.add, map_kw=map_kw,
progress_bar=self.options["progress_bar"],
progress_bar_kwargs=self.options["progress_kwargs"])
result.stats['run time'] = time() - start_time
return result
def _run_one_traj_mixed(
self,
trajectory_info: tuple[np.random.SeedSequence, int],
state_generator: InitialStateGenerator,
tlist: list[float], e_ops: Any) -> qt.Result:
seed, state_number = trajectory_info
state = self._prepare_state(
state_generator.state(state_number))
weight = state_generator.weight(state_number)
seed, result = self._run_one_traj(seed, state, tlist, e_ops)
result.weight = weight
return seed, result