Skip to content

Commit

Permalink
Merge pull request #184 from dswah/gen_x_grid
Browse files Browse the repository at this point in the history
move generate_X_grid to gam method
  • Loading branch information
dswah authored Jul 7, 2018
2 parents 7902f2d + 1ebcf86 commit f076346
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 50 deletions.
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,12 @@ For **regression** problems, we can use a **linear GAM** which models:

```python
from pygam import LinearGAM
from pygam.utils import generate_X_grid
from pygam.datasets import wage

X, y = wage(return_X_y=True)

gam = LinearGAM(n_splines=10).gridsearch(X, y)
XX = generate_X_grid(gam)
XX = gam.generate_X_grid()

fig, axs = plt.subplots(1, 3)
titles = ['year', 'age', 'education']
Expand Down Expand Up @@ -131,13 +130,12 @@ With **LinearGAMs**, we can also check the **prediction intervals**:

```python
from pygam import LinearGAM
from pygam.utils import generate_X_grid
from pygam.datasets import mcycle

X, y = mcycle(return_X_y=True)

gam = LinearGAM().gridsearch(X, y)
XX = generate_X_grid(gam)
XX = gam.generate_X_grid()

plt.plot(XX, gam.predict(XX), 'r--')
plt.plot(XX, gam.prediction_intervals(XX, width=.95), color='b', ls='--')
Expand Down Expand Up @@ -167,13 +165,12 @@ For **binary classification** problems, we can use a **logistic GAM** which mode

```python
from pygam import LogisticGAM
from pygam.utils import generate_X_grid
from pygam.datasets import default

X, y = default(return_X_y=True)

gam = LogisticGAM().gridsearch(X, y)
XX = generate_X_grid(gam)
XX = gam.generate_X_grid()

fig, axs = plt.subplots(1, 3)
titles = ['student', 'balance', 'income']
Expand Down
13 changes: 6 additions & 7 deletions gen_imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from matplotlib.font_manager import FontProperties

from pygam import *
from pygam.utils import generate_X_grid
from pygam.datasets import hepatitis, wage, faithful, mcycle, trees, default, cake, toy_classification

np.random.seed(420)
Expand All @@ -24,7 +23,7 @@
def gen_basis_fns():
X, y = hepatitis()
gam = LinearGAM(lam=.6, fit_intercept=False).fit(X, y)
XX = generate_X_grid(gam)
XX = gam.generate_X_grid()

plt.figure()
fig, ax = plt.subplots(2,1)
Expand All @@ -44,7 +43,7 @@ def cake_data_in_one():
gam = LinearGAM(fit_intercept=True)
gam.gridsearch(X,y)

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()

plt.figure()
plt.plot(gam.partial_dependence(XX))
Expand Down Expand Up @@ -81,7 +80,7 @@ def mcycle_data_linear():
gam = LinearGAM()
gam.gridsearch(X, y)

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()
plt.figure()
plt.scatter(X, y, facecolor='gray', edgecolors='none')
plt.plot(XX, gam.predict(XX), 'r--')
Expand Down Expand Up @@ -112,7 +111,7 @@ def wage_data_linear():
gam = LinearGAM(n_splines=10)
gam.gridsearch(X, y, lam=np.logspace(-5,3,50))

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()

plt.figure()
fig, axs = plt.subplots(1,3)
Expand All @@ -129,13 +128,13 @@ def wage_data_linear():
fig.tight_layout()
plt.savefig('imgs/pygam_wage_data_linear.png', dpi=300)

def default_data_logistic(n=500):
def default_data_logistic():
X, y = default()

gam = LogisticGAM()
gam.gridsearch(X, y)

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()

plt.figure()
fig, axs = plt.subplots(1,3)
Expand Down
25 changes: 23 additions & 2 deletions pygam/pygam.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from pygam.utils import space_row
from pygam.utils import sig_code
from pygam.utils import gen_edge_knots
from pygam.utils import generate_X_grid
from pygam.utils import b_spline_basis
from pygam.utils import combine
from pygam.utils import cholesky
Expand Down Expand Up @@ -567,6 +566,28 @@ def _validate_data_dep_params(self, X):
if self._fit_intercept:
self._n_coeffs = [1] + self._n_coeffs

def generate_X_grid(self, n=500):
"""create a nice grid of X data
array is sorted by feature and uniformly spaced,
so the marginal and joint distributions are likely wrong
Parameters
----------
n : int, default: 500
number of data points to create
Returns
-------
np.array of shape (n, n_features)
"""
if not self._is_fitted:
raise AttributeError('GAM has not been fitted. Call fit first.')
X = []
for ek in self._edge_knots:
X.append(np.linspace(ek[0], ek[-1], num=n))
return np.vstack(X).T

def loglikelihood(self, X, y, weights=None):
"""
compute the log-likelihood of the dataset using the current model
Expand Down Expand Up @@ -1824,7 +1845,7 @@ def partial_dependence(self, X=None, feature=-1, width=None, quantiles=None):
edge_knots=self._edge_knots, dtypes=self._dtype,
verbose=self.verbose)
else:
X = generate_X_grid(self)
X = self.generate_X_grid()

p_deps = []

Expand Down
11 changes: 5 additions & 6 deletions pygam/tests/test_GAM_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import scipy as sp

from pygam import *
from pygam.utils import generate_X_grid


@pytest.fixture
Expand Down Expand Up @@ -380,7 +379,7 @@ def test_shape_of_random_samples(self, mcycle_X_y, mcycle_gam):
assert sample_mu.shape == (n_draws, n_samples)
assert sample_y.shape == (n_draws, n_samples)

XX = generate_X_grid(mcycle_gam)
XX = mcycle_gam.generate_X_grid()
n_samples_in_grid = len(XX)
sample_coef = mcycle_gam.sample(X, y, quantity='coef', n_draws=n_draws,
sample_at_X=XX)
Expand Down Expand Up @@ -430,7 +429,7 @@ def test_prediction_interval_unknown_scale():
gam_a = LinearGAM(fit_linear=True, fit_splines=False).fit(X, y)
gam_b = LinearGAM(n_splines=4).fit(X, y)

XX = generate_X_grid(gam_a)
XX = gam_a.generate_X_grid()
intervals_a = gam_a.prediction_intervals(XX, quantiles=[0.1, .9]).mean(axis=0)
intervals_b = gam_b.prediction_intervals(XX, quantiles=[0.1, .9]).mean(axis=0)

Expand All @@ -452,7 +451,7 @@ def test_prediction_interval_known_scale():
gam_a = LinearGAM(fit_linear=True, fit_splines=False, scale=1.).fit(X, y)
gam_b = LinearGAM(n_splines=4, scale=1.).fit(X, y)

XX = generate_X_grid(gam_a)
XX = gam_a.generate_X_grid()
intervals_a = gam_a.prediction_intervals(XX, quantiles=[0.1, .9]).mean(axis=0)
intervals_b = gam_b.prediction_intervals(XX, quantiles=[0.1, .9]).mean(axis=0)

Expand Down Expand Up @@ -537,7 +536,7 @@ def test_pythonic_UI_in_pdeps(mcycle_gam):
to index into features starting at 0
and select the intercept by choosing feature='intercept'
"""
X = generate_X_grid(mcycle_gam)
X = mcycle_gam.generate_X_grid()

# check all features gives no intercept
pdeps = mcycle_gam.partial_dependence(X=X, feature=-1)
Expand Down Expand Up @@ -570,7 +569,7 @@ def test_no_X_needed_for_partial_dependence(mcycle_gam):
"""
partial_dependence() method uses generate_X_grid by default for the X array
"""
XX = generate_X_grid(mcycle_gam)
XX = mcycle_gam.generate_X_grid()
assert (mcycle_gam.partial_dependence() == mcycle_gam.partial_dependence(X=XX)).all()

def test_initial_estimate_runs_for_int_obseravtions(toy_classification_X_y):
Expand Down
10 changes: 4 additions & 6 deletions pygam/tests/test_penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from pygam.penalties import none
from pygam.penalties import wrap_penalty

from pygam.utils import generate_X_grid


def test_single_spline_penalty():
"""
Expand Down Expand Up @@ -64,7 +62,7 @@ def test_monotonic_inchepatitis_X_y(hepatitis_X_y):
gam = LinearGAM(constraints='monotonic_inc')
gam.fit(X, y)

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()
Y = gam.predict(np.sort(XX))
diffs = np.diff(Y, n=1)
assert(((diffs >= 0) + np.isclose(diffs, 0.)).all())
Expand All @@ -78,7 +76,7 @@ def test_monotonic_dec(hepatitis_X_y):
gam = LinearGAM(constraints='monotonic_dec')
gam.fit(X, y)

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()
Y = gam.predict(np.sort(XX))
diffs = np.diff(Y, n=1)
assert(((diffs <= 0) + np.isclose(diffs, 0.)).all())
Expand All @@ -92,7 +90,7 @@ def test_convex(hepatitis_X_y):
gam = LinearGAM(constraints='convex')
gam.fit(X, y)

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()
Y = gam.predict(np.sort(XX))
diffs = np.diff(Y, n=2)
assert(((diffs >= 0) + np.isclose(diffs, 0.)).all())
Expand All @@ -106,7 +104,7 @@ def test_concave(hepatitis_X_y):
gam = LinearGAM(constraints='concave')
gam.fit(X, y)

XX = generate_X_grid(gam)
XX = gam.generate_X_grid()
Y = gam.predict(np.sort(XX))
diffs = np.diff(Y, n=2)
assert(((diffs <= 0) + np.isclose(diffs, 0.)).all())
Expand Down
23 changes: 0 additions & 23 deletions pygam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,29 +89,6 @@ def cholesky(A, sparse=True, verbose=True):
return L


def generate_X_grid(gam, n=500):
"""
tool to create a nice grid of X data if no X data is supplied
array is sorted by feature and uniformly spaced, so the marginal and joint
distributions are likely wrong
Parameters
----------
gam : GAM instance
n : int, default: 500
number of data points to create
Returns
-------
np.array of shape (n, n_features)
"""
X = []
for ek in gam._edge_knots:
X.append(np.linspace(ek[0], ek[-1], num=n))
return np.vstack(X).T


def check_dtype(X, ratio=.95):
"""
tool to identify the data-types of the features in data matrix X.
Expand Down

0 comments on commit f076346

Please sign in to comment.