From 72259fca520b495c5eb4d142e74034b7b2ebbf48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Dec 2024 09:33:13 +0000 Subject: [PATCH 01/10] feat(prediction): add `model` as instance attribute --- src/gentropy/dataset/l2g_prediction.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 2bc286a40..72b6afdef 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -29,6 +29,8 @@ class L2GPrediction(Dataset): confidence of the prediction that a gene is causal to an association. """ + model: LocusToGeneModel | None = None + @classmethod def get_schema(cls: type[L2GPrediction]) -> StructType: """Provides the schema for the L2GPrediction dataset. @@ -85,7 +87,9 @@ def from_credible_set( .select_features(features_list) ) - return l2g_model.predict(fm, session) + predictions = l2g_model.predict(fm, session) + predictions.model = l2g_model # Set the model attribute + return predictions def to_disease_target_evidence( self: L2GPrediction, From 9e8c491961d7f9315011c04ac9ad18c34d3dc545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Dec 2024 11:14:55 +0000 Subject: [PATCH 02/10] feat: added `convert_map_type_to_columns` spark util --- src/gentropy/common/spark_helpers.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 64a8bceb7..c55d77001 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -885,3 +885,35 @@ def calculate_harmonic_sum(input_array: Column) -> Column: / f.pow(x["pos"], 2) / f.lit(sum(1 / ((i + 1) ** 2) for i in range(1000))), ) + + +def convert_map_type_to_columns(df: DataFrame, map_column: Column) -> list[Column]: + """Convert a MapType column into multiple columns, one for each key in the map. + + Args: + df (DataFrame): A Spark DataFrame + map_column (Column): A Spark Column of MapType + + Returns: + list[Column]: List of columns, one for each key in the map + + Examples: + >>> df = spark.createDataFrame([({'a': 1, 'b': 2},), ({'c':3},)], ["map_col"]) + >>> df.select(*convert_map_type_to_columns(df, f.col("map_col"))).show() + +----+----+----+ + | a| b| c| + +----+----+----+ + | 1| 2|null| + |null|null| 3| + +----+----+----+ + + """ + # Schema is agnostic of the map keys, I have to collect them first + keys = ( + df.select(f.explode(map_column)) + .select("key") + .distinct() + .rdd.flatMap(lambda x: x) + .collect() + ) + return [map_column.getItem(k).alias(k) for k in keys] From 450a9375862cc0e94a813cab1d543cb319291dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Dec 2024 11:38:41 +0000 Subject: [PATCH 03/10] feat(prediction): new method `explain` returns shapley values --- src/gentropy/dataset/l2g_prediction.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 72b6afdef..c80c449d6 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -6,10 +6,12 @@ from typing import TYPE_CHECKING, Type import pyspark.sql.functions as f +import shap from pyspark.sql import DataFrame from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session +from gentropy.common.spark_helpers import convert_map_type_to_columns from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex @@ -17,6 +19,7 @@ from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: + from numpy import ndarray as np_ndarray from pyspark.sql.types import StructType @@ -132,6 +135,29 @@ def to_disease_target_evidence( ) ) + def explain(self: L2GPrediction) -> np_ndarray: + """Extract Shapley values for the L2G predictions. + + Returns: + np_ndarray: Shapley values + + Raises: + ValueError: If the model is not set + """ + if self.model is None: + raise ValueError("Model not set, explainer cannot be created") + explainer = shap.TreeExplainer( + self.model.model, feature_perturbation="tree_path_dependent" + ) + features_matrix = ( + self.df.select( + *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) + ) + .toPandas() + .to_numpy() + ) + return explainer.shap_values(features_matrix) + def add_locus_to_gene_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix ) -> L2GPrediction: From 08ae6bd294bf3ed6d8a94ab801d0cbcddce7ed51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 4 Dec 2024 15:58:23 +0000 Subject: [PATCH 04/10] feat(prediction): `explain` returns predictions with shapley values --- .../assets/schemas/l2g_predictions.json | 11 ++++ src/gentropy/dataset/l2g_prediction.py | 60 +++++++++++++++---- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 57247a49a..8bda086a3 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -29,6 +29,17 @@ "valueContainsNull": true, "valueType": "float" } + }, + { + "metadata": {}, + "name": "shapleyValues", + "nullable": true, + "type": { + "keyType": "string", + "type": "map", + "valueContainsNull": false, + "valueType": "float" + } } ] } diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index c80c449d6..5833810e9 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -7,7 +7,7 @@ import pyspark.sql.functions as f import shap -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, Window from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -19,7 +19,6 @@ from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: - from numpy import ndarray as np_ndarray from pyspark.sql.types import StructType @@ -135,11 +134,12 @@ def to_disease_target_evidence( ) ) - def explain(self: L2GPrediction) -> np_ndarray: - """Extract Shapley values for the L2G predictions. + def explain(self: L2GPrediction) -> L2GPrediction: + """Extract Shapley values for the L2G predictions and add them as a map column. Returns: - np_ndarray: Shapley values + L2GPrediction: L2GPrediction object with an additional column 'shapleyValues' containing + feature name to Shapley value mappings Raises: ValueError: If the model is not set @@ -149,14 +149,50 @@ def explain(self: L2GPrediction) -> np_ndarray: explainer = shap.TreeExplainer( self.model.model, feature_perturbation="tree_path_dependent" ) - features_matrix = ( - self.df.select( - *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - ) - .toPandas() - .to_numpy() + features_matrix = self.df.select( + *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) + ).toPandas() + shapley_values = explainer.shap_values(features_matrix.to_numpy()) + + # Create arrays of Shapley values for each feature + features_list = list(features_matrix.columns) + shapley_arrays = { + feature: [row[i] for row in shapley_values] + for i, feature in enumerate(features_list) + } + return L2GPrediction( + _df=( + self.df.withColumn( + # Add row index to ensure correct mapping between the predictions and the shapley values + "tmp_idx", + f.row_number().over( + Window.orderBy(f.monotonically_increasing_id()) + ), + ) + .withColumn( + "shapleyValues", + f.create_map( + *[ + item + for feature in features_list + for item in [ + f.lit(feature), + f.array( + [f.lit(float(x)) for x in shapley_arrays[feature]] + ) + .getItem( + f.col("tmp_idx") - f.lit(1) + ) # we substract one because row_number starts counting from 1 + .cast("float"), + ] + ] + ), + ) + .drop("tmp_idx") + ), + _schema=self.get_schema(), + model=self.model, ) - return explainer.shap_values(features_matrix) def add_locus_to_gene_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix From 9d40e6254b18bc6a8112fe5f3571b11592c9d2cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 4 Dec 2024 16:15:16 +0000 Subject: [PATCH 05/10] chore: compute `shapleyValues` in the l2g step --- src/gentropy/l2g.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 16922ef78..1825c305e 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -285,7 +285,7 @@ def run_predict(self) -> None: ) predictions.filter( f.col("score") >= self.l2g_threshold - ).add_locus_to_gene_features(self.feature_matrix).df.coalesce( + ).add_locus_to_gene_features(self.feature_matrix).explain().df.coalesce( self.session.output_partitions ).write.mode(self.session.write_mode).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") From f407512d0f7c825c41c632c5abbb13285286c0e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Thu, 5 Dec 2024 17:53:13 +0000 Subject: [PATCH 06/10] refactor: use pandas udf instead --- src/gentropy/dataset/l2g_prediction.py | 98 +++++++++++++++----------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 5833810e9..a0581d0d6 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -5,9 +5,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Type +import pandas as pd import pyspark.sql.functions as f import shap -from pyspark.sql import DataFrame, Window +from pyspark.sql import DataFrame +from pyspark.sql.functions import pandas_udf from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -146,52 +148,64 @@ def explain(self: L2GPrediction) -> L2GPrediction: """ if self.model is None: raise ValueError("Model not set, explainer cannot be created") + + # Create explainer once explainer = shap.TreeExplainer( self.model.model, feature_perturbation="tree_path_dependent" ) - features_matrix = self.df.select( - *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - ).toPandas() - shapley_values = explainer.shap_values(features_matrix.to_numpy()) - - # Create arrays of Shapley values for each feature - features_list = list(features_matrix.columns) - shapley_arrays = { - feature: [row[i] for row in shapley_values] - for i, feature in enumerate(features_list) - } - return L2GPrediction( - _df=( - self.df.withColumn( - # Add row index to ensure correct mapping between the predictions and the shapley values - "tmp_idx", - f.row_number().over( - Window.orderBy(f.monotonically_increasing_id()) - ), - ) - .withColumn( - "shapleyValues", - f.create_map( - *[ - item - for feature in features_list - for item in [ + + # Create UDF for Shapley calculation + @pandas_udf("array") + def calculate_shapley_values(features_pd: pd.DataFrame) -> pd.Series: + """Calculate Shapley values for a batch of features. + + Args: + features_pd (pd.DataFrame): Batch of features. + + Returns: + pd.Series: Series of Shapley values for the batch. + """ + feature_array = features_pd.to_numpy() + shapley_values = explainer.shap_values(feature_array) + return pd.Series([list(row) for row in shapley_values]) + + df_w_features = self.df.select( + "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) + ) + features_list = [ + col for col in df_w_features.columns if col not in self.df.columns + ] + # Apply UDF and create map of feature names to Shapley values + result_df = ( + df_w_features.withColumn( + "shapley_array", + calculate_shapley_values( + f.array(*[f.col(feature) for feature in features_list]) + ), + ) + .withColumn( + "shapleyValues", + f.create_map( + *sum( + ( + ( f.lit(feature), - f.array( - [f.lit(float(x)) for x in shapley_arrays[feature]] - ) - .getItem( - f.col("tmp_idx") - f.lit(1) - ) # we substract one because row_number starts counting from 1 - .cast("float"), - ] - ] - ), - ) - .drop("tmp_idx") - ), - _schema=self.get_schema(), + f.element_at("shapley_array", f.lit(pos + 1)), + ) + for pos, feature in enumerate(features_list) + ), + (), + ) + ), + ) + .drop("shapley_array") + .select(*[field.name for field in self.get_schema().fields]) + ) + + return L2GPrediction( + _df=result_df, model=self.model, + _schema=self.get_schema(), ) def add_locus_to_gene_features( From f542395075e45e7a4717d9b660261c9c279a7256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 15:22:41 +0000 Subject: [PATCH 07/10] refactor: forget about udfs and get shaps single threaded --- src/gentropy/dataset/l2g_prediction.py | 70 +++++++++----------------- 1 file changed, 25 insertions(+), 45 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index a0581d0d6..d4b2e314c 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,14 +2,14 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Type -import pandas as pd import pyspark.sql.functions as f import shap from pyspark.sql import DataFrame -from pyspark.sql.functions import pandas_udf +from pyspark.sql.types import StructType from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -137,11 +137,10 @@ def to_disease_target_evidence( ) def explain(self: L2GPrediction) -> L2GPrediction: - """Extract Shapley values for the L2G predictions and add them as a map column. + """Extract Shapley values for the L2G predictions and add them as a map in an additional column. Returns: - L2GPrediction: L2GPrediction object with an additional column 'shapleyValues' containing - feature name to Shapley value mappings + L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings Raises: ValueError: If the model is not set @@ -149,62 +148,43 @@ def explain(self: L2GPrediction) -> L2GPrediction: if self.model is None: raise ValueError("Model not set, explainer cannot be created") - # Create explainer once explainer = shap.TreeExplainer( self.model.model, feature_perturbation="tree_path_dependent" ) - - # Create UDF for Shapley calculation - @pandas_udf("array") - def calculate_shapley_values(features_pd: pd.DataFrame) -> pd.Series: - """Calculate Shapley values for a batch of features. - - Args: - features_pd (pd.DataFrame): Batch of features. - - Returns: - pd.Series: Series of Shapley values for the batch. - """ - feature_array = features_pd.to_numpy() - shapley_values = explainer.shap_values(feature_array) - return pd.Series([list(row) for row in shapley_values]) - df_w_features = self.df.select( "*", *convert_map_type_to_columns(self.df, f.col("locusToGeneFeatures")) - ) + ).drop("shapleyValues") features_list = [ - col for col in df_w_features.columns if col not in self.df.columns + col for col in df_w_features.columns if col not in self.get_schema().names ] - # Apply UDF and create map of feature names to Shapley values - result_df = ( - df_w_features.withColumn( - "shapley_array", - calculate_shapley_values( - f.array(*[f.col(feature) for feature in features_list]) - ), + pdf = df_w_features.select(features_list).toPandas() + + # Calculate SHAP values + if pdf.shape[0] >= 10_000: + logging.warning( + "Calculating SHAP values for more than 10,000 rows. This may take a while..." + ) + shap_values = explainer.shap_values(pdf.to_numpy()) + for i, feature in enumerate(features_list): + pdf[f"shap_{feature}"] = [row[i] for row in shap_values] + + spark_session = df_w_features.sparkSession + return L2GPrediction( + _df=df_w_features.join( + # Convert df with shapley values to Spark and join with original df + spark_session.createDataFrame(pdf.to_dict(orient="records")), + features_list, ) .withColumn( "shapleyValues", f.create_map( *sum( - ( - ( - f.lit(feature), - f.element_at("shapley_array", f.lit(pos + 1)), - ) - for pos, feature in enumerate(features_list) - ), + ((f.lit(col), f.col(f"shap_{col}")) for col in features_list), (), ) ), ) - .drop("shapley_array") - .select(*[field.name for field in self.get_schema().fields]) - ) - - return L2GPrediction( - _df=result_df, - model=self.model, + .select(*self.get_schema().names), _schema=self.get_schema(), ) From 9403fe620a9303613fbc36c6051b50d3a16aeff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 15:23:23 +0000 Subject: [PATCH 08/10] chore: remove reference to chromatin interaction data in HF card --- src/gentropy/method/l2g/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 336efeb7f..091a970d4 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -194,7 +194,6 @@ def _create_hugging_face_model_card( - Distance: (from credible set variants to gene) - Molecular QTL Colocalization - - Chromatin Interaction: (e.g., promoter-capture Hi-C) - Variant Pathogenicity: (from VEP) More information at: https://opentargets.github.io/gentropy/python_api/methods/l2g/_l2g/ From 1bc6f3a1d03c05f6c0e5065ae90ec58acd94d850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 16:04:53 +0000 Subject: [PATCH 09/10] fix(l2g_prediction): methods that return new instance preserve attribute --- src/gentropy/dataset/l2g_prediction.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index d4b2e314c..f8534c4f7 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -186,6 +186,7 @@ def explain(self: L2GPrediction) -> L2GPrediction: ) .select(*self.get_schema().names), _schema=self.get_schema(), + model=self.model, ) def add_locus_to_gene_features( @@ -237,4 +238,5 @@ def add_locus_to_gene_features( return L2GPrediction( _df=self.df.join(aggregated_features, on=prediction_id_columns, how="left"), _schema=self.get_schema(), + model=self.model, ) From 8420933ac7fb3f3c4ab23cfa1c186cf8c097eb16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Dec 2024 16:06:03 +0000 Subject: [PATCH 10/10] feat(dataset): `filter` method preserves all instance attributes --- src/gentropy/dataset/dataset.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index 67fe05eaf..442c0e6ba 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -156,15 +156,25 @@ def from_parquet( def filter(self: Self, condition: Column) -> Self: """Creates a new instance of a Dataset with the DataFrame filtered by the condition. + Preserves all attributes from the original instance. + Args: condition (Column): Condition to filter the DataFrame Returns: - Self: Filtered Dataset + Self: Filtered Dataset with preserved attributes """ df = self._df.filter(condition) class_constructor = self.__class__ - return class_constructor(_df=df, _schema=class_constructor.get_schema()) + # Get all attributes from the current instance + attrs = { + key: value + for key, value in self.__dict__.items() + if key not in ["_df", "_schema"] + } + return class_constructor( + _df=df, _schema=class_constructor.get_schema(), **attrs + ) def validate_schema(self: Dataset) -> None: """Validate DataFrame schema against expected class schema.