From a93a0024e27f735f1009c2de4a9fe53e2caeef35 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 19 Dec 2024 22:26:11 -0500 Subject: [PATCH] fix tutorials --- .../plot_06_sklearn_pipeline_cv_demo.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 4073b928..68c08000 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -321,7 +321,8 @@ scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds)) coeffs = {} # initialize basis and model -basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineLinearEval(6)) +basis = nmo.basis.RaisedCosineLinearEval(6).set_input_shape(1) +basis = nmo.basis.TransformerBasis(basis) model = nmo.glm.GLM(regularizer="Ridge") # loop over combinations @@ -441,13 +442,13 @@ We are now able to capture the distribution of the firing rate appropriately: bo In the previous example we set the number of basis functions of the [`Basis`](nemos.basis._basis.Basis) wrapped in our [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis). However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well. -Here we include `transformerbasis___basis` in the parameter grid to try different values for `TransformerBasis._basis`: +Here we include `transformerbasis__basis` in the parameter grid to try different values for `TransformerBasis.basis`: ```{code-cell} ipython3 param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), - transformerbasis___basis=( + transformerbasis__basis=( nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), @@ -481,7 +482,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_) # Read out the number of basis functions cvdf["transformerbasis_config"] = [ f"{b.__class__.__name__} - {b.n_basis_funcs}" - for b in cvdf["param_transformerbasis___basis"] + for b in cvdf["param_transformerbasis__basis"] ] cvdf_wide = cvdf.pivot( @@ -537,7 +538,7 @@ Please note that because it would lead to unexpected behavior, mixing the two wa param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), - transformerbasis___basis=( + transformerbasis__basis=( nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), @@ -592,7 +593,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_) # Read out the number of basis functions cvdf["transformerbasis_config"] = [ f"{b.__class__.__name__} - {b.n_basis_funcs}" - for b in cvdf["param_transformerbasis___basis"] + for b in cvdf["param_transformerbasis__basis"] ] cvdf_wide = cvdf.pivot(