-
Notifications
You must be signed in to change notification settings - Fork 12
Internal API
If you want to implement new functionality by implementing new estimators or wrapping existing sklearn estimators this page will guide you through the necessary steps.
Let's say you want to wrap an estimator called AwesomeEstimator
from sklearn.awesome_module
that implements the usual fit
, predict
and score
methods.
All sklearn-xarray estimators should inherit from sklearn_xarray.common.wrappers.EstimatorWrapper
which provides the fit
method that all sklearn estimator must possess. All other member methods are optional and provided by mixins. The constructor of the wrapper should construct an instance of the wrapped estimator and pass it to the superclass constructor:
from sklearn_xarray.common.wrappers import (
EstimatorWrapper, _ImplementsPredictMixin, _ImplementsScoreMixin)
from sklearn.awesome_module import AwesomeEstimator as _AwesomeEstimator
class AwesomeEstimator(EstimatorWrapper, _ImplementsPredictMixin,
_ImplementsScoreMixin):
""" An example class demonstrating the internal API. """
def __init__(self, param_1=None, param_2=None, reshapes=None,
sample_dim='sample', compat=False):
estimator = _AwesomeEstimator(param_1=param_1, param_2=param_2)
super(AwesomeEstimator, self).__init__(
estimator, reshapes=reshapes, sample_dim=sample_dim, compat=compat)
Notice that the keyword arguments reshapes
, sample_dim
and compat
are also passed. If the wrapped estimator changes the number of features when calling predict
you can specify reshapes='feature'
, but it's not absolutely necessary. Check out the documentation for more details.
Now let's say the AwesomeEstimator
also has a do_stuff
method that does not have a corresponding wrapper yet. You'll have to implement a mixin that wraps this method in sklearn_xarray.common.base
. It should include a public do_stuff
method as well as an internal _do_stuff
method whose input is a DataArray and which calls the wrapped estimators do_stuff
. _CommonEstimatorWrapper._call_fitted
will then take care of the rest:
import numpy as np
from sklearn_xarray.common.base import _CommonEstimatorWrapper
class _ImplementsDoStuffMixin(_CommonEstimatorWrapper):
def _do_stuff(self, estimator, X):
""" Do stuff with ``self.estimator`` and update dims. """
if self.sample_dim is not None:
# transpose to sample dim first, do stuff and transpose back
order = self._get_transpose_order(X)
X_arr = np.transpose(X.data, order)
X_d = estimator.do_stuff(X_arr)
if X_d.ndim == X.ndim:
X_d = np.transpose(X_d, np.argsort(order))
else:
X_d = estimator.do_stuff(X.data)
# update dims
dims_new = self._update_dims(X, X_d)
return X_d, dims_new
def do_stuff(self, X):
""" A wrapper around the do_stuff function. """
return self._call_fitted('do_stuff', X)
If you want to implement new estimators, you can obviously just implement a standard sklearn-compatible estimator and wrap it with the steps described above. But you'll probably want to implement some functionality that makes use of xarray's coordinates or other features. In that case you'll have to implement the estimator from scratch.
First of all, your estimator should implement a fit
method that determines the type of the training data:
from sklearn_xarray.utils import is_dataarray, is_dataset
def fit(self, X, y=None, **fit_params):
""" Fit estimator to data. """
if is_dataset(X):
self.type_ = 'Dataset'
# implement fitting procedure for Datasets
elif is_dataarray(X):
self.type_ = 'DataArray'
# implement fitting procedure for DataArrays
else:
self.type_ = other
# implement fitting procedure for numpy-like
return self
All other methods that work on fitted estimators should call check_fitted
first:
from sklearn.utils.validation import check_is_fitted
check_is_fitted(self, ['type_'])
You should probably implement the functionality for DataArrays in a seperate method. For Datasets you can then just call that method for each of the data_vars
and join them together afterwards.
Take a look at the estimators in the preprocessing module for a couple of examples.