diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 4073b928..68c08000 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -321,7 +321,8 @@ scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds)) coeffs = {} # initialize basis and model -basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineLinearEval(6)) +basis = nmo.basis.RaisedCosineLinearEval(6).set_input_shape(1) +basis = nmo.basis.TransformerBasis(basis) model = nmo.glm.GLM(regularizer="Ridge") # loop over combinations @@ -441,13 +442,13 @@ We are now able to capture the distribution of the firing rate appropriately: bo In the previous example we set the number of basis functions of the [`Basis`](nemos.basis._basis.Basis) wrapped in our [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis). However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well. -Here we include `transformerbasis___basis` in the parameter grid to try different values for `TransformerBasis._basis`: +Here we include `transformerbasis__basis` in the parameter grid to try different values for `TransformerBasis.basis`: ```{code-cell} ipython3 param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), - transformerbasis___basis=( + transformerbasis__basis=( nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), @@ -481,7 +482,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_) # Read out the number of basis functions cvdf["transformerbasis_config"] = [ f"{b.__class__.__name__} - {b.n_basis_funcs}" - for b in cvdf["param_transformerbasis___basis"] + for b in cvdf["param_transformerbasis__basis"] ] cvdf_wide = cvdf.pivot( @@ -537,7 +538,7 @@ Please note that because it would lead to unexpected behavior, mixing the two wa param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), - transformerbasis___basis=( + transformerbasis__basis=( nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), @@ -592,7 +593,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_) # Read out the number of basis functions cvdf["transformerbasis_config"] = [ f"{b.__class__.__name__} - {b.n_basis_funcs}" - for b in cvdf["param_transformerbasis___basis"] + for b in cvdf["param_transformerbasis__basis"] ] cvdf_wide = cvdf.pivot( diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 4420eb40..37ba60ab 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +from functools import wraps from typing import TYPE_CHECKING, Generator import numpy as np @@ -11,6 +12,18 @@ from ._basis import Basis +def transformer_chaining(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + # Call the wrapped function and capture its return value + result = func(*args, **kwargs) + + # If the method returns the inner `self`, replace it with the outer `self` (no deepcopy here). + return self if result is self.basis else result + + return wrapper + + class TransformerBasis: """Basis as ``scikit-learn`` transformers. @@ -61,8 +74,16 @@ class TransformerBasis: Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} """ + _chainable_methods = ( + "set_kernel", + "set_input_shape", + "_set_input_independent_states", + "setup_basis", + ) + def __init__(self, basis: Basis): - self._basis = copy.deepcopy(basis) + self.basis = copy.deepcopy(basis) + self._wrapped_methods = {} # Cache for wrapped methods @staticmethod def _check_initialized(basis): @@ -73,14 +94,6 @@ def _check_initialized(basis): "`fit_transform`." ) - @property - def basis(self): - return self._basis - - @basis.setter - def basis(self, basis): - self._basis = basis - def _unpack_inputs(self, X: FeatureMatrix) -> Generator: """Unpack inputs. @@ -95,7 +108,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator: Returns ------- : - A list of each individual input. + A generator looping on each individual input. """ n_samples = X.shape[0] @@ -110,9 +123,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator: def fit(self, X: FeatureMatrix, y=None): """ - Compute the convolutional kernels. - - If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels. + Check the input structure and, if necessary, compute the convolutional kernels. Parameters ---------- @@ -126,6 +137,13 @@ def fit(self, X: FeatureMatrix, y=None): self : The transformer object. + Raises + ------ + RuntimeError + If ``self.n_basis_input`` is None. Call ``self.set_input_shape`` before calling ``fit`` to avoid this. + ValueError: + If the number of columns in X do not ``self.n_basis_input_``. + Examples -------- >>> import numpy as np @@ -139,9 +157,9 @@ def fit(self, X: FeatureMatrix, y=None): >>> transformer = TransformerBasis(basis) >>> transformer_fitted = transformer.fit(X) """ - self._check_initialized(self._basis) + self._check_initialized(self.basis) self._check_input(X, y) - self._basis.setup_basis(*self._unpack_inputs(X)) + self.basis.setup_basis(*self._unpack_inputs(X)) return self def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: @@ -181,10 +199,11 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Transform basis >>> feature_transformed = transformer.transform(X) """ - self._check_initialized(self._basis) + self._check_initialized(self.basis) + self._check_input(X, y) # transpose does not work with pynapple # can't use func(*X.T) to unwrap - return self._basis._compute_features(*self._unpack_inputs(X)) + return self.basis._compute_features(*self._unpack_inputs(X)) def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: """ @@ -231,7 +250,11 @@ def __getstate__(self): See https://docs.python.org/3/library/pickle.html#object.__getstate__ and https://docs.python.org/3/library/pickle.html#pickle-state """ - return {"_basis": self._basis} + # this is the only state needed at initalization + # returning the cached wrapped methods would create + # a circular binding of the state to self (i.e. infinite recursion when + # unpickling). + return {"basis": self.basis} def __setstate__(self, state): """ @@ -243,12 +266,17 @@ def __setstate__(self, state): See https://docs.python.org/3/library/pickle.html#object.__setstate__ and https://docs.python.org/3/library/pickle.html#pickle-state """ - self._basis = state["_basis"] + self.basis = state["basis"] + self._wrapped_methods = {} # Reinitialize the cache def __getattr__(self, name: str): """ Enable easy access to attributes of the underlying Basis object. + This method caches all chainable methods (methods returning self) in a dicitonary. + These methods are created the first time they are accessed by decorating the `self.basis.name` + and cached for future use. + Examples -------- >>> from nemos import basis @@ -259,11 +287,28 @@ def __getattr__(self, name: str): >>> trans_bas.n_basis_funcs 5 """ - return getattr(self._basis, name) + # Check if the method has already been wrapped + if name in self._wrapped_methods: + return self._wrapped_methods[name] + + if not hasattr(self.basis, name) or name == "to_transformer": + raise AttributeError(f"'TransformerBasis' object has no attribute '{name}'") + + # Get the original attribute from the basis + attr = getattr(self.basis, name) + + # If the attribute is a callable method, wrap it dynamically + if name in self._chainable_methods: + wrapped = transformer_chaining(attr).__get__(self) + self._wrapped_methods[name] = wrapped # Cache the wrapped method + return wrapped + + # For non-callable attributes, return them directly + return attr def __setattr__(self, name: str, value) -> None: r""" - Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax. + Allow setting basis or the attributes of basis with a convenient dot assignment syntax. Setting any other attribute is not allowed. @@ -274,33 +319,33 @@ def __setattr__(self, name: str, value) -> None: Raises ------ ValueError - If the attribute being set is not ``_basis`` or an attribute of ``_basis``. + If the attribute being set is not ``basis`` or an attribute of ``basis``. Examples -------- >>> import nemos as nmo >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.MSplineEval(10)) >>> # allowed - >>> trans_bas._basis = nmo.basis.BSplineEval(10) + >>> trans_bas.basis = nmo.basis.BSplineEval(10) >>> # allowed >>> trans_bas.n_basis_funcs = 20 >>> # not allowed >>> try: - ... trans_bas.random_attribute_name = "some value" + ... trans_bas.rand_atrr = "some value" ... except ValueError as e: ... print(repr(e)) - ValueError('Only setting _basis or existing attributes of _basis is allowed.') + ValueError('Only setting basis or existing attributes of basis is allowed. Attempt to set `rand_atrr`.') """ - # allow self._basis = basis - if name == "_basis": + # allow self.basis = basis and other attrs of self to be retrievable + if name in ["basis", "_wrapped_methods"]: super().__setattr__(name, value) - # allow changing existing attributes of self._basis - elif hasattr(self._basis, name): - setattr(self._basis, name, value) + # allow changing existing attributes of self.basis + elif hasattr(self.basis, name): + setattr(self.basis, name, value) # don't allow setting any other attribute else: raise ValueError( - "Only setting _basis or existing attributes of _basis is allowed." + f"Only setting basis or existing attributes of basis is allowed. Attempt to set `{name}`." ) def __sklearn_clone__(self) -> TransformerBasis: @@ -312,15 +357,15 @@ def __sklearn_clone__(self) -> TransformerBasis: For more info: https://scikit-learn.org/stable/developers/develop.html#cloning """ - cloned_obj = TransformerBasis(self._basis.__sklearn_clone__()) - cloned_obj._basis.kernel_ = None + cloned_obj = TransformerBasis(self.basis.__sklearn_clone__()) + cloned_obj.basis.kernel_ = None return cloned_obj def set_params(self, **parameters) -> TransformerBasis: """ Set TransformerBasis parameters. - When used with ``sklearn.model_selection``, users can set either the ``_basis`` attribute directly + When used with ``sklearn.model_selection``, users can set either the ``basis`` attribute directly or the parameters of the underlying Basis, but not both. Examples @@ -329,38 +374,41 @@ def set_params(self, **parameters) -> TransformerBasis: >>> basis = MSplineEval(10) >>> transformer_basis = TransformerBasis(basis=basis) - >>> # setting parameters of _basis is allowed + >>> # setting parameters of basis is allowed >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) 8 - >>> # setting _basis directly is allowed - >>> print(type(transformer_basis.set_params(_basis=BSplineEval(10))._basis)) + >>> # setting basis directly is allowed + >>> print(type(transformer_basis.set_params(basis=BSplineEval(10)).basis)) >>> # mixing is not allowed, this will raise an exception >>> try: - ... transformer_basis.set_params(_basis=BSplineEval(10), n_basis_funcs=2) + ... transformer_basis.set_params(basis=BSplineEval(10), n_basis_funcs=2) ... except ValueError as e: ... print(repr(e)) - ValueError('Set either new _basis object or parameters for existing _basis, not both.') + ValueError('Set either new basis object or parameters for existing basis, not both.') """ - new_basis = parameters.pop("_basis", None) + new_basis = parameters.pop("basis", None) if new_basis is not None: - self._basis = new_basis + self.basis = new_basis if len(parameters) > 0: raise ValueError( - "Set either new _basis object or parameters for existing _basis, not both." + "Set either new basis object or parameters for existing basis, not both." ) else: - self._basis = self._basis.set_params(**parameters) + self.basis = self.basis.set_params(**parameters) return self def get_params(self, deep: bool = True) -> dict: - """Extend the dict of parameters from the underlying Basis with _basis.""" - return {"_basis": self._basis, **self._basis.get_params(deep)} + """Extend the dict of parameters from the underlying Basis with basis.""" + return {"basis": self.basis, **self.basis.get_params(deep)} def __dir__(self) -> list[str]: """Extend the list of properties of methods with the ones from the underlying Basis.""" - return list(super().__dir__()) + list(self._basis.__dir__()) + unique_attrs = set(list(super().__dir__()) + list(self.basis.__dir__())) + # discard without raising errors if not present + unique_attrs.discard("to_transformer") + return list(unique_attrs) def __add__(self, other: TransformerBasis) -> TransformerBasis: """ @@ -376,7 +424,7 @@ def __add__(self, other: TransformerBasis) -> TransformerBasis: : TransformerBasis The resulting Basis object. """ - return TransformerBasis(self._basis + other._basis) + return TransformerBasis(self.basis + other.basis) def __mul__(self, other: TransformerBasis) -> TransformerBasis: """ @@ -392,7 +440,7 @@ def __mul__(self, other: TransformerBasis) -> TransformerBasis: : The resulting Basis object. """ - return TransformerBasis(self._basis * other._basis) + return TransformerBasis(self.basis * other.basis) def __pow__(self, exponent: int) -> TransformerBasis: """Exponentiation of a TransformerBasis object. @@ -418,7 +466,7 @@ def __pow__(self, exponent: int) -> TransformerBasis: If the integer is zero or negative. """ # errors are handled by Basis.__pow__ - return TransformerBasis(self._basis**exponent) + return TransformerBasis(self.basis**exponent) def _check_input(self, X: FeatureMatrix, y=None): """Check that the input structure. @@ -455,5 +503,5 @@ def _check_input(self, X: FeatureMatrix, y=None): if y is not None and y.shape[0] != X.shape[0]: raise ValueError( "X and y must have the same number of samples. " - f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples." + f"X has {X.shape[0]} samples, while y has {y.shape[0]} samples." ) diff --git a/tests/test_basis.py b/tests/test_basis.py index 8daeb683..4c2dfb40 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,6 +1,5 @@ import inspect import itertools -import pickle import re from contextlib import nullcontext as does_not_raise from functools import partial @@ -1271,7 +1270,7 @@ def test_transformer_get_params(self, cls): bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() - params_transf.pop("_basis") + params_transf.pop("basis") params_basis = bas.get_params() rates_1 = params_basis.pop("decay_rates", 1) rates_2 = params_transf.pop("decay_rates", 1) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 9e52a4f2..e824745e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,10 +21,10 @@ ) def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) - pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y) + pipe.fit(X[:, : bas.basis._n_input_dimensionality] ** 2, y) @pytest.mark.parametrize( @@ -39,7 +39,7 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): ) def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") @@ -60,7 +60,7 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( bas, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV( @@ -89,7 +89,7 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( - transformerbasis___basis=( + transformerbasis__basis=( bas_cls(5).set_input_shape(*([1] * bas._n_input_dimensionality)), bas_cls(10).set_input_shape(*([1] * bas._n_input_dimensionality)), bas_cls(20).set_input_shape(*([1] * bas._n_input_dimensionality)), @@ -117,13 +117,13 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( - transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), + transformerbasis__basis=(bas_cls(5), bas_cls(10), bas_cls(20)), transformerbasis__n_basis_funcs=(4, 5, 10), ) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") with pytest.raises( ValueError, - match="Set either new _basis object or parameters for existing _basis, not both.", + match="Set either new basis object or parameters for existing basis, not both.", ): gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -173,14 +173,14 @@ def test_sklearn_transformer_pipeline_pynapple( ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]]) X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep) y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep) - bas = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - bas = TransformerBasis(bas) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) + # fit a pipeline & predict from pynapple pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) - pipe.fit(X_nap[:, : bas._basis._n_input_dimensionality] ** 2, y_nap) + pipe.fit(X_nap[:, : bas.basis._n_input_dimensionality] ** 2, y_nap) # get rate - rate = pipe.predict(X_nap[:, : bas._basis._n_input_dimensionality] ** 2) + rate = pipe.predict(X_nap[:, : bas.basis._n_input_dimensionality] ** 2) # check rate is Tsd with same time info assert isinstance(rate, nap.Tsd) assert np.all(rate.t == X_nap.t) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py new file mode 100644 index 00000000..97e55b50 --- /dev/null +++ b/tests/test_transformer_basis.py @@ -0,0 +1,916 @@ +import pickle +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pytest +from conftest import CombinedBasis, list_all_basis_classes +from sklearn.base import clone as sk_clone +from sklearn.pipeline import Pipeline + +import nemos as nmo +from nemos import basis +from nemos._inspect_utils import get_subclass_methods, list_abstract_methods +from nemos.basis import AdditiveBasis, MSplineConv, MultiplicativeBasis + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_has_the_same_public_attributes_as_basis( + basis_cls, basis_class_specific_params +): + n_basis_funcs = 5 + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 + ) + + public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} + public_attrs_transformerbasis = { + attr + for attr in dir( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)).to_transformer() + ) + if not attr.startswith("_") + } + + assert public_attrs_transformerbasis - public_attrs_basis == { + "fit", + "fit_transform", + "transform", + "basis", + } + + assert public_attrs_basis - public_attrs_transformerbasis == {"to_transformer"} + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_to_transformer_and_constructor_are_equivalent( + basis_cls, basis_class_specific_params +): + n_basis_funcs = 5 + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 + ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + trans_bas_a = bas.to_transformer() + trans_bas_b = basis.TransformerBasis(bas) + + # they both just have a _basis + assert ( + list(trans_bas_a.__dict__.keys()) + == list(trans_bas_b.__dict__.keys()) + == ["basis", "_wrapped_methods"] + ) + # and those bases are the same + assert np.all( + trans_bas_a.basis.__dict__.pop("_decay_rates", 1) + == trans_bas_b.basis.__dict__.pop("_decay_rates", 1) + ) + + # extract the wrapped func for these methods + wrapped_methods_a = {} + for method in trans_bas_a._chainable_methods: + out = trans_bas_a.basis.__dict__.pop(method, False) + val = out if out is False else out.__func__.__qualname__ + wrapped_methods_a.update({method: val}) + + wrapped_methods_b = {} + for method in trans_bas_b._chainable_methods: + out = trans_bas_b.basis.__dict__.pop(method, False) + val = out if out is False else out.__func__.__qualname__ + wrapped_methods_b.update({method: val}) + + assert wrapped_methods_a == wrapped_methods_b + assert trans_bas_a.basis.__dict__ == trans_bas_b.basis.__dict__ + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_basis_to_transformer_makes_a_copy(basis_cls, basis_class_specific_params): + bas_a = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_a = bas_a.set_input_shape( + *([1] * bas_a._n_input_dimensionality) + ).to_transformer() + + # changing an attribute in bas should not change trans_bas + if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + bas_a.basis1.n_basis_funcs = 10 + assert trans_bas_a.basis.basis1.n_basis_funcs == 5 + + # changing an attribute in the transformer basis should not change the original + bas_b = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) + trans_bas_b = bas_b.to_transformer() + trans_bas_b.basis.basis1.n_basis_funcs = 100 + assert bas_b.basis1.n_basis_funcs == 5 + else: + bas_a.n_basis_funcs = 10 + assert trans_bas_a.n_basis_funcs == 5 + + # changing an attribute in the transformer basis should not change the original + bas_b = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_b = bas_b.set_input_shape( + *([1] * bas_b._n_input_dimensionality) + ).to_transformer() + trans_bas_b.n_basis_funcs = 100 + assert bas_b.n_basis_funcs == 5 + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) +def test_transformerbasis_getattr( + basis_cls, n_basis_funcs, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=30 + ) + trans_basis = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + for bas in [getattr(trans_basis.basis, attr) for attr in ("basis1", "basis2")]: + assert bas.n_basis_funcs == n_basis_funcs + else: + assert trans_basis.n_basis_funcs == n_basis_funcs + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +@pytest.mark.parametrize("n_basis_funcs_init", [5]) +@pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) +def test_transformerbasis_set_params( + basis_cls, n_basis_funcs_init, n_basis_funcs_new, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs_init, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_basis = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) + + assert trans_basis.n_basis_funcs == n_basis_funcs_new + assert trans_basis.basis.n_basis_funcs == n_basis_funcs_new + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_transformerbasis_setattr_basis(basis_cls, basis_class_specific_params): + # setting the _basis attribute should change it + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=30 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + + bas = CombinedBasis().instantiate_basis( + 20, basis_cls, basis_class_specific_params, window_size=30 + ) + + trans_bas.basis = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + + assert trans_bas.n_basis_funcs == 20 + assert trans_bas.basis.n_basis_funcs == 20 + assert isinstance(trans_bas.basis, basis_cls) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_transformerbasis_setattr_basis_attribute( + basis_cls, basis_class_specific_params +): + # setting an attribute that is an attribute of the underlying _basis + # should propagate setting it on _basis itself + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + trans_bas.n_basis_funcs = 20 + + assert trans_bas.n_basis_funcs == 20 + assert trans_bas.basis.n_basis_funcs == 20 + assert isinstance(trans_bas.basis, basis_cls) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_transformerbasis_copy_basis_on_construct( + basis_cls, basis_class_specific_params +): + # modifying the transformerbasis's attributes shouldn't + # touch the original basis that was used to create it + orig_bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=10 + ) + orig_bas = orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) + trans_bas = basis.TransformerBasis(orig_bas) + trans_bas.n_basis_funcs = 20 + + assert orig_bas.n_basis_funcs == 10 + assert trans_bas.basis.n_basis_funcs == 20 + assert trans_bas.basis.n_basis_funcs == 20 + assert isinstance(trans_bas.basis, basis_cls) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_setattr_illegal_attribute( + basis_cls, basis_class_specific_params +): + # changing an attribute that is not _basis or an attribute of _basis + # is not allowed + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + + with pytest.raises( + ValueError, + match="Only setting basis or existing attributes of basis is allowed.", + ): + trans_bas.random_attr = "random value" + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_addition(basis_cls, basis_class_specific_params): + n_basis_funcs_a = 5 + n_basis_funcs_b = n_basis_funcs_a * 2 + bas_a = CombinedBasis().instantiate_basis( + n_basis_funcs_a, basis_cls, basis_class_specific_params, window_size=10 + ) + bas_a.set_input_shape(*([1] * bas_a._n_input_dimensionality)) + bas_b = CombinedBasis().instantiate_basis( + n_basis_funcs_b, basis_cls, basis_class_specific_params, window_size=10 + ) + bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) + trans_bas_a = basis.TransformerBasis(bas_a) + trans_bas_b = basis.TransformerBasis(bas_b) + trans_bas_sum = trans_bas_a + trans_bas_b + assert isinstance(trans_bas_sum, basis.TransformerBasis) + assert isinstance(trans_bas_sum.basis, basis.AdditiveBasis) + assert ( + trans_bas_sum.n_basis_funcs + == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_sum._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) + if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + assert trans_bas_sum.basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_sum.basis2.n_basis_funcs == n_basis_funcs_b + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_multiplication(basis_cls, basis_class_specific_params): + n_basis_funcs_a = 5 + n_basis_funcs_b = n_basis_funcs_a * 2 + bas1 = CombinedBasis().instantiate_basis( + n_basis_funcs_a, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_a = basis.TransformerBasis( + bas1.set_input_shape(*([1] * bas1._n_input_dimensionality)) + ) + bas2 = CombinedBasis().instantiate_basis( + n_basis_funcs_b, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_b = basis.TransformerBasis( + bas2.set_input_shape(*([1] * bas2._n_input_dimensionality)) + ) + trans_bas_prod = trans_bas_a * trans_bas_b + assert isinstance(trans_bas_prod, basis.TransformerBasis) + assert isinstance(trans_bas_prod.basis, basis.MultiplicativeBasis) + assert ( + trans_bas_prod.n_basis_funcs + == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_prod._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) + if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + assert trans_bas_prod.basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_prod.basis2.n_basis_funcs == n_basis_funcs_b + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize( + "exponent, error_type, error_message", + [ + (2, does_not_raise, None), + (5, does_not_raise, None), + (0.5, TypeError, "Exponent should be an integer"), + (-1, ValueError, "Exponent should be a non-negative integer"), + ], +) +def test_transformerbasis_exponentiation( + basis_cls, exponent: int, error_type, error_message, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + + if not isinstance(exponent, int): + with pytest.raises(error_type, match=error_message): + trans_bas_exp = trans_bas**exponent + assert isinstance(trans_bas_exp, basis.TransformerBasis) + assert isinstance(trans_bas_exp.basis, basis.MultiplicativeBasis) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_dir(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + for attr_name in ( + "fit", + "transform", + "fit_transform", + "n_basis_funcs", + "mode", + "window_size", + ): + if ( + attr_name == "window_size" + and "Conv" not in trans_bas.basis.__class__.__name__ + ): + continue + assert attr_name in dir(trans_bas) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv"), +) +def test_transformerbasis_sk_clone_kernel_noned(basis_cls, basis_class_specific_params): + orig_bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=20 + ) + orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) + trans_bas = basis.TransformerBasis(orig_bas) + + # kernel should be saved in the object after fit + trans_bas.fit(np.random.randn(100, 1)) + assert isinstance(trans_bas.kernel_, np.ndarray) + + # cloning should set kernel_ to None + trans_bas_clone = sk_clone(trans_bas) + + # the original object should still have kernel_ + assert isinstance(trans_bas.kernel_, np.ndarray) + # but the clone should not have one + assert trans_bas_clone.kernel_ is None + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize("n_basis_funcs", [5]) +def test_transformerbasis_pickle( + tmpdir, basis_cls, n_basis_funcs, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 + ) + # the test that tries cross-validation with n_jobs = 2 already should test this + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + filepath = tmpdir / "transformerbasis.pickle" + with open(filepath, "wb") as f: + pickle.dump(trans_bas, f) + with open(filepath, "rb") as f: + trans_bas2 = pickle.load(f) + + assert isinstance(trans_bas2, basis.TransformerBasis) + if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + for bas in [getattr(trans_bas2.basis, attr) for attr in ("basis1", "basis2")]: + assert bas.n_basis_funcs == n_basis_funcs + else: + assert trans_bas2.n_basis_funcs == n_basis_funcs + + +@pytest.mark.parametrize( + "set_input, expectation", + [ + (True, does_not_raise()), + ( + False, + pytest.raises( + RuntimeError, + match="Cannot apply TransformerBasis: the provided basis has no defined input shape", + ), + ), + ], +) +@pytest.mark.parametrize( + "inp", [np.ones((10,)), np.ones((10, 1)), np.ones((10, 2)), np.ones((10, 2, 3))] +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_to_transformer_and_set_input( + basis_cls, inp, set_input, expectation, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + if set_input: + bas.set_input_shape(*([inp] * bas._n_input_dimensionality)) + trans = bas.to_transformer() + with expectation: + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + trans.fit(X) + + +@pytest.mark.parametrize( + "inp, expectation", + [ + (np.ones((10,)), pytest.raises(ValueError, match="X must be 2-")), + (np.ones((10, 1)), does_not_raise()), + (np.ones((10, 2)), does_not_raise()), + (np.ones((10, 2, 3)), pytest.raises(ValueError, match="X must be 2-")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit(basis_cls, inp, basis_class_specific_params, expectation): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + transformer.fit(X) + if "Conv" in basis_cls.__name__: + assert transformer.kernel_ is not None + + # try and pass segmented time series + if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): + if inp.ndim == 2: + expectation = pytest.raises(ValueError, match="Input mismatch: expected ") + + with expectation: + transformer.fit(*([inp] * bas._n_input_dimensionality)) + + +@pytest.mark.parametrize( + "inp", + [ + np.ones((10, 1)), + np.ones((10, 2)), + ], +) +@pytest.mark.parametrize( + "delta_input, expectation", + [ + (0, does_not_raise()), + (1, pytest.raises(ValueError, match="Input mismatch: expected ")), + (-1, pytest.raises(ValueError, match="Input mismatch: expected ")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_input_shape_mismatch( + basis_cls, delta_input, inp, basis_class_specific_params, expectation +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.random.randn(10, int(sum(bas._n_basis_input_) + delta_input)) + with expectation: + transformer.fit(X) + + +@pytest.mark.parametrize( + "inp", + [ + np.random.randn( + 10, + ), + np.random.randn(10, 1), + np.random.randn(10, 2), + np.random.randn(10, 2, 3), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_transform(basis_cls, inp, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + transformer.fit(X) + + out = transformer.transform(X) + out2 = bas.compute_features(*([inp] * bas._n_input_dimensionality)) + + assert np.array_equal(out, out2, equal_nan=True) + + +@pytest.mark.parametrize( + "inp", + [ + np.random.randn( + 10, + ), + np.random.randn(10, 1), + np.random.randn(10, 2), + np.random.randn(10, 2, 3), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_transform(basis_cls, inp, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + + out = transformer.fit_transform(X) + out2 = bas.compute_features(*([inp] * bas._n_input_dimensionality)) + + assert np.array_equal(out, out2, equal_nan=True) + + +@pytest.mark.parametrize( + "inp", + [ + np.ones((10, 1)), + np.ones((10, 2)), + ], +) +@pytest.mark.parametrize( + "delta_input, expectation", + [ + (0, does_not_raise()), + (1, pytest.raises(ValueError, match="Input mismatch: expected ")), + (-1, pytest.raises(ValueError, match="Input mismatch: expected ")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_transform_input_shape_mismatch( + basis_cls, delta_input, inp, basis_class_specific_params, expectation +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.random.randn(10, int(sum(bas._n_basis_input_) + delta_input)) + with expectation: + transformer.fit_transform(X) + + +@pytest.mark.parametrize( + "inp, expectation", + [ + (np.ones((10,)), pytest.raises(ValueError, match="X must be 2-")), + (np.ones((10, 1)), does_not_raise()), + (np.ones((10, 2)), does_not_raise()), + (np.ones((10, 2, 3)), pytest.raises(ValueError, match="X must be 2-")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_transform_input_struct( + basis_cls, inp, basis_class_specific_params, expectation +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + transformer.fit_transform(X) + + if "Conv" in basis_cls.__name__: + assert transformer.kernel_ is not None + + # try and pass a tuple of time series + if ( + isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) + and inp.ndim != 2 + ): + expectation = pytest.raises(ValueError, match="X must be 2-") + elif ( + isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) + and inp.ndim == 2 + ): + expectation = pytest.raises(ValueError, match="Input mismatch: expected") + with expectation: + transformer.fit(*([inp] * bas._n_input_dimensionality)) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize( + "inp", + [ + 0.1 + * np.random.randn( + 100, + ), + 0.1 * np.random.randn(100, 1), + 0.1 * np.random.randn(100, 2), + 0.1 * np.random.randn(100, 1, 2), + ], +) +def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + # fit outside pipeline + X = bas.compute_features(*([inp] * bas._n_input_dimensionality)) + log_mu = X.dot(0.005 * np.ones(X.shape[1])) + y = np.full(X.shape[0], 0) + y[~np.isnan(log_mu)] = np.random.poisson( + np.exp(log_mu[~np.isnan(log_mu)] - np.nanmean(log_mu)) + ) + model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001).fit(X, y) + + # pipeline + pipe = Pipeline( + [ + ("bas", transformer), + ("glm", nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)), + ] + ) + x = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + pipe.fit(x, y) + np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) + + # set basis & refit + if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): + pipe.set_params(bas__basis2__n_basis_funcs=4) + assert ( + bas.basis2.n_basis_funcs == 5 + ) # make sure that the change did not affect bas + X = bas.set_params(basis2__n_basis_funcs=4).compute_features( + *([inp] * bas._n_input_dimensionality) + ) + else: + pipe.set_params(bas__n_basis_funcs=4) + assert bas.n_basis_funcs == 5 # make sure that the change did not affect bas + X = bas.set_params(n_basis_funcs=4).compute_features( + *([inp] * bas._n_input_dimensionality) + ) + pipe.fit(x, y) + model.fit(X, y) + np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_initialization(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): + transformer.fit(np.ones((100,))) + + with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): + transformer.transform(np.ones((100,))) + + with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): + transformer.fit_transform(np.ones((100,))) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_basis_setter(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + + bas2 = CombinedBasis().instantiate_basis( + 7, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + transformer.basis = bas2 + assert transformer.basis.n_basis_funcs == bas2.n_basis_funcs + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_getstate(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + state = transformer.__getstate__() + assert {"_basis": transformer.basis} == state + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_eetstate(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + bas2 = CombinedBasis().instantiate_basis( + 7, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + state = {"basis": bas2} + transformer.__setstate__(state) + assert transformer.basis == bas2 + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_to_transformer_not_an_attribute_of_transformer_basis( + basis_cls, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + bas = bas.to_transformer() + assert "to_transformer" not in bas.__dir__() + + with pytest.raises( + AttributeError, + match="'TransformerBasis' object has no attribute 'to_transformer'", + ): + bas.to_transformer() + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_getstate(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + lst = transformer.__dir__() + dict_abst_method = list_abstract_methods(nmo.basis._basis.Basis) + + # check it finds all abc basis methods + for meth in dict_abst_method: + assert meth[0] in lst + + # check all reimplemented methods + dict_reimplemented_method = get_subclass_methods(basis_cls) + for meth in dict_reimplemented_method: + if meth[0] == "to_transformer": + continue + assert meth[0] in lst + + # check that it is a trnasformer + for meth in ["fit", "transform", "fit_transform"]: + assert meth in lst + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize( + "inp, expectation", + [ + ( + np.random.randn(10, 2), + pytest.raises(ValueError, match="Input mismatch: expected \d inputs"), + ), + ( + np.random.randn(10, 3, 1), + pytest.raises(ValueError, match="X must be 2-dimensional"), + ), + ( + {1: np.random.randn(10, 3)}, + pytest.raises(ValueError, match="The input must be a 2-dimensional array"), + ), + (np.random.randn(10, 3), does_not_raise()), + ], +) +@pytest.mark.parametrize("method", ["fit", "transform", "fit_transform"]) +def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, method): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + # set kernels + bas._set_input_independent_states() + # set input shape + transformer = bas.to_transformer().set_input_shape( + *([3] * bas._n_input_dimensionality) + ) + if isinstance(bas, (AdditiveBasis, MultiplicativeBasis)): + if hasattr(inp, "ndim"): + ndim = inp.ndim + inp = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + if ndim == 3: + inp = inp[..., np.newaxis] + + meth = getattr(transformer, method) + + with expectation: + meth(inp) + with pytest.raises(ValueError, match="X and y must have the same"): + meth(inp, np.ones(11))