diff --git a/src/gen_experiments/odes.py b/src/gen_experiments/odes.py index 2f3acf3..e45277a 100644 --- a/src/gen_experiments/odes.py +++ b/src/gen_experiments/odes.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Callable import matplotlib.pyplot as plt @@ -14,9 +15,9 @@ from .utils import ( FullSINDyTrialData, SINDyTrialData, - _make_model, coeff_metrics, integration_metrics, + make_model, simulate_test_data, unionize_coeff_matrices, ) @@ -59,7 +60,7 @@ def forcing(t, x): p_duff = [0.2, 0.05, 1] -p_lotka = [1, 10] +p_lotka = [5, 1] p_ross = [0.2, 0.2, 5.7] p_hopf = [-0.05, 1, 1] @@ -73,7 +74,7 @@ def forcing(t, x): ], }, "lv": { - "rhsfunc": ps.utils.odes.lotka, + "rhsfunc": partial(ps.utils.odes.lotka, p=p_lotka), "input_features": ["x", "y"], "coeff_true": [ {"x": p_lotka[0], "x y": -p_lotka[1]}, @@ -181,7 +182,7 @@ def run( nonnegative=nonnegative, **sim_params, ) - model = _make_model(input_features, dt, diff_params, feat_params, opt_params) + model = make_model(input_features, dt, diff_params, feat_params, opt_params) model.fit(x_train) coeff_true, coefficients, feature_names = unionize_coeff_matrices(model, coeff_true) diff --git a/src/gen_experiments/pdes.py b/src/gen_experiments/pdes.py index 8cb712b..985ed10 100644 --- a/src/gen_experiments/pdes.py +++ b/src/gen_experiments/pdes.py @@ -7,9 +7,9 @@ from .utils import ( FullSINDyTrialData, SINDyTrialData, - _make_model, coeff_metrics, integration_metrics, + make_model, simulate_test_data, unionize_coeff_matrices, ) @@ -171,7 +171,7 @@ def run( dt=time_args[0], t_end=time_args[1], ) - model = _make_model(input_features, dt, diff_params, feat_params, opt_params) + model = make_model(input_features, dt, diff_params, feat_params, opt_params) model.fit(x_train, t=t_train) coeff_true, coefficients, feature_names = unionize_coeff_matrices(model, coeff_true)