Skip to content

Commit

Permalink
regen images, readme has no i+1, use toy_classification dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
dswah committed Jul 5, 2018
1 parent 2e80ecc commit 6253643
Show file tree
Hide file tree
Showing 11 changed files with 18 additions and 75 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ fig, axs = plt.subplots(1, 3)
titles = ['year', 'age', 'education']

for i, ax in enumerate(axs):
pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)
pdep, confi = gam.partial_dependence(XX, feature=i, width=.95)

ax.plot(XX[:, i], pdep)
ax.plot(XX[:, i], *confi, c='r', ls='--')
Expand Down Expand Up @@ -179,7 +179,7 @@ fig, axs = plt.subplots(1, 3)
titles = ['student', 'balance', 'income']

for i, ax in enumerate(axs):
pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)
pdep, confi = gam.partial_dependence(XX, feature=i, width=.95)

ax.plot(XX[:, i], pdep)
ax.plot(XX[:, i], confi[0], c='r', ls='--')
Expand Down Expand Up @@ -235,7 +235,7 @@ gam = PoissonGAM().gridsearch(X, y)

plt.hist(faithful(return_X_y=False)['eruptions'], bins=200, color='k');
plt.plot(X, gam.predict(X), color='r')
plt.title('Lam: {0:.2f}'.format(gam.lam))
plt.title('Best Lambda: {0:.2f}'.format(gam.lam))
```
<img src=imgs/pygam_poisson.png>

Expand Down Expand Up @@ -322,6 +322,9 @@ pyGAM is intuitive, modular, and adheres to a familiar API:

```python
from pygam import LogisticGAM
from pygam.datasets import toy_classification

X, y = toy_classification(return_X_y=True)

gam = LogisticGAM()
gam.fit(X, y)
Expand Down
84 changes: 12 additions & 72 deletions gen_imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pygam import *
from pygam.utils import generate_X_grid
from pygam.datasets import hepatitis, wage, faithful, mcycle, trees, default, cake, toy_classification

np.random.seed(420)
fontP = FontProperties()
Expand All @@ -20,60 +21,6 @@
# monotonic increasing, concave constraint on hep data
# prediction intervals on motorcycle data


def hepatitis():
hep = pd.read_csv('datasets/hepatitis_A_bulgaria.csv').astype(float)
# eliminate 0/0
mask = (hep.total > 0).values
hep = hep[mask]
X = hep.age.values
y = hep.hepatitis_A_positive.values / hep.total.values
return X, y

def mcycle():
motor = pd.read_csv('datasets/mcycle.csv', index_col=0)
X = motor.times.values[:,None]
y = motor.accel
return X, y

def faithful():
faithful = pd.read_csv('datasets/faithful.csv', index_col=0)
y, xx, _ = plt.hist(faithful.values[:,0], bins=200, color='k')
X = xx[:-1] + np.diff(xx)/2 # get midpoints of bins
return X, y

def wage():
wage = pd.read_csv('datasets/wage.csv', index_col=0)
X = wage[['year', 'age', 'education']].values
y = wage['wage'].values

# change education level to integers
X[:,-1] = np.unique(X[:,-1], return_inverse=True)[1]
return X, y

def trees():
trees = pd.read_csv('datasets/trees.csv', index_col=0)
y = trees.Volume
X = trees[['Girth', 'Height']]
return X, y

def default():
default = pd.read_csv('datasets/default.csv', index_col=0)
default = default.values
default[:,0] = np.unique(default[:,0], return_inverse=True)[1]
default[:,1] = np.unique(default[:,1], return_inverse=True)[1]
X = default[:,1:]
y = default[:,0]
return X, y

def cake():
cake = pd.read_csv('datasets/cake.csv', index_col=0)
X = cake[['recipe', 'replicate', 'temperature']].values
X[:,0] = np.unique(cake.values[:,1], return_inverse=True)[1]
X[:,1] -= 1
y = cake['angle'].values
return X, y

def gen_basis_fns():
X, y = hepatitis()
gam = LinearGAM(lam=.6, fit_intercept=False).fit(X, y)
Expand Down Expand Up @@ -109,10 +56,10 @@ def faithful_data_poisson():
gam = PoissonGAM().gridsearch(X, y)

plt.figure()
plt.bar(X, y, width=np.diff(X)[0], color='k')
plt.hist(faithful(return_X_y=False)['eruptions'], bins=200, color='k');

plt.plot(X, gam.predict(X), color='r')
plt.title('Best Lamba: {0:.2f}'.format(gam.lam))
plt.title('Best Lambda: {0:.2f}'.format(gam.lam))
plt.savefig('imgs/pygam_poisson.png', dpi=300)

def single_data_linear():
Expand All @@ -125,7 +72,7 @@ def single_data_linear():
plt.figure()
plt.scatter(X, y, facecolor='gray', edgecolors='none')
plt.plot(X, gam.predict(X), color='r')
plt.title('Best Lamba: {0:.2f}'.format(gam.lam))
plt.title('Best Lambda: {0:.2f}'.format(gam.lam))
plt.savefig('imgs/pygam_single_pred_linear.png', dpi=300)

def mcycle_data_linear():
Expand Down Expand Up @@ -172,8 +119,8 @@ def wage_data_linear():

titles = ['year', 'age', 'education']
for i, ax in enumerate(axs):
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i, width=.95)[1],
c='r', ls='--')
if i == 0:
ax.set_ylim(-30,30);
Expand All @@ -195,8 +142,8 @@ def default_data_logistic(n=500):

titles = ['student', 'balance', 'income']
for i, ax in enumerate(axs):
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i, width=.95)[1],
c='r', ls='--')
ax.set_title(titles[i])

Expand Down Expand Up @@ -281,24 +228,17 @@ def trees_data_custom():
# plt.savefig('imgs/pygam_lambda_gridsearch.png', dpi=300)


def gen_multi_data(n=200):
def gen_multi_data(n=5000):
"""
multivariate Logistic problem
"""
n = 5000
x = np.random.rand(n,5) * 10 - 5
cat = np.random.randint(0,4, n)
x = np.c_[x, cat]
log_odds = (-0.5*x[:,0]**2) + 5 +(-0.5*x[:,1]**2) + np.mod(x[:,-1], 2)*-30
p = 1/(1+np.exp(-log_odds)).squeeze()

obs = (np.random.rand(len(x)) < p).astype(np.int)
X, y = toy_classification(return_X_y=True, n=n)

lgam = LogisticGAM()
lgam.fit(x, obs)
lgam.fit(X, y)

plt.figure()
plt.plot(lgam.partial_dependence(np.sort(x, axis=0)))
plt.plot(lgam.partial_dependence(np.sort(X, axis=0)))
plt.savefig('imgs/pygam_multi_pdep.png', dpi=300)

plt.figure()
Expand Down
Binary file modified imgs/pygam_basis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_cake_data.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_constraints.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_custom.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_default_data_logistic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_mcycle_data_extrapolation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_mcycle_data_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_poisson.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs/pygam_wage_data_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 6253643

Please sign in to comment.