Skip to content

Commit

Permalink
Merge pull request #12 from Orange-OpenSource/lin-tree-fix
Browse files Browse the repository at this point in the history
Fix plot for linearized trees
  • Loading branch information
pierrenodet authored Dec 20, 2024
2 parents 92209ce + 75f6e18 commit ff2d499
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 34 deletions.
56 changes: 37 additions & 19 deletions examples/plot_linearized.ipynb

Large diffs are not rendered by default.

25 changes: 10 additions & 15 deletions mislabeled/probe/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,17 @@ def grad_p(self, X, y):
# gradients w.r.t. the parameters (weight, intercept)
dl_dy = self.grad_y(X, y)

X_p = X
X_p = X if not sp.issparse(X) else X.toarray()
if self.intercept is not None:
X_p = np.hstack((X, np.ones((X.shape[0], 1))))
if sp.issparse(X_p):
X_p = X_p.toarray()
X_p = np.hstack((X_p, np.ones((X_p.shape[0], 1))))

return dl_dy[:, :, None] * X_p[:, None, :]

def hessian(self, X, y):
X_p = X
X_p = X if not sp.issparse(X) else X.toarray()

if self.intercept is not None:
X_p = np.hstack((X, np.ones((X.shape[0], 1))))
if sp.issparse(X_p):
X_p = X_p.toarray()
X_p = np.hstack((X_p, np.ones((X_p.shape[0], 1))))

if self.loss == "l2":
H = 2.0 * X_p.T @ X_p
Expand Down Expand Up @@ -179,7 +177,7 @@ def linearize_linear_model_ridge(estimator, X, y):

@linearize.register(SGDClassifier)
def linearize_linear_model_sgdclassifier(estimator, X, y):
X, y = check_X_y(X, y, accept_sparse=False, dtype=[np.float64, np.float32])
X, y = check_X_y(X, y, accept_sparse=True, dtype=[np.float64, np.float32])
coef = estimator.coef_.T
intercept = estimator.intercept_ if estimator.fit_intercept else None
linear = LinearModel(coef, intercept, loss=estimator.loss, regul=estimator.alpha)
Expand Down Expand Up @@ -213,19 +211,16 @@ def linearize_trees(
X,
y,
default_linear_model=dict(
classification=LogisticRegressionCV(
max_iter=1000, fit_intercept=False, n_jobs=-1
),
regression=RidgeCV(fit_intercept=False),
classification=LogisticRegression(max_iter=1000),
regression=RidgeCV(),
),
):
leaves = OneHotEncoder().fit_transform(estimator.apply(X).reshape(X.shape[0], -1))
if is_classifier(estimator):
linear = default_linear_model["classification"]
linear.fit(leaves, y)
else:
linear = default_linear_model["regression"]
linear.fit(leaves, y)
linear.fit(leaves, y)
return linearize(linear, leaves, y)


Expand Down

0 comments on commit ff2d499

Please sign in to comment.