diff --git a/pygam/tests/test_GAM_params.py b/pygam/tests/test_GAM_params.py index 57e5adee..5cd76e85 100644 --- a/pygam/tests/test_GAM_params.py +++ b/pygam/tests/test_GAM_params.py @@ -92,6 +92,15 @@ def test_no_explicit_terms_custom_lambda(self, wage_X_y): gam.gridsearch(X, y) assert gam._is_fitted + def test_n_splines_not_int(self, mcycle_X_y): + """ + used to fail for n_splines of type np.int64, as returned by np.arange + """ + X, y = mcycle_X_y + gam = LinearGAM(n_splines=np.arange(9,10)[0]).fit(X, y) + assert gam._is_fitted + + # TODO categorical dtypes get no fit linear even if fit linear TRUE # TODO categorical dtypes get their own number of splines # TODO can force continuous dtypes on categorical vars if wanted diff --git a/pygam/utils.py b/pygam/utils.py index b2930bdb..b437c23f 100644 --- a/pygam/utils.py +++ b/pygam/utils.py @@ -4,6 +4,7 @@ from __future__ import division from copy import deepcopy +import numbers import sys import warnings @@ -593,10 +594,10 @@ def b_spline_basis(x, edge_knots, n_splines=20, spline_order=3, sparse=True, raise ValueError('Data must be 1-D, but found {}'\ .format(np.ravel(x).ndim)) - if (n_splines < 1) or (type(n_splines) is not int): + if (n_splines < 1) or not isinstance(n_splines, numbers.Integral): raise ValueError('n_splines must be int >= 1') - if (spline_order < 0) or (type(spline_order) is not int): + if (spline_order < 0) or not isinstance(spline_order, numbers.Integral): raise ValueError('spline_order must be int >= 1') if n_splines < spline_order + 1: