Skip to content

Commit

Permalink
Merge pull request #206 from dswah/ints
Browse files Browse the repository at this point in the history
Ints
  • Loading branch information
dswah authored Sep 22, 2018
2 parents 871c1c5 + be35d42 commit cf81b0d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
9 changes: 9 additions & 0 deletions pygam/tests/test_GAM_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def test_no_explicit_terms_custom_lambda(self, wage_X_y):
gam.gridsearch(X, y)
assert gam._is_fitted

def test_n_splines_not_int(self, mcycle_X_y):
"""
used to fail for n_splines of type np.int64, as returned by np.arange
"""
X, y = mcycle_X_y
gam = LinearGAM(n_splines=np.arange(9,10)[0]).fit(X, y)
assert gam._is_fitted


# TODO categorical dtypes get no fit linear even if fit linear TRUE
# TODO categorical dtypes get their own number of splines
# TODO can force continuous dtypes on categorical vars if wanted
5 changes: 3 additions & 2 deletions pygam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import division
from copy import deepcopy
import numbers
import sys
import warnings

Expand Down Expand Up @@ -593,10 +594,10 @@ def b_spline_basis(x, edge_knots, n_splines=20, spline_order=3, sparse=True,
raise ValueError('Data must be 1-D, but found {}'\
.format(np.ravel(x).ndim))

if (n_splines < 1) or (type(n_splines) is not int):
if (n_splines < 1) or not isinstance(n_splines, numbers.Integral):
raise ValueError('n_splines must be int >= 1')

if (spline_order < 0) or (type(spline_order) is not int):
if (spline_order < 0) or not isinstance(spline_order, numbers.Integral):
raise ValueError('spline_order must be int >= 1')

if n_splines < spline_order + 1:
Expand Down

0 comments on commit cf81b0d

Please sign in to comment.