diff --git a/README.md b/README.md index babb4477..818e8c7c 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ fig, axs = plt.subplots(1, 3) titles = ['year', 'age', 'education'] for i, ax in enumerate(axs): - pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95) + pdep, confi = gam.partial_dependence(XX, feature=i, width=.95) ax.plot(XX[:, i], pdep) ax.plot(XX[:, i], *confi, c='r', ls='--') @@ -179,7 +179,7 @@ fig, axs = plt.subplots(1, 3) titles = ['student', 'balance', 'income'] for i, ax in enumerate(axs): - pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95) + pdep, confi = gam.partial_dependence(XX, feature=i, width=.95) ax.plot(XX[:, i], pdep) ax.plot(XX[:, i], confi[0], c='r', ls='--') @@ -235,7 +235,7 @@ gam = PoissonGAM().gridsearch(X, y) plt.hist(faithful(return_X_y=False)['eruptions'], bins=200, color='k'); plt.plot(X, gam.predict(X), color='r') -plt.title('Lam: {0:.2f}'.format(gam.lam)) +plt.title('Best Lambda: {0:.2f}'.format(gam.lam)) ``` @@ -322,6 +322,9 @@ pyGAM is intuitive, modular, and adheres to a familiar API: ```python from pygam import LogisticGAM +from pygam.datasets import toy_classification + +X, y = toy_classification(return_X_y=True) gam = LogisticGAM() gam.fit(X, y) diff --git a/gen_imgs.py b/gen_imgs.py index 41a1f0cf..91adcbc5 100644 --- a/gen_imgs.py +++ b/gen_imgs.py @@ -8,6 +8,7 @@ from pygam import * from pygam.utils import generate_X_grid +from pygam.datasets import hepatitis, wage, faithful, mcycle, trees, default, cake, toy_classification np.random.seed(420) fontP = FontProperties() @@ -20,60 +21,6 @@ # monotonic increasing, concave constraint on hep data # prediction intervals on motorcycle data - -def hepatitis(): - hep = pd.read_csv('datasets/hepatitis_A_bulgaria.csv').astype(float) - # eliminate 0/0 - mask = (hep.total > 0).values - hep = hep[mask] - X = hep.age.values - y = hep.hepatitis_A_positive.values / hep.total.values - return X, y - -def mcycle(): - motor = pd.read_csv('datasets/mcycle.csv', index_col=0) - X = motor.times.values[:,None] - y = motor.accel - return X, y - -def faithful(): - faithful = pd.read_csv('datasets/faithful.csv', index_col=0) - y, xx, _ = plt.hist(faithful.values[:,0], bins=200, color='k') - X = xx[:-1] + np.diff(xx)/2 # get midpoints of bins - return X, y - -def wage(): - wage = pd.read_csv('datasets/wage.csv', index_col=0) - X = wage[['year', 'age', 'education']].values - y = wage['wage'].values - - # change education level to integers - X[:,-1] = np.unique(X[:,-1], return_inverse=True)[1] - return X, y - -def trees(): - trees = pd.read_csv('datasets/trees.csv', index_col=0) - y = trees.Volume - X = trees[['Girth', 'Height']] - return X, y - -def default(): - default = pd.read_csv('datasets/default.csv', index_col=0) - default = default.values - default[:,0] = np.unique(default[:,0], return_inverse=True)[1] - default[:,1] = np.unique(default[:,1], return_inverse=True)[1] - X = default[:,1:] - y = default[:,0] - return X, y - -def cake(): - cake = pd.read_csv('datasets/cake.csv', index_col=0) - X = cake[['recipe', 'replicate', 'temperature']].values - X[:,0] = np.unique(cake.values[:,1], return_inverse=True)[1] - X[:,1] -= 1 - y = cake['angle'].values - return X, y - def gen_basis_fns(): X, y = hepatitis() gam = LinearGAM(lam=.6, fit_intercept=False).fit(X, y) @@ -109,10 +56,10 @@ def faithful_data_poisson(): gam = PoissonGAM().gridsearch(X, y) plt.figure() - plt.bar(X, y, width=np.diff(X)[0], color='k') + plt.hist(faithful(return_X_y=False)['eruptions'], bins=200, color='k'); plt.plot(X, gam.predict(X), color='r') - plt.title('Best Lamba: {0:.2f}'.format(gam.lam)) + plt.title('Best Lambda: {0:.2f}'.format(gam.lam)) plt.savefig('imgs/pygam_poisson.png', dpi=300) def single_data_linear(): @@ -125,7 +72,7 @@ def single_data_linear(): plt.figure() plt.scatter(X, y, facecolor='gray', edgecolors='none') plt.plot(X, gam.predict(X), color='r') - plt.title('Best Lamba: {0:.2f}'.format(gam.lam)) + plt.title('Best Lambda: {0:.2f}'.format(gam.lam)) plt.savefig('imgs/pygam_single_pred_linear.png', dpi=300) def mcycle_data_linear(): @@ -172,8 +119,8 @@ def wage_data_linear(): titles = ['year', 'age', 'education'] for i, ax in enumerate(axs): - ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1)) - ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1], + ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i)) + ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i, width=.95)[1], c='r', ls='--') if i == 0: ax.set_ylim(-30,30); @@ -195,8 +142,8 @@ def default_data_logistic(n=500): titles = ['student', 'balance', 'income'] for i, ax in enumerate(axs): - ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1)) - ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1], + ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i)) + ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i, width=.95)[1], c='r', ls='--') ax.set_title(titles[i]) @@ -281,24 +228,17 @@ def trees_data_custom(): # plt.savefig('imgs/pygam_lambda_gridsearch.png', dpi=300) -def gen_multi_data(n=200): +def gen_multi_data(n=5000): """ multivariate Logistic problem """ - n = 5000 - x = np.random.rand(n,5) * 10 - 5 - cat = np.random.randint(0,4, n) - x = np.c_[x, cat] - log_odds = (-0.5*x[:,0]**2) + 5 +(-0.5*x[:,1]**2) + np.mod(x[:,-1], 2)*-30 - p = 1/(1+np.exp(-log_odds)).squeeze() - - obs = (np.random.rand(len(x)) < p).astype(np.int) + X, y = toy_classification(return_X_y=True, n=n) lgam = LogisticGAM() - lgam.fit(x, obs) + lgam.fit(X, y) plt.figure() - plt.plot(lgam.partial_dependence(np.sort(x, axis=0))) + plt.plot(lgam.partial_dependence(np.sort(X, axis=0))) plt.savefig('imgs/pygam_multi_pdep.png', dpi=300) plt.figure() diff --git a/imgs/pygam_basis.png b/imgs/pygam_basis.png index 188be586..3159fd32 100644 Binary files a/imgs/pygam_basis.png and b/imgs/pygam_basis.png differ diff --git a/imgs/pygam_cake_data.png b/imgs/pygam_cake_data.png index 40bcfdd0..d849f054 100644 Binary files a/imgs/pygam_cake_data.png and b/imgs/pygam_cake_data.png differ diff --git a/imgs/pygam_constraints.png b/imgs/pygam_constraints.png index 390aab5e..e517ab66 100644 Binary files a/imgs/pygam_constraints.png and b/imgs/pygam_constraints.png differ diff --git a/imgs/pygam_custom.png b/imgs/pygam_custom.png index 1b8413e4..13f59990 100644 Binary files a/imgs/pygam_custom.png and b/imgs/pygam_custom.png differ diff --git a/imgs/pygam_default_data_logistic.png b/imgs/pygam_default_data_logistic.png index c006bc1c..a64b29d7 100644 Binary files a/imgs/pygam_default_data_logistic.png and b/imgs/pygam_default_data_logistic.png differ diff --git a/imgs/pygam_mcycle_data_extrapolation.png b/imgs/pygam_mcycle_data_extrapolation.png index df2e948a..3fbab13a 100644 Binary files a/imgs/pygam_mcycle_data_extrapolation.png and b/imgs/pygam_mcycle_data_extrapolation.png differ diff --git a/imgs/pygam_mcycle_data_linear.png b/imgs/pygam_mcycle_data_linear.png index 9526343e..9d08dff7 100644 Binary files a/imgs/pygam_mcycle_data_linear.png and b/imgs/pygam_mcycle_data_linear.png differ diff --git a/imgs/pygam_poisson.png b/imgs/pygam_poisson.png index 5bf94b4f..d3253944 100644 Binary files a/imgs/pygam_poisson.png and b/imgs/pygam_poisson.png differ diff --git a/imgs/pygam_wage_data_linear.png b/imgs/pygam_wage_data_linear.png index 2249357d..ccfc6a81 100644 Binary files a/imgs/pygam_wage_data_linear.png and b/imgs/pygam_wage_data_linear.png differ