From bae3e923564c36c6397af2ea73fac0df8269a2c8 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:48:56 -0800 Subject: [PATCH] TST: Added test for unionize_coeff_matrices --- tests/test_all.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_all.py b/tests/test_all.py index 13274fd..56da312 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,6 +1,9 @@ +import numpy as np +import pysindy as ps import pytest from gen_experiments.typing import NestedDict +from gen_experiments.utils import unionize_coeff_matrices def test_flatten_nested_dict(): @@ -19,3 +22,16 @@ def test_flatten_nested_bad_dict(): with pytest.raises(TypeError, match="Only string keys allowed"): deep = NestedDict(a={1: 1}) deep.flatten() + + +def test_unionize_coeff_matrices(): + # lib = ps.PolynomialLibrary().fit(np.array([[1, 1]])) + model = ps.SINDy(feature_names=["x", "y"]) + data = np.arange(10) + data = np.vstack((data, data)).T + model.fit(data, 0.1) + coeff_true = [{"y": -1, "zorp_x": 0.1}, {"x": 1, "zorp_y": 0.1}] + true, est, feats = unionize_coeff_matrices(model, coeff_true) + assert len(feats) == true.shape[1] + assert len(feats) == est.shape[1] + assert est.shape == true.shape