Skip to content

Commit

Permalink
Merge pull request #180 from dswah/intercept-ui
Browse files Browse the repository at this point in the history
Intercept ui
  • Loading branch information
dswah authored Jul 6, 2018
2 parents 77c82e5 + 77ff2a5 commit 023b0fd
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 86 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ fig, axs = plt.subplots(1, 3)
titles = ['year', 'age', 'education']

for i, ax in enumerate(axs):
pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)
pdep, confi = gam.partial_dependence(XX, feature=i, width=.95)

ax.plot(XX[:, i], pdep)
ax.plot(XX[:, i], *confi, c='r', ls='--')
Expand Down Expand Up @@ -179,7 +179,7 @@ fig, axs = plt.subplots(1, 3)
titles = ['student', 'balance', 'income']

for i, ax in enumerate(axs):
pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)
pdep, confi = gam.partial_dependence(XX, feature=i, width=.95)

ax.plot(XX[:, i], pdep)
ax.plot(XX[:, i], confi[0], c='r', ls='--')
Expand Down Expand Up @@ -235,7 +235,7 @@ gam = PoissonGAM().gridsearch(X, y)

plt.hist(faithful(return_X_y=False)['eruptions'], bins=200, color='k');
plt.plot(X, gam.predict(X), color='r')
plt.title('Lam: {0:.2f}'.format(gam.lam))
plt.title('Best Lambda: {0:.2f}'.format(gam.lam))
```
<img src=imgs/pygam_poisson.png>

Expand Down Expand Up @@ -322,6 +322,9 @@ pyGAM is intuitive, modular, and adheres to a familiar API:

```python
from pygam import LogisticGAM
from pygam.datasets import toy_classification

X, y = toy_classification(return_X_y=True)

gam = LogisticGAM()
gam.fit(X, y)
Expand Down
84 changes: 12 additions & 72 deletions gen_imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pygam import *
from pygam.utils import generate_X_grid
from pygam.datasets import hepatitis, wage, faithful, mcycle, trees, default, cake, toy_classification

np.random.seed(420)
fontP = FontProperties()
Expand All @@ -20,60 +21,6 @@
# monotonic increasing, concave constraint on hep data
# prediction intervals on motorcycle data


def hepatitis():
hep = pd.read_csv('datasets/hepatitis_A_bulgaria.csv').astype(float)
# eliminate 0/0
mask = (hep.total > 0).values
hep = hep[mask]
X = hep.age.values
y = hep.hepatitis_A_positive.values / hep.total.values
return X, y

def mcycle():
motor = pd.read_csv('datasets/mcycle.csv', index_col=0)
X = motor.times.values[:,None]
y = motor.accel
return X, y

def faithful():
faithful = pd.read_csv('datasets/faithful.csv', index_col=0)
y, xx, _ = plt.hist(faithful.values[:,0], bins=200, color='k')
X = xx[:-1] + np.diff(xx)/2 # get midpoints of bins
return X, y

def wage():
wage = pd.read_csv('datasets/wage.csv', index_col=0)
X = wage[['year', 'age', 'education']].values
y = wage['wage'].values

# change education level to integers
X[:,-1] = np.unique(X[:,-1], return_inverse=True)[1]
return X, y

def trees():
trees = pd.read_csv('datasets/trees.csv', index_col=0)
y = trees.Volume
X = trees[['Girth', 'Height']]
return X, y

def default():
default = pd.read_csv('datasets/default.csv', index_col=0)
default = default.values
default[:,0] = np.unique(default[:,0], return_inverse=True)[1]
default[:,1] = np.unique(default[:,1], return_inverse=True)[1]
X = default[:,1:]
y = default[:,0]
return X, y

def cake():
cake = pd.read_csv('datasets/cake.csv', index_col=0)
X = cake[['recipe', 'replicate', 'temperature']].values
X[:,0] = np.unique(cake.values[:,1], return_inverse=True)[1]
X[:,1] -= 1
y = cake['angle'].values
return X, y

def gen_basis_fns():
X, y = hepatitis()
gam = LinearGAM(lam=.6, fit_intercept=False).fit(X, y)
Expand Down Expand Up @@ -109,10 +56,10 @@ def faithful_data_poisson():
gam = PoissonGAM().gridsearch(X, y)

plt.figure()
plt.bar(X, y, width=np.diff(X)[0], color='k')
plt.hist(faithful(return_X_y=False)['eruptions'], bins=200, color='k');

plt.plot(X, gam.predict(X), color='r')
plt.title('Best Lamba: {0:.2f}'.format(gam.lam))
plt.title('Best Lambda: {0:.2f}'.format(gam.lam))
plt.savefig('imgs/pygam_poisson.png', dpi=300)

def single_data_linear():
Expand All @@ -125,7 +72,7 @@ def single_data_linear():
plt.figure()
plt.scatter(X, y, facecolor='gray', edgecolors='none')
plt.plot(X, gam.predict(X), color='r')
plt.title('Best Lamba: {0:.2f}'.format(gam.lam))
plt.title('Best Lambda: {0:.2f}'.format(gam.lam))
plt.savefig('imgs/pygam_single_pred_linear.png', dpi=300)

def mcycle_data_linear():
Expand Down Expand Up @@ -172,8 +119,8 @@ def wage_data_linear():

titles = ['year', 'age', 'education']
for i, ax in enumerate(axs):
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i, width=.95)[1],
c='r', ls='--')
if i == 0:
ax.set_ylim(-30,30);
Expand All @@ -195,8 +142,8 @@ def default_data_logistic(n=500):

titles = ['student', 'balance', 'income']
for i, ax in enumerate(axs):
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i, width=.95)[1],
c='r', ls='--')
ax.set_title(titles[i])

Expand Down Expand Up @@ -281,24 +228,17 @@ def trees_data_custom():
# plt.savefig('imgs/pygam_lambda_gridsearch.png', dpi=300)


def gen_multi_data(n=200):
def gen_multi_data(n=5000):
"""
multivariate Logistic problem
"""
n = 5000
x = np.random.rand(n,5) * 10 - 5
cat = np.random.randint(0,4, n)
x = np.c_[x, cat]
log_odds = (-0.5*x[:,0]**2) + 5 +(-0.5*x[:,1]**2) + np.mod(x[:,-1], 2)*-30
p = 1/(1+np.exp(-log_odds)).squeeze()

obs = (np.random.rand(len(x)) < p).astype(np.int)
X, y = toy_classification(return_X_y=True, n=n)

lgam = LogisticGAM()
lgam.fit(x, obs)
lgam.fit(X, y)

plt.figure()
plt.plot(lgam.partial_dependence(np.sort(x, axis=0)))
plt.plot(lgam.partial_dependence(np.sort(X, axis=0)))
plt.savefig('imgs/pygam_multi_pdep.png', dpi=300)

plt.figure()
Expand Down
Binary file modified imgs/pygam_basis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_cake_data.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_constraints.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_custom.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_default_data_logistic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_mcycle_data_extrapolation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_mcycle_data_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_poisson.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_wage_data_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion pygam/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pygam.datasets.load_datasets import default
from pygam.datasets.load_datasets import cake
from pygam.datasets.load_datasets import hepatitis
from pygam.datasets.load_datasets import toy_classification

__all__ = ['mcycle',
'coal',
Expand All @@ -19,4 +20,5 @@
'wage',
'default',
'cake',
'hepatitis']
'hepatitis',
'toy_classification']
57 changes: 57 additions & 0 deletions pygam/datasets/load_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,60 @@ def hepatitis(return_X_y=True):
y = hep.hepatitis_A_positive.values / hep.total.values
return _clean_X_y(X, y)
return hep

def toy_classification(return_X_y=True, n=5000):
"""toy classification dataset with irrelevant features
fitting a logistic model on this data and performing a model summary
should reveal that features 2,3,4 are not significant.
Parameters
----------
return_X_y : bool,
if True, returns a model-ready tuple of data (X, y)
otherwise, returns a Pandas DataFrame
n : int, default: 5000
number of samples to generate
Returns
-------
model-ready tuple of data (X, y)
OR
Pandas DataFrame
Notes
-----
X contains 5 variables:
continuous feature 0
continuous feature 1
irrelevant feature 0
irrelevant feature 1
irrelevant feature 2
categorical feature 0
y contains binary labels
Also, this dataset is randomly generated and will vary each time.
"""
# make features
X = np.random.rand(n,5) * 10 - 5
cat = np.random.randint(0,4, n)
X = np.c_[X, cat]

# make observations
log_odds = (-0.5*X[:,0]**2) + 5 +(-0.5*X[:,1]**2) + np.mod(X[:,-1], 2)*-30
p = 1/(1+np.exp(-log_odds)).squeeze()
y = (np.random.rand(n) < p).astype(np.int)

if return_X_y:
return X, y
else:
return pd.DataFrame(np.c_[X, y], columns=[['continuous0',
'continuous1',
'irrelevant0',
'irrelevant1',
'irrelevant2',
'categorical0',
'observations'
]])
34 changes: 25 additions & 9 deletions pygam/pygam.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from pygam.utils import space_row
from pygam.utils import sig_code
from pygam.utils import gen_edge_knots
from pygam.utils import generate_X_grid
from pygam.utils import b_spline_basis
from pygam.utils import combine
from pygam.utils import cholesky
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def _initial_estimate(self, y, modelmat):
return np.ones(m) * np.sqrt(EPS)

# transform the problem to the linear scale
y = deepcopy(y)
y = deepcopy(y).astype('float64')
y[y == 0] += .01 # edge case for log link, inverse link, and logit link
y[y == 1] -= .01 # edge case for logit link

Expand Down Expand Up @@ -1779,7 +1780,7 @@ def _select_feature(self, feature):
b = np.sum(self._n_coeffs[feature])
return np.arange(a, a+b, dtype=int)

def partial_dependence(self, X, feature=-1, width=None, quantiles=None):
def partial_dependence(self, X=None, feature=-1, width=None, quantiles=None):
"""
Computes the feature functions for the GAM
and possibly their confidence intervals.
Expand All @@ -1789,14 +1790,16 @@ def partial_dependence(self, X, feature=-1, width=None, quantiles=None):
Parameters
----------
X : array
input data of shape (n_samples, m_features)
X : array or None, default: None
input data of shape (n_samples, m_features).
if None, an equally spaced grid of 500 points is generated for
each feature function.
feature : array-like of ints, default: -1
feature for which to compute the partial dependence functions
if feature == -1, then all features are selected,
excluding the intercept
if feature == 0 and gam.fit_intercept is True, then the intercept's
patial dependence is returned
if feature == 'intercept' and gam.fit_intercept is True,
then the intercept's partial dependence is returned
width : float in [0, 1], default: None
width of the confidence interval
if None, defaults to 0.95
Expand All @@ -1815,15 +1818,28 @@ def partial_dependence(self, X, feature=-1, width=None, quantiles=None):
raise AttributeError('GAM has not been fitted. Call fit first.')

m = len(self._n_coeffs) - self._fit_intercept
X = check_X(X, n_feats=m, edge_knots=self._edge_knots,
dtypes=self._dtype, verbose=self.verbose)

if X is not None:
X = check_X(X, n_feats=m,
edge_knots=self._edge_knots, dtypes=self._dtype,
verbose=self.verbose)
else:
X = generate_X_grid(self)

p_deps = []

compute_quantiles = (width is not None) or (quantiles is not None)
conf_intervals = []

if feature == -1:
# make coding more pythonic for users
if feature == 'intercept':
if not self._fit_intercept:
raise ValueError('intercept is not fitted')
feature = 0
elif feature == -1:
feature = np.arange(m) + self._fit_intercept
else:
feature += self._fit_intercept

# convert to array
feature = np.atleast_1d(feature)
Expand Down
8 changes: 7 additions & 1 deletion pygam/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from pygam import *
from pygam.datasets import mcycle, coal, faithful, cake, coal, default, trees, hepatitis, wage
from pygam.datasets import mcycle, coal, faithful, cake, coal, default, trees, hepatitis, wage, toy_classification


@pytest.fixture
Expand Down Expand Up @@ -55,3 +55,9 @@ def hepatitis_X_y():
# y is real
# recommend LinearGAM
return hepatitis(return_X_y=True)

@pytest.fixture
def toy_classification_X_y():
# y is binary ints
# recommend LogisticGAM
return toy_classification(return_X_y=True)
Loading

0 comments on commit 023b0fd

Please sign in to comment.