Skip to content

Commit

Permalink
Merge pull request #280 from flatironinstitute/improve_transformer_api
Browse files Browse the repository at this point in the history
Improve transformer api
  • Loading branch information
BalzaniEdoardo authored Dec 20, 2024
2 parents e8a62e8 + a93a002 commit 8b1b403
Show file tree
Hide file tree
Showing 5 changed files with 1,033 additions and 69 deletions.
13 changes: 7 additions & 6 deletions docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds))
coeffs = {}
# initialize basis and model
basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineLinearEval(6))
basis = nmo.basis.RaisedCosineLinearEval(6).set_input_shape(1)
basis = nmo.basis.TransformerBasis(basis)
model = nmo.glm.GLM(regularizer="Ridge")
# loop over combinations
Expand Down Expand Up @@ -441,13 +442,13 @@ We are now able to capture the distribution of the firing rate appropriately: bo

In the previous example we set the number of basis functions of the [`Basis`](nemos.basis._basis.Basis) wrapped in our [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis). However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well.

Here we include `transformerbasis___basis` in the parameter grid to try different values for `TransformerBasis._basis`:
Here we include `transformerbasis__basis` in the parameter grid to try different values for `TransformerBasis.basis`:


```{code-cell} ipython3
param_grid = dict(
glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
transformerbasis___basis=(
transformerbasis__basis=(
nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1),
nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1),
nmo.basis.RaisedCosineLogEval(5).set_input_shape(1),
Expand Down Expand Up @@ -481,7 +482,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_)
# Read out the number of basis functions
cvdf["transformerbasis_config"] = [
f"{b.__class__.__name__} - {b.n_basis_funcs}"
for b in cvdf["param_transformerbasis___basis"]
for b in cvdf["param_transformerbasis__basis"]
]
cvdf_wide = cvdf.pivot(
Expand Down Expand Up @@ -537,7 +538,7 @@ Please note that because it would lead to unexpected behavior, mixing the two wa
param_grid = dict(
glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100),
transformerbasis___basis=(
transformerbasis__basis=(
nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1),
nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1),
nmo.basis.RaisedCosineLogEval(5).set_input_shape(1),
Expand Down Expand Up @@ -592,7 +593,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_)
# Read out the number of basis functions
cvdf["transformerbasis_config"] = [
f"{b.__class__.__name__} - {b.n_basis_funcs}"
for b in cvdf["param_transformerbasis___basis"]
for b in cvdf["param_transformerbasis__basis"]
]
cvdf_wide = cvdf.pivot(
Expand Down
148 changes: 98 additions & 50 deletions src/nemos/basis/_transformer_basis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
from functools import wraps
from typing import TYPE_CHECKING, Generator

import numpy as np
Expand All @@ -11,6 +12,18 @@
from ._basis import Basis


def transformer_chaining(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
# Call the wrapped function and capture its return value
result = func(*args, **kwargs)

# If the method returns the inner `self`, replace it with the outer `self` (no deepcopy here).
return self if result is self.basis else result

return wrapper


class TransformerBasis:
"""Basis as ``scikit-learn`` transformers.
Expand Down Expand Up @@ -61,8 +74,16 @@ class TransformerBasis:
Cross-validated number of basis: {'compute_features__n_basis_funcs': 10}
"""

_chainable_methods = (
"set_kernel",
"set_input_shape",
"_set_input_independent_states",
"setup_basis",
)

def __init__(self, basis: Basis):
self._basis = copy.deepcopy(basis)
self.basis = copy.deepcopy(basis)
self._wrapped_methods = {} # Cache for wrapped methods

@staticmethod
def _check_initialized(basis):
Expand All @@ -73,14 +94,6 @@ def _check_initialized(basis):
"`fit_transform`."
)

@property
def basis(self):
return self._basis

@basis.setter
def basis(self, basis):
self._basis = basis

def _unpack_inputs(self, X: FeatureMatrix) -> Generator:
"""Unpack inputs.
Expand All @@ -95,7 +108,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator:
Returns
-------
:
A list of each individual input.
A generator looping on each individual input.
"""
n_samples = X.shape[0]
Expand All @@ -110,9 +123,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator:

def fit(self, X: FeatureMatrix, y=None):
"""
Compute the convolutional kernels.
If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels.
Check the input structure and, if necessary, compute the convolutional kernels.
Parameters
----------
Expand All @@ -126,6 +137,13 @@ def fit(self, X: FeatureMatrix, y=None):
self :
The transformer object.
Raises
------
RuntimeError
If ``self.n_basis_input`` is None. Call ``self.set_input_shape`` before calling ``fit`` to avoid this.
ValueError:
If the number of columns in X do not ``self.n_basis_input_``.
Examples
--------
>>> import numpy as np
Expand All @@ -139,9 +157,9 @@ def fit(self, X: FeatureMatrix, y=None):
>>> transformer = TransformerBasis(basis)
>>> transformer_fitted = transformer.fit(X)
"""
self._check_initialized(self._basis)
self._check_initialized(self.basis)
self._check_input(X, y)
self._basis.setup_basis(*self._unpack_inputs(X))
self.basis.setup_basis(*self._unpack_inputs(X))
return self

def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
Expand Down Expand Up @@ -181,10 +199,11 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
>>> # Transform basis
>>> feature_transformed = transformer.transform(X)
"""
self._check_initialized(self._basis)
self._check_initialized(self.basis)
self._check_input(X, y)
# transpose does not work with pynapple
# can't use func(*X.T) to unwrap
return self._basis._compute_features(*self._unpack_inputs(X))
return self.basis._compute_features(*self._unpack_inputs(X))

def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
"""
Expand Down Expand Up @@ -231,7 +250,11 @@ def __getstate__(self):
See https://docs.python.org/3/library/pickle.html#object.__getstate__
and https://docs.python.org/3/library/pickle.html#pickle-state
"""
return {"_basis": self._basis}
# this is the only state needed at initalization
# returning the cached wrapped methods would create
# a circular binding of the state to self (i.e. infinite recursion when
# unpickling).
return {"basis": self.basis}

def __setstate__(self, state):
"""
Expand All @@ -243,12 +266,17 @@ def __setstate__(self, state):
See https://docs.python.org/3/library/pickle.html#object.__setstate__
and https://docs.python.org/3/library/pickle.html#pickle-state
"""
self._basis = state["_basis"]
self.basis = state["basis"]
self._wrapped_methods = {} # Reinitialize the cache

def __getattr__(self, name: str):
"""
Enable easy access to attributes of the underlying Basis object.
This method caches all chainable methods (methods returning self) in a dicitonary.
These methods are created the first time they are accessed by decorating the `self.basis.name`
and cached for future use.
Examples
--------
>>> from nemos import basis
Expand All @@ -259,11 +287,28 @@ def __getattr__(self, name: str):
>>> trans_bas.n_basis_funcs
5
"""
return getattr(self._basis, name)
# Check if the method has already been wrapped
if name in self._wrapped_methods:
return self._wrapped_methods[name]

if not hasattr(self.basis, name) or name == "to_transformer":
raise AttributeError(f"'TransformerBasis' object has no attribute '{name}'")

# Get the original attribute from the basis
attr = getattr(self.basis, name)

# If the attribute is a callable method, wrap it dynamically
if name in self._chainable_methods:
wrapped = transformer_chaining(attr).__get__(self)
self._wrapped_methods[name] = wrapped # Cache the wrapped method
return wrapped

# For non-callable attributes, return them directly
return attr

def __setattr__(self, name: str, value) -> None:
r"""
Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax.
Allow setting basis or the attributes of basis with a convenient dot assignment syntax.
Setting any other attribute is not allowed.
Expand All @@ -274,33 +319,33 @@ def __setattr__(self, name: str, value) -> None:
Raises
------
ValueError
If the attribute being set is not ``_basis`` or an attribute of ``_basis``.
If the attribute being set is not ``basis`` or an attribute of ``basis``.
Examples
--------
>>> import nemos as nmo
>>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.MSplineEval(10))
>>> # allowed
>>> trans_bas._basis = nmo.basis.BSplineEval(10)
>>> trans_bas.basis = nmo.basis.BSplineEval(10)
>>> # allowed
>>> trans_bas.n_basis_funcs = 20
>>> # not allowed
>>> try:
... trans_bas.random_attribute_name = "some value"
... trans_bas.rand_atrr = "some value"
... except ValueError as e:
... print(repr(e))
ValueError('Only setting _basis or existing attributes of _basis is allowed.')
ValueError('Only setting basis or existing attributes of basis is allowed. Attempt to set `rand_atrr`.')
"""
# allow self._basis = basis
if name == "_basis":
# allow self.basis = basis and other attrs of self to be retrievable
if name in ["basis", "_wrapped_methods"]:
super().__setattr__(name, value)
# allow changing existing attributes of self._basis
elif hasattr(self._basis, name):
setattr(self._basis, name, value)
# allow changing existing attributes of self.basis
elif hasattr(self.basis, name):
setattr(self.basis, name, value)
# don't allow setting any other attribute
else:
raise ValueError(
"Only setting _basis or existing attributes of _basis is allowed."
f"Only setting basis or existing attributes of basis is allowed. Attempt to set `{name}`."
)

def __sklearn_clone__(self) -> TransformerBasis:
Expand All @@ -312,15 +357,15 @@ def __sklearn_clone__(self) -> TransformerBasis:
For more info: https://scikit-learn.org/stable/developers/develop.html#cloning
"""
cloned_obj = TransformerBasis(self._basis.__sklearn_clone__())
cloned_obj._basis.kernel_ = None
cloned_obj = TransformerBasis(self.basis.__sklearn_clone__())
cloned_obj.basis.kernel_ = None
return cloned_obj

def set_params(self, **parameters) -> TransformerBasis:
"""
Set TransformerBasis parameters.
When used with ``sklearn.model_selection``, users can set either the ``_basis`` attribute directly
When used with ``sklearn.model_selection``, users can set either the ``basis`` attribute directly
or the parameters of the underlying Basis, but not both.
Examples
Expand All @@ -329,38 +374,41 @@ def set_params(self, **parameters) -> TransformerBasis:
>>> basis = MSplineEval(10)
>>> transformer_basis = TransformerBasis(basis=basis)
>>> # setting parameters of _basis is allowed
>>> # setting parameters of basis is allowed
>>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs)
8
>>> # setting _basis directly is allowed
>>> print(type(transformer_basis.set_params(_basis=BSplineEval(10))._basis))
>>> # setting basis directly is allowed
>>> print(type(transformer_basis.set_params(basis=BSplineEval(10)).basis))
<class 'nemos.basis.basis.BSplineEval'>
>>> # mixing is not allowed, this will raise an exception
>>> try:
... transformer_basis.set_params(_basis=BSplineEval(10), n_basis_funcs=2)
... transformer_basis.set_params(basis=BSplineEval(10), n_basis_funcs=2)
... except ValueError as e:
... print(repr(e))
ValueError('Set either new _basis object or parameters for existing _basis, not both.')
ValueError('Set either new basis object or parameters for existing basis, not both.')
"""
new_basis = parameters.pop("_basis", None)
new_basis = parameters.pop("basis", None)
if new_basis is not None:
self._basis = new_basis
self.basis = new_basis
if len(parameters) > 0:
raise ValueError(
"Set either new _basis object or parameters for existing _basis, not both."
"Set either new basis object or parameters for existing basis, not both."
)
else:
self._basis = self._basis.set_params(**parameters)
self.basis = self.basis.set_params(**parameters)

return self

def get_params(self, deep: bool = True) -> dict:
"""Extend the dict of parameters from the underlying Basis with _basis."""
return {"_basis": self._basis, **self._basis.get_params(deep)}
"""Extend the dict of parameters from the underlying Basis with basis."""
return {"basis": self.basis, **self.basis.get_params(deep)}

def __dir__(self) -> list[str]:
"""Extend the list of properties of methods with the ones from the underlying Basis."""
return list(super().__dir__()) + list(self._basis.__dir__())
unique_attrs = set(list(super().__dir__()) + list(self.basis.__dir__()))
# discard without raising errors if not present
unique_attrs.discard("to_transformer")
return list(unique_attrs)

def __add__(self, other: TransformerBasis) -> TransformerBasis:
"""
Expand All @@ -376,7 +424,7 @@ def __add__(self, other: TransformerBasis) -> TransformerBasis:
: TransformerBasis
The resulting Basis object.
"""
return TransformerBasis(self._basis + other._basis)
return TransformerBasis(self.basis + other.basis)

def __mul__(self, other: TransformerBasis) -> TransformerBasis:
"""
Expand All @@ -392,7 +440,7 @@ def __mul__(self, other: TransformerBasis) -> TransformerBasis:
:
The resulting Basis object.
"""
return TransformerBasis(self._basis * other._basis)
return TransformerBasis(self.basis * other.basis)

def __pow__(self, exponent: int) -> TransformerBasis:
"""Exponentiation of a TransformerBasis object.
Expand All @@ -418,7 +466,7 @@ def __pow__(self, exponent: int) -> TransformerBasis:
If the integer is zero or negative.
"""
# errors are handled by Basis.__pow__
return TransformerBasis(self._basis**exponent)
return TransformerBasis(self.basis**exponent)

def _check_input(self, X: FeatureMatrix, y=None):
"""Check that the input structure.
Expand Down Expand Up @@ -455,5 +503,5 @@ def _check_input(self, X: FeatureMatrix, y=None):
if y is not None and y.shape[0] != X.shape[0]:
raise ValueError(
"X and y must have the same number of samples. "
f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples."
f"X has {X.shape[0]} samples, while y has {y.shape[0]} samples."
)
Loading

0 comments on commit 8b1b403

Please sign in to comment.