Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve transformer api #280

Merged
merged 58 commits into from
Dec 20, 2024
Merged

Improve transformer api #280

merged 58 commits into from
Dec 20, 2024

Conversation

BalzaniEdoardo
Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo commented Dec 14, 2024

Small PR allowing the chaining behavior of TransformerBasis.

Before this PRs, basis methods returning self could not be chained when the basis was wrapped by the TransformerBasis class.

Old behavior

>>> import nemos as nmo
>>> from nemos.basis._basis import Basis
>>> from nemos.basis._transformer_basis import TransformerBasis
>>> transformer_basis = nmo.basis.BSplineEval(5).to_transformer()
>>> out = transformer_basis.set_input_shape(1)
>>> isinstance(out, TransformerBasis)
False
>>> isinstance(out, Basis)
True

New behavior

>>> import nemos as nmo
>>> from nemos.basis._basis import Basis
>>> from nemos.basis._transformer_basis import TransformerBasis
>>> transformer_basis = nmo.basis.BSplineEval(5).to_transformer()
>>> out = transformer_basis.set_input_shape(1)
>>> isinstance(out, TransformerBasis)
True
>>> isinstance(out, Basis)
False

Key Features

  • Decorator that processes the output of chainable basis methods, setting the method output (the updated basis object) to the _basis attribute of the transformer, and returns the transformer.
  • New tuple attribute of the TransformerBasis listing all the chainable methods of basis.
  • Run-time decoration of chainable methods in the __getattr__ when first called + caching. Run-time decorating is necessary because if set at initialization, it would create a circular reference, and result in infinite loop when pickling.

Copy link
Collaborator

@sjvenditto sjvenditto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I committed a fix to a couple typos I found, otherwise it looks good!

@BalzaniEdoardo
Copy link
Collaborator Author

BalzaniEdoardo commented Dec 17, 2024 via email

Copy link
Member

@billbrod billbrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't really have any notes about the chaining, that all looks good to me.

  • TransformerBasis objects have a to_transformer() method, which I think just returns self. This is a bit confusing -- is there a reason to have this behavior?
  • transform and fit_transform should have the notes about set_input_shape being called first as well (like fit), as well as the Raises section of the docstring

src/nemos/basis/_transformer_basis.py Outdated Show resolved Hide resolved
src/nemos/basis/_transformer_basis.py Outdated Show resolved Hide resolved
src/nemos/basis/_transformer_basis.py Outdated Show resolved Hide resolved
src/nemos/basis/_transformer_basis.py Outdated Show resolved Hide resolved
tests/test_transformer_basis.py Show resolved Hide resolved
@BalzaniEdoardo BalzaniEdoardo merged commit 8b1b403 into development Dec 20, 2024
13 checks passed
@BalzaniEdoardo BalzaniEdoardo deleted the improve_transformer_api branch December 20, 2024 03:53
@billbrod billbrod mentioned this pull request Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants