Skip to content

Commit

Permalink
fix plot for linearized trees
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrenodet committed Dec 20, 2024
1 parent 92209ce commit 714bab9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
55 changes: 36 additions & 19 deletions examples/plot_linearized.ipynb

Large diffs are not rendered by default.

16 changes: 7 additions & 9 deletions mislabeled/probe/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,17 @@ def grad_p(self, X, y):
# gradients w.r.t. the parameters (weight, intercept)
dl_dy = self.grad_y(X, y)

X_p = X
X_p = X if not sp.issparse(X) else X.toarray()
if self.intercept is not None:
X_p = np.hstack((X, np.ones((X.shape[0], 1))))
if sp.issparse(X_p):
X_p = X_p.toarray()
X_p = np.hstack((X_p, np.ones((X_p.shape[0], 1))))

return dl_dy[:, :, None] * X_p[:, None, :]

def hessian(self, X, y):
X_p = X
X_p = X if not sp.issparse(X) else X.toarray()

if self.intercept is not None:
X_p = np.hstack((X, np.ones((X.shape[0], 1))))
if sp.issparse(X_p):
X_p = X_p.toarray()
X_p = np.hstack((X_p, np.ones((X_p.shape[0], 1))))

if self.loss == "l2":
H = 2.0 * X_p.T @ X_p
Expand Down Expand Up @@ -179,7 +177,7 @@ def linearize_linear_model_ridge(estimator, X, y):

@linearize.register(SGDClassifier)
def linearize_linear_model_sgdclassifier(estimator, X, y):
X, y = check_X_y(X, y, accept_sparse=False, dtype=[np.float64, np.float32])
X, y = check_X_y(X, y, accept_sparse=True, dtype=[np.float64, np.float32])
coef = estimator.coef_.T
intercept = estimator.intercept_ if estimator.fit_intercept else None
linear = LinearModel(coef, intercept, loss=estimator.loss, regul=estimator.alpha)
Expand Down

0 comments on commit 714bab9

Please sign in to comment.