diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index 62eb78d92..3288d9ea8 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -95,9 +95,9 @@ def fit( and self.y_train is not None and self.features_list is not None ): - assert self.x_train.size != 0 and self.y_train.size != 0, ( - "Train data not set, nothing to fit." - ) + assert ( + self.x_train.size != 0 and self.y_train.size != 0 + ), "Train data not set, nothing to fit." fitted_model = self.model.model.fit(X=self.x_train, y=self.y_train) self.model = LocusToGeneModel( model=fitted_model, @@ -112,7 +112,7 @@ def _get_shap_explanation( self: LocusToGeneTrainer, model: LocusToGeneModel, ) -> Explanation: - """Get the SHAP values for the given model and data. We pass the full X matrix (without the labels) to interpret their shap values. + """Get the SHAP values for the given model and data. We sample the full X matrix (without the labels) to interpret their shap values. Args: model (LocusToGeneModel): Model to explain. @@ -133,12 +133,15 @@ def _get_shap_explanation( model.model, data=training_data, feature_perturbation="interventional", + model_output="probability", ) try: - return explainer(training_data) + return explainer(training_data.sample(n=1_000)) except Exception as e: if "Additivity check failed in TreeExplainer" in repr(e): - return explainer(training_data, check_additivity=False) + return explainer( + training_data.sample(n=1_000), check_additivity=False + ) else: raise @@ -191,9 +194,9 @@ def log_to_wandb( or self.features_list is None ): raise RuntimeError("Train data not set, we cannot log to W&B.") - assert self.x_train.size != 0 and self.y_train.size != 0, ( - "Train data not set, nothing to evaluate." - ) + assert ( + self.x_train.size != 0 and self.y_train.size != 0 + ), "Train data not set, nothing to evaluate." fitted_classifier = self.model.model y_predicted = fitted_classifier.predict(self.x_test) y_probas = fitted_classifier.predict_proba(self.x_test)