Skip to content

Commit

Permalink
chore(trainer): sample train set to base model importances on
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed committed Feb 24, 2025
1 parent a56f2d0 commit a5ff33f
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/gentropy/method/l2g/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def fit(
and self.y_train is not None
and self.features_list is not None
):
assert self.x_train.size != 0 and self.y_train.size != 0, (
"Train data not set, nothing to fit."
)
assert (
self.x_train.size != 0 and self.y_train.size != 0
), "Train data not set, nothing to fit."
fitted_model = self.model.model.fit(X=self.x_train, y=self.y_train)
self.model = LocusToGeneModel(
model=fitted_model,
Expand All @@ -112,7 +112,7 @@ def _get_shap_explanation(
self: LocusToGeneTrainer,
model: LocusToGeneModel,
) -> Explanation:
"""Get the SHAP values for the given model and data. We pass the full X matrix (without the labels) to interpret their shap values.
"""Get the SHAP values for the given model and data. We sample the full X matrix (without the labels) to interpret their shap values.
Args:
model (LocusToGeneModel): Model to explain.
Expand All @@ -133,12 +133,15 @@ def _get_shap_explanation(
model.model,
data=training_data,
feature_perturbation="interventional",
model_output="probability",
)
try:
return explainer(training_data)
return explainer(training_data.sample(n=1_000))
except Exception as e:
if "Additivity check failed in TreeExplainer" in repr(e):
return explainer(training_data, check_additivity=False)
return explainer(
training_data.sample(n=1_000), check_additivity=False
)
else:
raise

Expand Down Expand Up @@ -191,9 +194,9 @@ def log_to_wandb(
or self.features_list is None
):
raise RuntimeError("Train data not set, we cannot log to W&B.")
assert self.x_train.size != 0 and self.y_train.size != 0, (
"Train data not set, nothing to evaluate."
)
assert (
self.x_train.size != 0 and self.y_train.size != 0
), "Train data not set, nothing to evaluate."
fitted_classifier = self.model.model
y_predicted = fitted_classifier.predict(self.x_test)
y_probas = fitted_classifier.predict_proba(self.x_test)
Expand Down

0 comments on commit a5ff33f

Please sign in to comment.