Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(l2gprediction): add score explanation based on features #939

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/gentropy/assets/schemas/l2g_predictions.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@
"valueContainsNull": true,
"valueType": "float"
}
},
{
"metadata": {},
"name": "shapleyValues",
"nullable": true,
"type": {
"keyType": "string",
"type": "map",
"valueContainsNull": false,
"valueType": "float"
}
}
]
}
32 changes: 32 additions & 0 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,35 @@ def calculate_harmonic_sum(input_array: Column) -> Column:
/ f.pow(x["pos"], 2)
/ f.lit(sum(1 / ((i + 1) ** 2) for i in range(1000))),
)


def convert_map_type_to_columns(df: DataFrame, map_column: Column) -> list[Column]:
"""Convert a MapType column into multiple columns, one for each key in the map.

Args:
df (DataFrame): A Spark DataFrame
map_column (Column): A Spark Column of MapType

Returns:
list[Column]: List of columns, one for each key in the map

Examples:
>>> df = spark.createDataFrame([({'a': 1, 'b': 2},), ({'c':3},)], ["map_col"])
>>> df.select(*convert_map_type_to_columns(df, f.col("map_col"))).show()
+----+----+----+
| a| b| c|
+----+----+----+
| 1| 2|null|
|null|null| 3|
+----+----+----+
<BLANKLINE>
"""
# Schema is agnostic of the map keys, I have to collect them first
keys = (
df.select(f.explode(map_column))
.select("key")
.distinct()
.rdd.flatMap(lambda x: x)
.collect()
)
return [map_column.getItem(k).alias(k) for k in keys]
14 changes: 12 additions & 2 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,25 @@ def from_parquet(
def filter(self: Self, condition: Column) -> Self:
"""Creates a new instance of a Dataset with the DataFrame filtered by the condition.

Preserves all attributes from the original instance.

Args:
condition (Column): Condition to filter the DataFrame

Returns:
Self: Filtered Dataset
Self: Filtered Dataset with preserved attributes
"""
df = self._df.filter(condition)
class_constructor = self.__class__
return class_constructor(_df=df, _schema=class_constructor.get_schema())
# Get all attributes from the current instance
attrs = {
key: value
for key, value in self.__dict__.items()
if key not in ["_df", "_schema"]
}
return class_constructor(
_df=df, _schema=class_constructor.get_schema(), **attrs
)

def validate_schema(self: Dataset) -> None:
"""Validate DataFrame schema against expected class schema.
Expand Down
64 changes: 63 additions & 1 deletion src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Type

import pyspark.sql.functions as f
import shap
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.session import Session
from gentropy.common.spark_helpers import convert_map_type_to_columns
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
from gentropy.dataset.study_index import StudyIndex
Expand All @@ -29,6 +33,8 @@ class L2GPrediction(Dataset):
confidence of the prediction that a gene is causal to an association.
"""

model: LocusToGeneModel | None = None

@classmethod
def get_schema(cls: type[L2GPrediction]) -> StructType:
"""Provides the schema for the L2GPrediction dataset.
Expand Down Expand Up @@ -85,7 +91,9 @@ def from_credible_set(
.select_features(features_list)
)

return l2g_model.predict(fm, session)
predictions = l2g_model.predict(fm, session)
predictions.model = l2g_model # Set the model attribute
return predictions

def to_disease_target_evidence(
self: L2GPrediction,
Expand Down Expand Up @@ -128,6 +136,59 @@ def to_disease_target_evidence(
)
)

def explain(self: L2GPrediction) -> L2GPrediction:
"""Extract Shapley values for the L2G predictions and add them as a map in an additional column.

Returns:
L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings

Raises:
ValueError: If the model is not set
"""
if self.model is None:
raise ValueError("Model not set, explainer cannot be created")

explainer = shap.TreeExplainer(
self.model.model, feature_perturbation="tree_path_dependent"
)
df_w_features = self.df.select(
"*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures"))
).drop("shapleyValues")
features_list = [
col for col in df_w_features.columns if col not in self.get_schema().names
]
pdf = df_w_features.select(features_list).toPandas()

# Calculate SHAP values
if pdf.shape[0] >= 10_000:
logging.warning(
"Calculating SHAP values for more than 10,000 rows. This may take a while..."
)
shap_values = explainer.shap_values(pdf.to_numpy())
for i, feature in enumerate(features_list):
pdf[f"shap_{feature}"] = [row[i] for row in shap_values]

spark_session = df_w_features.sparkSession
return L2GPrediction(
_df=df_w_features.join(
# Convert df with shapley values to Spark and join with original df
spark_session.createDataFrame(pdf.to_dict(orient="records")),
features_list,
)
.withColumn(
"shapleyValues",
f.create_map(
*sum(
((f.lit(col), f.col(f"shap_{col}")) for col in features_list),
(),
)
),
)
.select(*self.get_schema().names),
_schema=self.get_schema(),
model=self.model,
)

def add_locus_to_gene_features(
self: L2GPrediction, feature_matrix: L2GFeatureMatrix, features_list: list[str]
) -> L2GPrediction:
Expand Down Expand Up @@ -166,4 +227,5 @@ def add_locus_to_gene_features(
aggregated_features, on=["studyLocusId", "geneId"], how="left"
),
_schema=self.get_schema(),
model=self.model,
)
2 changes: 1 addition & 1 deletion src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def run_predict(self) -> None:
f.col("score") >= self.l2g_threshold
).add_locus_to_gene_features(
self.feature_matrix, self.features_list
).df.coalesce(self.session.output_partitions).write.mode(
).explain().df.coalesce(self.session.output_partitions).write.mode(
self.session.write_mode
).parquet(self.predictions_path)
self.session.logger.info("L2G predictions saved successfully.")
Expand Down
1 change: 0 additions & 1 deletion src/gentropy/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def _create_hugging_face_model_card(

- Distance: (from credible set variants to gene)
- Molecular QTL Colocalization
- Chromatin Interaction: (e.g., promoter-capture Hi-C)
- Variant Pathogenicity: (from VEP)

More information at: https://opentargets.github.io/gentropy/python_api/methods/l2g/_l2g/
Expand Down
Loading