diff --git a/qutip/core/environment.py b/qutip/core/environment.py index 4d9724f179..d3a7fbfab4 100644 --- a/qutip/core/environment.py +++ b/qutip/core/environment.py @@ -479,25 +479,11 @@ def approx_by_cf_fit( guess_re, lower_re, upper_re = guess, lower, upper guess_im, lower_im, upper_im = guess, lower, upper - print(lower_re) - print(guess_re) - print(upper_re) - if guess_re is None: - guess_fun_re = None - else: - def guess_fun_re(N): - return np.tile(guess_re, (N, 1)) - if guess_im is None: - guess_fun_im = None - else: - def guess_fun_im(N): - return np.tile(guess_im, (N, 1)) - # Fit real part start_real = time() rmse_real, params_real = iterated_fit( _cf_real_fit_model, num_params, tlist, np.real(clist), target_rsme, - guess_fun_re, Nr_min, Nr_max, lower_re, upper_re + guess_re, Nr_min, Nr_max, lower_re, upper_re ) end_real = time() fit_time_real = end_real - start_real @@ -506,7 +492,7 @@ def guess_fun_im(N): start_imag = time() rmse_imag, params_imag = iterated_fit( _cf_imag_fit_model, num_params, tlist, np.imag(clist), target_rsme, - guess_fun_im, Ni_min, Ni_max, lower_im, upper_im + guess_im, Ni_min, Ni_max, lower_im, upper_im ) end_imag = time() fit_time_imag = end_imag - start_imag @@ -643,16 +629,10 @@ def approx_by_sd_fit( if guess is None and lower is None and upper is None: guess, lower, upper = _default_guess_sd(wlist, jlist) - if guess is None: - guess_fun = None - else: - def guess_fun(N): - return np.tile(guess, (N, 1)) - # Fit start = time() rmse, params = iterated_fit( - _sd_fit_model, 3, wlist, jlist, target_rsme, guess_fun, + _sd_fit_model, 3, wlist, jlist, target_rsme, guess, Nmin, Nmax, lower, upper ) end = time() diff --git a/qutip/utilities.py b/qutip/utilities.py index 3bb99428c4..1a44a279a6 100644 --- a/qutip/utilities.py +++ b/qutip/utilities.py @@ -348,7 +348,7 @@ def iterated_fit( fun: Callable[..., complex], num_params: int, xdata: ArrayLike, ydata: ArrayLike, target_rmse: float = 1e-5, - guess: Callable[[int], ArrayLike] = None, + guess: ArrayLike | Callable[[int], ArrayLike] = None, Nmin: int = 1, Nmax: int = 10, lower: ArrayLike = None, upper: ArrayLike = None ) -> tuple[float, ArrayLike]: @@ -376,10 +376,13 @@ def iterated_fit( The dependent data. target_rmse : optional, float Desired normalized root mean squared error (default `1e-5`). - guess : optional, callable - A function that, given a number `N` of terms, returns an array - `[[p11, ..., p1n], [p21, ..., p2n], ..., [pN1, ..., pNn]]` - of initial guesses. + guess : optional, array_like or callable + This can be either a list of length `n`, with the i-th entry being the + guess for the parameter :math:`p_{k,i}` (for all terms :math:`k`), or a + function that provides different initial guesses for each term. + Specifically, given a number `N` of terms, the function returns an + array `[[p11, ..., p1n], [p21, ..., p2n], ..., [pN1, ..., pNn]]` of + initial guesses. Nmin : optional, int The minimum number of terms to be used for the fit (default 1). Nmax : optional, int @@ -418,14 +421,16 @@ def iterated_fit( while rmse1 > target_rmse and N <= Nmax: if guess is None: guesses = np.ones((N, num_params), dtype=float) - else: + elif callable(guess): guesses = np.array(guess(N)) if guesses.shape != (N, num_params): raise ValueError( "The shape of the provided fit guesses is not consistent") + else: + guesses = np.tile(guess, (N, 1)) - lower_repeat = np.repeat(lower, N) - upper_repeat = np.repeat(upper, N) + lower_repeat = np.tile(lower, N) + upper_repeat = np.tile(upper, N) rmse1, params = _fit(fun, num_params, xdata, ydata, N, guesses, lower_repeat, upper_repeat) N += 1