Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
pmenczel committed Sep 15, 2024
1 parent cae3898 commit 8db2833
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 31 deletions.
26 changes: 3 additions & 23 deletions qutip/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,25 +479,11 @@ def approx_by_cf_fit(
guess_re, lower_re, upper_re = guess, lower, upper
guess_im, lower_im, upper_im = guess, lower, upper

print(lower_re)
print(guess_re)
print(upper_re)
if guess_re is None:
guess_fun_re = None
else:
def guess_fun_re(N):
return np.tile(guess_re, (N, 1))
if guess_im is None:
guess_fun_im = None
else:
def guess_fun_im(N):
return np.tile(guess_im, (N, 1))

# Fit real part
start_real = time()
rmse_real, params_real = iterated_fit(
_cf_real_fit_model, num_params, tlist, np.real(clist), target_rsme,
guess_fun_re, Nr_min, Nr_max, lower_re, upper_re
guess_re, Nr_min, Nr_max, lower_re, upper_re
)
end_real = time()
fit_time_real = end_real - start_real
Expand All @@ -506,7 +492,7 @@ def guess_fun_im(N):
start_imag = time()
rmse_imag, params_imag = iterated_fit(
_cf_imag_fit_model, num_params, tlist, np.imag(clist), target_rsme,
guess_fun_im, Ni_min, Ni_max, lower_im, upper_im
guess_im, Ni_min, Ni_max, lower_im, upper_im
)
end_imag = time()
fit_time_imag = end_imag - start_imag
Expand Down Expand Up @@ -643,16 +629,10 @@ def approx_by_sd_fit(
if guess is None and lower is None and upper is None:
guess, lower, upper = _default_guess_sd(wlist, jlist)

if guess is None:
guess_fun = None
else:
def guess_fun(N):
return np.tile(guess, (N, 1))

# Fit
start = time()
rmse, params = iterated_fit(
_sd_fit_model, 3, wlist, jlist, target_rsme, guess_fun,
_sd_fit_model, 3, wlist, jlist, target_rsme, guess,
Nmin, Nmax, lower, upper
)
end = time()
Expand Down
21 changes: 13 additions & 8 deletions qutip/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def iterated_fit(
fun: Callable[..., complex], num_params: int,
xdata: ArrayLike, ydata: ArrayLike,
target_rmse: float = 1e-5,
guess: Callable[[int], ArrayLike] = None,
guess: ArrayLike | Callable[[int], ArrayLike] = None,
Nmin: int = 1, Nmax: int = 10,
lower: ArrayLike = None, upper: ArrayLike = None
) -> tuple[float, ArrayLike]:
Expand Down Expand Up @@ -376,10 +376,13 @@ def iterated_fit(
The dependent data.
target_rmse : optional, float
Desired normalized root mean squared error (default `1e-5`).
guess : optional, callable
A function that, given a number `N` of terms, returns an array
`[[p11, ..., p1n], [p21, ..., p2n], ..., [pN1, ..., pNn]]`
of initial guesses.
guess : optional, array_like or callable
This can be either a list of length `n`, with the i-th entry being the
guess for the parameter :math:`p_{k,i}` (for all terms :math:`k`), or a
function that provides different initial guesses for each term.
Specifically, given a number `N` of terms, the function returns an
array `[[p11, ..., p1n], [p21, ..., p2n], ..., [pN1, ..., pNn]]` of
initial guesses.
Nmin : optional, int
The minimum number of terms to be used for the fit (default 1).
Nmax : optional, int
Expand Down Expand Up @@ -418,14 +421,16 @@ def iterated_fit(
while rmse1 > target_rmse and N <= Nmax:
if guess is None:
guesses = np.ones((N, num_params), dtype=float)
else:
elif callable(guess):
guesses = np.array(guess(N))
if guesses.shape != (N, num_params):
raise ValueError(
"The shape of the provided fit guesses is not consistent")
else:
guesses = np.tile(guess, (N, 1))

lower_repeat = np.repeat(lower, N)
upper_repeat = np.repeat(upper, N)
lower_repeat = np.tile(lower, N)
upper_repeat = np.tile(upper, N)
rmse1, params = _fit(fun, num_params, xdata, ydata, N,
guesses, lower_repeat, upper_repeat)
N += 1
Expand Down

0 comments on commit 8db2833

Please sign in to comment.