diff --git a/.github/workflows/artifact.yml b/.github/workflows/artifact.yml index 161d2767a..a122575a2 100644 --- a/.github/workflows/artifact.yml +++ b/.github/workflows/artifact.yml @@ -34,7 +34,9 @@ jobs: run: "gcloud info" - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@v3.6.0 + with: + image: tonistiigi/binfmt:qemu-v7.0.0-28 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 961f0bef5..8acf208f6 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -2,7 +2,10 @@ name: Release "on": push: - branches: ["main", "release/**", "dev"] + branches: + - main + + workflow_dispatch: concurrency: group: deploy @@ -13,19 +16,15 @@ env: jobs: release: + # Ensure the workflow can be run only from main & dev branches! + if: ${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' }} runs-on: ubuntu-latest - concurrency: release outputs: released: ${{ steps.semrelease.outputs.released }} permissions: - # NOTE: this enables trusted publishing. - # See https://github.com/pypa/gh-action-pypi-publish/tree/release/v1#trusted-publishing - # and https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ id-token: write contents: write - steps: - # NOTE: commits using GITHUB_TOKEN does not trigger workflows - uses: actions/create-github-app-token@v1 id: trigger-token with: @@ -34,77 +33,62 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + ref: ${{ github.ref_name }} repository: opentargets/gentropy token: ${{ secrets.GITHUB_TOKEN }} persist-credentials: false - - - name: Python Semantic Release + - uses: python-semantic-release/python-semantic-release@v9.19.1 id: semrelease - uses: python-semantic-release/python-semantic-release@v9.16.1 with: github_token: ${{ steps.trigger-token.outputs.token }} - - - name: Publish package to GitHub Release - uses: python-semantic-release/upload-to-gh-release@main - # NOTE: semrelease output is a string, so we need to compare it to a string - if: steps.semrelease.outputs.released == 'true' + - uses: python-semantic-release/publish-action@v9.21.0 + if: ${{ steps.semrelease.outputs.released }} == 'true' with: - # NOTE: allow to start the workflow when push action on tag gets executed - # requires using GH_APP to authenitcate, otherwise push authorised with - # the GITHUB_TOKEN does not trigger the tag artifact workflow. - # see https://github.com/actions/create-github-app-token - github_token: ${{ secrets.GITHUB_TOKEN }} + github_token: ${{ steps.trigger-token.outputs.token }} tag: ${{ steps.semrelease.outputs.tag }} - - - name: Store the distribution packages + - uses: actions/upload-artifact@v4 if: steps.semrelease.outputs.released == 'true' - uses: actions/upload-artifact@v4 with: name: python-package-distributions path: dist/ - publish-to-pypi: + publish-to-testpypi: + name: Publish 📦 in TestPyPI needs: release - name: Publish 📦 in PyPI if: github.ref == 'refs/heads/main' && needs.release.outputs.released == 'true' runs-on: ubuntu-latest environment: - name: pypi - url: https://pypi.org/p/gentropy + name: testpypi + url: https://test.pypi.org/p/gentropy permissions: id-token: write # IMPORTANT: mandatory for trusted publishing steps: - - name: Download all the dists - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v4 with: name: python-package-distributions path: dist/ - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + - uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ - publish-to-testpypi: - name: Publish 📦 in TestPyPI - needs: release + publish-to-pypi: + needs: + - release + - publish-to-testpypi + name: Publish 📦 in PyPI if: github.ref == 'refs/heads/main' && needs.release.outputs.released == 'true' runs-on: ubuntu-latest - environment: - name: testpypi - url: https://test.pypi.org/p/gentropy - + name: pypi + url: https://pypi.org/p/gentropy permissions: id-token: write # IMPORTANT: mandatory for trusted publishing - steps: - - name: Download all the dists - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v4 with: name: python-package-distributions path: dist/ - - name: Publish distribution 📦 to TestPyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - repository-url: https://test.pypi.org/legacy/ + - uses: pypa/gh-action-pypi-publish@release/v1 documentation: needs: release @@ -115,23 +99,10 @@ jobs: with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION_DEFAULT }} - - name: Install uv - uses: astral-sh/setup-uv@v5 - - name: Load cached venv - id: cached-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: | - venv-${{ runner.os }}-\ - ${{ env.PYTHON_VERSION_DEFAULT }}-\ - ${{ hashFiles('**/uv.lock') }} - - name: Install dependencies - if: steps.cached-dependencies.outputs.cache-hit != 'true' - run: uv sync --group docs + - uses: astral-sh/setup-uv@v5 + - run: uv sync --group docs - name: Publish docs run: uv run mkdocs gh-deploy --force diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 843a326dc..c4f794000 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: autofix_commit_msg: "chore: pre-commit auto fixes [...]" repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.4 + rev: v0.9.9 hooks: - id: ruff args: @@ -57,14 +57,14 @@ repos: exclude: "CHANGELOG.md" - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook - rev: v9.18.0 + rev: v9.21.0 hooks: - id: commitlint additional_dependencies: ["@commitlint/config-conventional@18.6.3"] stages: [commit-msg] - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.13.0" + rev: "v1.15.0" hooks: - id: mypy args: @@ -97,10 +97,10 @@ repos: - id: beautysh - repo: https://github.com/jsh9/pydoclint - rev: 0.5.9 + rev: 0.6.2 hooks: - id: pydoclint - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.5.22 + rev: 0.6.3 hooks: - id: uv-lock diff --git a/Makefile b/Makefile index afd2fa30b..36aea8f53 100644 --- a/Makefile +++ b/Makefile @@ -4,12 +4,13 @@ REGION ?= europe-west1 APP_NAME ?= $$(cat pyproject.toml | grep -m 1 "name" | cut -d" " -f3 | sed 's/"//g') PACKAGE_VERSION ?= $(shell grep -m 1 'version = ' pyproject.toml | sed 's/version = "\(.*\)"/\1/') USER_SAFE ?= $(shell echo $(USER) | tr '[:upper:]' '[:lower:]') +CLUSTER_TIMEOUT ?= 60m # NOTE: git rev-parse will always return the HEAD if it sits in the tag, # this way we can distinguish the tag vs branch name ifeq ($(shell git rev-parse --abbrev-ref HEAD),HEAD) - REF := $(shell git describe --exact-match --tags) + REF ?= $(shell git describe --exact-match --tags) else - REF := $(shell git rev-parse --abbrev-ref HEAD) + REF ?= $(shell git rev-parse --abbrev-ref HEAD) endif CLEAN_PACKAGE_VERSION := $(shell echo "$(PACKAGE_VERSION)" | tr -cd '[:alnum:]') @@ -54,8 +55,8 @@ sync-gentropy-cli-script: ## Synchronize the gentropy cli script @gcloud storage cp src/gentropy/cli.py ${BUCKET_NAME}/cli.py create-dev-cluster: sync-cluster-init-script sync-gentropy-cli-script ## Spin up a simple dataproc cluster with all dependencies for development purposes - @echo "Making sure the branch is in sync with remote, so cluster can install gentropy dev version..." - @./utils/clean_status.sh || (echo "ERROR: Commit and push or stash local changes, to have up to date cluster"; exit 1) + @echo "Making sure the cluster can reference to ${REF} branch to install gentropy..." + @./utils/clean_status.sh ${REF} || (echo "ERROR: Commit and push local changes, to have up to date cluster"; exit 1) @echo "Creating Dataproc Dev Cluster" gcloud config set project ${PROJECT_ID} gcloud dataproc clusters create "ot-genetics-dev-${CLEAN_PACKAGE_VERSION}-$(USER_SAFE)" \ @@ -72,7 +73,7 @@ create-dev-cluster: sync-cluster-init-script sync-gentropy-cli-script ## Spin up --optional-components=JUPYTER \ --enable-component-gateway \ --labels team=open-targets,subteam=gentropy,created_by=${USER_SAFE},environment=development, \ - --max-idle=60m + --max-idle=${CLUSTER_TIMEOUT} update-dev-cluster: build ## Reinstalls the package on the dev-cluster @echo "Updating Dataproc Dev Cluster" diff --git a/docs/assets/imgs/development-flow.png b/docs/assets/imgs/development-flow.png new file mode 100644 index 000000000..ee91ab17d Binary files /dev/null and b/docs/assets/imgs/development-flow.png differ diff --git a/docs/development/contributing.md b/docs/development/contributing.md index 99b199de1..917b4c273 100644 --- a/docs/development/contributing.md +++ b/docs/development/contributing.md @@ -84,3 +84,27 @@ For more details on each of these steps, see the sections below. ### Support for python versions As of version 2.1.X gentropy supports multiple python versions. To ensure compatibility with all supported versions, unit tests are run for each of the minor python release from 3.10 to 3.12. Make sure your changes are compatible with all supported versions. + +### Development process + +The development follows simplified Git Flow process that includes usage of + +- `dev` (development branch) +- `feature` branches +- `main` (production branch) + +The development starts with creating new `feature` branch based on the `dev` branch. Once the feature is ready, the Pull Request for the `dev` branch is created and CI/CD Checks are performed to ensure that the code is compliant with the project conventions. Once the PR is approved, the feature branch is merged into the `dev` branch. + +#### Development releases + +One can create the dev release tagged by `vX.Y.Z-dev.V` tag. This release will not trigger the CI/CD pipeline to publish the package to the PyPi repository. The release is done by triggering the `Release` GitHub action. + +#### Production releases + +Once per week, the `Trigger PR for release` github action creates a Pull Request from `dev` to `main` branch, when the PR is approved, the `Release` GitHub action is triggered to create a production release tagged by `vX.Y.Z` tag. This release triggers the CI/CD pipeline to publish the package to the _TestPyPi_ repository. If it is successful, then the actual deployment to the _PyPI_ repository is done. The deployment to the PyPi repository must be verified by the gentropy maintainer. + +Below you can find a simplified diagram of the development process. + +
+ development process +
diff --git a/docs/development/troubleshooting.md b/docs/development/troubleshooting.md index 762d40dd0..de5937b8c 100644 --- a/docs/development/troubleshooting.md +++ b/docs/development/troubleshooting.md @@ -42,19 +42,45 @@ This can be resolved by adding the follow line to your `~/.zshrc`: ## Creating development dataproc cluster (OT users only) -To start dataproc cluster in the development mode run +!!! info "Requirements" + + To create the cluster, you need to auth to the google cloud + + ```bash + gcloud auth login + ``` + +To start dataproc cluster in the development mode run. ```bash -make create-dev-cluster +make create-dev-cluster REF=dev ``` -!!! note "Tip" -This command will work, provided you have fully commited and pushed all your changes to the remote repository. +`REF` - remote branch available at the [gentropy repository](https://github.com/opentargets/gentropy) + +During cluster [initialization actions](https://cloud.google.com/dataproc/docs/concepts/configuring-clusters/init-actions#important_considerations_and_guidelines) the `utils/install_dependencies_on_cluster.sh` script is run, that installs `gentropy` package from the remote repository by using VCS support, hence it does not require the **gentropy package whl artifact** to be prepared in the Google Cloud Storage before the make command can be run. + +Check details how to make a package installable by VCS in [pip documentation](https://pip.pypa.io/en/stable/topics/vcs-support/). + +!!! note "How `create-dev-cluster` works" + + This command will work, provided you have done one of: + + - run `make create-dev-cluster REF=dev`, since the REF is requested, the cluster will attempt to install it from the remote repository. + - run `make create-dev-cluster` without specifying the REF or specifying REF that points to your local branch will request branch name you are checkout on your local repository, if any changes are pending locally, the cluster can not be created, it requires stashing or pushing the changes to the remote. + + The command will create a new dataproc cluster with the following configuration: + + - package installed from the requested **REF** (for example `dev` or `feature/xxx`) + - uv installed in the cluster (to speed up the installation and dependency resolution process) + - cli script to run gentropy steps + +!!! tip "Dataproc cluster timeout" -The command will create a new dataproc cluster with the following configuration: + By default the cluster will **delete itself** when running for **60 minutes after the last submitted job to the cluster was successfully completed** (running jobs interactively via Jupyter or Jupyter lab is not treated as submitted job). To preserve the cluster for arbitrary period (**for instance when the cluster is used only for interactive jobs**) increase the cluster timeout: -- package installed from the current branch you are checkout on (for example `dev` or `feature/xxx`) -- uv installed in the cluster (to speed up the installation and dependency resolution process) -- cli script to run gentropy steps + ```bash + make create-dev-cluster CLUSTER_TIMEOUT=1d REF=dev # 60m 1h 1d (by default 60m) + ``` -This process requires gentropy to be installable by git repository - see VCS support in [pip documentation](https://pip.pypa.io/en/stable/topics/vcs-support/). + For the reference on timeout format check [gcloud documentation](https://cloud.google.com/sdk/gcloud/reference/dataproc/clusters/create#--max-idle) diff --git a/docs/python_api/datasources/ukb_ppp_eur/_ukb_ppp_eur.md b/docs/python_api/datasources/ukb_ppp_eur/_ukb_ppp_eur.md index e416e1f32..e05e4c183 100644 --- a/docs/python_api/datasources/ukb_ppp_eur/_ukb_ppp_eur.md +++ b/docs/python_api/datasources/ukb_ppp_eur/_ukb_ppp_eur.md @@ -1,7 +1,8 @@ --- -title: UKB-PPP (EUR) +title: UK Biobank Pharma Proteomics Project (UKB-PPP) (EUR) --- The UKB-PPP is a collaboration between the UK Biobank (UKB) and thirteen biopharmaceutical companies characterising the plasma proteomic profiles of 54,219 UKB participants. -The original data is available at https://www.synapse.org/#!Synapse:syn51364943/. The associated paper is https://www.nature.com/articles/s41586-023-06592-6. +The original data is available here: https://www.synapse.org/Synapse:syn51364943/wiki/622119. +The associated paper is https://www.nature.com/articles/s41586-023-06592-6. diff --git a/pyproject.toml b/pyproject.toml index 161b9bfb6..e60bf68d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "skops (>=0.11.0, <0.12.0)", "shap (>=0.46, <0.47)", "matplotlib (>=3.10.0, <3.11.0)", - "google-cloud-secret-manager (>=2.12.6, <2.13.0)", + "google-cloud-secret-manager (>=2.12.6, <2.24.0)", "google-cloud-storage (>=2.14.0, <3.1.0)", ] classifiers = [ @@ -92,8 +92,8 @@ packages = ["src/gentropy"] match = "(main|master)" prerelease = false -[tool.semantic_release.branches."release"] -match = "release/*" +[tool.semantic_release.branches.dev] +match = "dev" prerelease = true prerelease_token = "rc" @@ -101,19 +101,11 @@ prerelease_token = "rc" dist_glob_patterns = ["dist/*"] upload_to_vcs_release = true -[tool.semantic_release.changelog] +[tool.semantic-release.changelog.default_templates] changelog_file = "CHANGELOG.md" -exclude_commit_patterns = ["chore\\(release\\):"] - -[tool.semantic_release.branches."step"] -match = "(build|chore|ci|docs|feat|fix|perf|style|refactor|test)" -prerelease = true -prerelease_token = "alpha" -[tool.semantic_release.branches."dev"] -match = "dev" -prerelease = true -prerelease_token = "dev" +[tool.semantic-release.changelog] +exclude_commit_patterns = ["chore\\(release\\):"] [build-system] requires = ["hatchling"] diff --git a/src/gentropy/assets/schemas/amino_acid_variants.json b/src/gentropy/assets/schemas/amino_acid_variants.json index b2fce522b..7aa876401 100644 --- a/src/gentropy/assets/schemas/amino_acid_variants.json +++ b/src/gentropy/assets/schemas/amino_acid_variants.json @@ -14,7 +14,7 @@ }, { "metadata": {}, - "name": "inSilicoPredictors", + "name": "variantEffect", "nullable": true, "type": { "containsNull": true, diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 57247a49a..1d100bf94 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -21,14 +21,41 @@ }, { "metadata": {}, - "name": "locusToGeneFeatures", + "name": "features", "nullable": true, "type": { - "keyType": "string", - "type": "map", - "valueContainsNull": true, - "valueType": "float" + "containsNull": false, + "elementType": { + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": false, + "type": "string" + }, + { + "metadata": {}, + "name": "value", + "nullable": false, + "type": "float" + }, + { + "metadata": {}, + "name": "shapValue", + "nullable": true, + "type": "float" + } + ], + "type": "struct" + }, + "type": "array" } + }, + { + "name": "shapBaseValue", + "type": "float", + "nullable": true, + "metadata": {} } ] } diff --git a/src/gentropy/assets/schemas/variant_index.json b/src/gentropy/assets/schemas/variant_index.json index 1f1ef787c..6f5f5869e 100644 --- a/src/gentropy/assets/schemas/variant_index.json +++ b/src/gentropy/assets/schemas/variant_index.json @@ -32,7 +32,7 @@ }, { "metadata": {}, - "name": "inSilicoPredictors", + "name": "variantEffect", "nullable": true, "type": { "containsNull": true, diff --git a/src/gentropy/assets/schemas/vep_json_output.json b/src/gentropy/assets/schemas/vep_json_output.json index 14aae6b84..ce55862b6 100644 --- a/src/gentropy/assets/schemas/vep_json_output.json +++ b/src/gentropy/assets/schemas/vep_json_output.json @@ -20,30 +20,12 @@ "containsNull": true, "elementType": { "fields": [ - { - "metadata": {}, - "name": "conservation", - "nullable": true, - "type": "double" - }, { "metadata": {}, "name": "hgvsg", "nullable": true, "type": "string" }, - { - "metadata": {}, - "name": "cadd_phred", - "nullable": true, - "type": "double" - }, - { - "metadata": {}, - "name": "cadd_raw", - "nullable": true, - "type": "double" - }, { "metadata": {}, "name": "consequence_terms", @@ -65,12 +47,6 @@ "name": "variant_allele", "nullable": true, "type": "string" - }, - { - "metadata": {}, - "name": "gene_id", - "nullable": true, - "type": "string" } ], "type": "struct" diff --git a/src/gentropy/config.py b/src/gentropy/config.py index c5b6783fb..9bac9dbf0 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -436,6 +436,16 @@ class PanUKBBConfig(StepConfig): _target_: str = "gentropy.pan_ukb_ingestion.PanUKBBVariantIndexStep" +@dataclass +class LOFIngestionConfig(StepConfig): + """Step configuration for the ingestion of Loss-of-Function variant data generated by OTAR2075.""" + + lof_curation_dataset_path: str = MISSING + lof_curation_variant_annotations_path: str = MISSING + + _target_: str = "gentropy.lof_curation_ingestion.LOFIngestionStep" + + @dataclass class VariantIndexConfig(StepConfig): """Variant index step configuration.""" @@ -454,7 +464,7 @@ class _ConsequenceToPathogenicityScoreMap(TypedDict): ) vep_output_json_path: str = MISSING variant_index_path: str = MISSING - gnomad_variant_annotations_path: str | None = None + variant_annotations_path: list[str] | None = None hash_threshold: int = 300 consequence_to_pathogenicity_score: ClassVar[ list[_ConsequenceToPathogenicityScoreMap] @@ -739,6 +749,7 @@ def register_config() -> None: cs.store(group="step", name="pics", node=PICSConfig) cs.store(group="step", name="gnomad_variants", node=GnomadVariantConfig) cs.store(group="step", name="ukb_ppp_eur_sumstat_preprocess", node=UkbPppEurConfig) + cs.store(group="step", name="lof_curation_ingestion", node=LOFIngestionConfig) cs.store(group="step", name="variant_index", node=VariantIndexConfig) cs.store(group="step", name="variant_to_vcf", node=ConvertToVcfStepConfig) cs.store( diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index fa06faec2..1262043d4 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -174,15 +174,17 @@ def from_parquet( def filter(self: Self, condition: Column) -> Self: """Creates a new instance of a Dataset with the DataFrame filtered by the condition. + Preserves all attributes from the original instance. + Args: condition (Column): Condition to filter the DataFrame Returns: - Self: Filtered Dataset + Self: Filtered Dataset with preserved attributes """ - df = self._df.filter(condition) - class_constructor = self.__class__ - return class_constructor(_df=df, _schema=class_constructor.get_schema()) + filtered_df = self._df.filter(condition) + attrs = {k: v for k, v in self.__dict__.items() if k != "_df"} + return self.__class__(_df=filtered_df, **attrs) def validate_schema(self: Dataset) -> None: """Validate DataFrame schema against expected class schema. diff --git a/src/gentropy/dataset/l2g_features/other.py b/src/gentropy/dataset/l2g_features/other.py index f68b05cb4..e71b58544 100644 --- a/src/gentropy/dataset/l2g_features/other.py +++ b/src/gentropy/dataset/l2g_features/other.py @@ -100,6 +100,9 @@ def is_protein_coding_feature_logic( Returns: DataFrame: Feature dataset, with 1 if the gene is protein-coding, 0 if not. + + Raises: + AssertionError: when provided `genomic_window` is more or equal to 500kb. """ assert genomic_window <= 500_000, "Genomic window must be less than 500kb." genes_in_window = ( diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 255722414..d864c3aa5 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -2,14 +2,18 @@ from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING import pyspark.sql.functions as f +import shap from pyspark.sql import DataFrame +from pyspark.sql.types import StructType from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session +from gentropy.common.spark_helpers import pivot_df from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex @@ -17,6 +21,7 @@ from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: + from pandas import DataFrame as pd_dataframe from pyspark.sql.types import StructType @@ -47,6 +52,7 @@ def from_credible_set( credible_set: StudyLocus, feature_matrix: L2GFeatureMatrix, model_path: str | None, + features_list: list[str] | None = None, hf_token: str | None = None, download_from_hub: bool = True, ) -> L2GPrediction: @@ -57,19 +63,29 @@ def from_credible_set( credible_set (StudyLocus): Dataset containing credible sets from GWAS only feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + features_list (list[str] | None): Default list of features the model uses. Only used if the model is not downloaded from the Hub. CAUTION: This default list can differ from the actual list the model was trained on. hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. Returns: L2GPrediction: L2G scores for a set of credible sets. + + Raises: + AttributeError: If `features_list` is not provided and the model is not downloaded from the Hub. """ # Load the model if download_from_hub: # Model ID defaults to "opentargets/locus_to_gene" and it assumes the name of the classifier is "classifier.skops". model_id = model_path or "opentargets/locus_to_gene" - l2g_model = LocusToGeneModel.load_from_hub(model_id, hf_token) + l2g_model = LocusToGeneModel.load_from_hub(session, model_id, hf_token) elif model_path: - l2g_model = LocusToGeneModel.load_from_disk(model_path) + if not features_list: + raise AttributeError( + "features_list is required if the model is not downloaded from the Hub" + ) + l2g_model = LocusToGeneModel.load_from_disk( + session, path=model_path, features_list=features_list + ) # Prepare data fm = ( @@ -79,7 +95,7 @@ def from_credible_set( .select("studyLocusId") .join(feature_matrix._df, "studyLocusId") .filter(f.col("isProteinCoding") == 1) - ) + ), ) .fill_na() .select_features(l2g_model.features_list) @@ -127,7 +143,131 @@ def to_disease_target_evidence( ) ) - def add_locus_to_gene_features( + def explain( + self: L2GPrediction, feature_matrix: L2GFeatureMatrix | None = None + ) -> L2GPrediction: + """Extract Shapley values for the L2G predictions and add them as a map in an additional column. + + Args: + feature_matrix (L2GFeatureMatrix | None): Feature matrix in case the predictions are missing the feature annotation. If None, the features are fetched from the dataset. + + Returns: + L2GPrediction: L2GPrediction object with additional column containing feature name to Shapley value mappings + + Raises: + ValueError: If the model is not set or If feature matrix is not provided and the predictions do not have features + """ + # Fetch features if they are not present: + if "features" not in self.df.columns: + if feature_matrix is None: + raise ValueError( + "Feature matrix is required to explain the L2G predictions" + ) + self.add_features(feature_matrix) + + if self.model is None: + raise ValueError("Model not set, explainer cannot be created") + + # Format and pivot the dataframe to pass them before calculating shapley values + pdf = pivot_df( + df=self.df.withColumn("feature", f.explode("features")).select( + "studyLocusId", + "geneId", + "score", + f.col("feature.name").alias("feature_name"), + f.col("feature.value").alias("feature_value"), + ), + pivot_col="feature_name", + value_col="feature_value", + grouping_cols=[f.col("studyLocusId"), f.col("geneId"), f.col("score")], + ).toPandas() + pdf = pdf.rename( + # trim the suffix that is added after pivoting the df + columns={ + col: col.replace("_feature_value", "") + for col in pdf.columns + if col.endswith("_feature_value") + } + ) + + features_list = self.model.features_list # The matrix needs to present the features in the same order that the model was trained on) + base_value, shap_values = L2GPrediction._explain( + model=self.model, + pdf=pdf.filter(items=features_list), + ) + for i, feature in enumerate(features_list): + pdf[f"shap_{feature}"] = [row[i] for row in shap_values] + + spark_session = self.df.sparkSession + return L2GPrediction( + _df=( + spark_session.createDataFrame(pdf.to_dict(orient="records")) + .withColumn( + "features", + f.array( + *( + f.struct( + f.lit(feature).alias("name"), + f.col(feature).cast("float").alias("value"), + f.col(f"shap_{feature}") + .cast("float") + .alias("shapValue"), + ) + for feature in features_list + ) + ), + ) + .withColumn("shapBaseValue", f.lit(base_value).cast("float")) + .select(*L2GPrediction.get_schema().names) + ), + _schema=self.get_schema(), + model=self.model, + ) + + @staticmethod + def _explain( + model: LocusToGeneModel, pdf: pd_dataframe + ) -> tuple[float, list[list[float]]]: + """Calculate SHAP values. Output is in probability form (approximated from the log odds ratios). + + Args: + model (LocusToGeneModel): L2G model + pdf (pd_dataframe): Pandas dataframe containing the feature matrix in the same order that the model was trained on + + Returns: + tuple[float, list[list[float]]]: A tuple containing: + - base_value (float): Base value of the model + - shap_values (list[list[float]]): SHAP values for prediction + + Raises: + AttributeError: If model.training_data is not set, seed dataset to get shapley values cannot be created. + """ + if not model.training_data: + raise AttributeError( + "`model.training_data` is missing, seed dataset to get shapley values cannot be created." + ) + background_data = ( + model.training_data._df.select(*model.features_list) + .toPandas() + .sample(n=1_000) + ) + explainer = shap.TreeExplainer( + model.model, + data=background_data, + model_output="probability", + ) + if pdf.shape[0] >= 10_000: + logging.warning( + "Calculating SHAP values for more than 10,000 rows. This may take a while..." + ) + shap_values = explainer.shap_values( + pdf.to_numpy(), + check_additivity=False, + ) + base_value = explainer.expected_value + return (base_value, shap_values) + + def add_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix, ) -> L2GPrediction: @@ -137,41 +277,30 @@ def add_locus_to_gene_features( feature_matrix (L2GFeatureMatrix): Feature matrix dataset Returns: - L2GPrediction: L2G predictions with additional features + L2GPrediction: L2G predictions with additional column `features` Raises: ValueError: If model is not set, feature list won't be available """ if self.model is None: raise ValueError("Model not set, feature annotation cannot be created.") - # Testing if `locusToGeneFeatures` column already exists: - if "locusToGeneFeatures" in self.df.columns: - self.df = self.df.drop("locusToGeneFeatures") - - # Aggregating all features into a single map column: - aggregated_features = ( - feature_matrix._df.withColumn( - "locusToGeneFeatures", - f.create_map( - *sum( - ( - (f.lit(feature), f.col(feature)) - for feature in self.model.features_list - ), - (), - ) - ), - ) - .withColumn( - "locusToGeneFeatures", - f.expr("map_filter(locusToGeneFeatures, (k, v) -> v != 0)"), - ) - .drop(*self.model.features_list) - ) - return L2GPrediction( - _df=self.df.join( - aggregated_features, on=["studyLocusId", "geneId"], how="left" - ), - _schema=self.get_schema(), - model=self.model, + # Testing if `features` column already exists: + if "features" in self.df.columns: + self.df = self.df.drop("features") + + features_list = self.model.features_list + feature_expressions = [ + f.struct(f.lit(col).alias("name"), f.col(col).alias("value")) + for col in features_list + ] + self.df = self.df.join( + feature_matrix._df.select(*features_list, "studyLocusId", "geneId"), + on=["studyLocusId", "geneId"], + how="left", + ).select( + "studyLocusId", + "geneId", + "score", + f.array(*feature_expressions).alias("features"), ) + return self diff --git a/src/gentropy/dataset/pairwise_ld.py b/src/gentropy/dataset/pairwise_ld.py index ab68a74ab..6db570ba9 100644 --- a/src/gentropy/dataset/pairwise_ld.py +++ b/src/gentropy/dataset/pairwise_ld.py @@ -30,6 +30,9 @@ def __post_init__(self: PairwiseLD) -> None: """Validating the dataset upon creation. - Besides the schema, a pairwise LD table is expected have rows being a square number. + + Raises: + AssertionError: When the number of rows in the provided dataframe to construct the LD matrix is not even after applying square root. """ row_count = self.df.count() diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index 30b91541b..b829990ce 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -433,7 +433,8 @@ def _qc_subsignificant_associations( def qc_abnormal_pips( self: StudyLocus, sum_pips_lower_threshold: float = 0.99, - sum_pips_upper_threshold: float = 1.0001, # Set slightly above 1 to account for floating point errors + # Set slightly above 1 to account for floating point errors + sum_pips_upper_threshold: float = 1.0001, ) -> StudyLocus: """Filter study-locus by sum of posterior inclusion probabilities to ensure that the sum of PIPs is within a given range. @@ -691,6 +692,7 @@ def flag_trans_qtls( """Flagging transQTL credible sets based on genomic location of the measured gene. Process: + 0. Make sure that the `isTransQtl` column does not exist (remove if exists) 1. Enrich study-locus dataset with geneId based on study metadata. (only QTL studies are considered) 2. Enrich with transcription start site and chromosome of the studied gegne. 3. Flagging any tagging variant of QTL credible sets, if chromosome is different from the gene or distance is above the threshold. @@ -709,6 +711,12 @@ def flag_trans_qtls( if "geneId" not in study_index.df.columns: return self + # We have to remove the column `isTransQtl` to ensure the column is not duplicated + # The duplication can happen when one reads the StudyLocus from parquet with + # predefined schema that already contains the `isTransQtl` column. + if "isTransQtl" in self.df.columns: + self.df = self.df.drop("isTransQtl") + # Process study index: processed_studies = ( study_index.df diff --git a/src/gentropy/dataset/variant_index.py b/src/gentropy/dataset/variant_index.py index 6622c822c..a44356ca6 100644 --- a/src/gentropy/dataset/variant_index.py +++ b/src/gentropy/dataset/variant_index.py @@ -130,7 +130,7 @@ def add_annotation( """Import annotation from an other variant index dataset. At this point the annotation can be extended with extra cross-references, - in-silico predictions and allele frequencies. + variant effects, allele frequencies, and variant descriptions. Args: annotation_source (VariantIndex): Annotation to add to the dataset @@ -168,7 +168,12 @@ def add_annotation( f.col(column), f.col(f"{prefix}{column}"), fields_order ).alias(column) ) - # Non-array columns are coalesced: + # variantDescription columns are concatenated: + elif column == "variantDescription": + select_expressions.append( + f.concat_ws(" ", f.col(column), f.col(f"{prefix}{column}")).alias(column) + ) + # All other non-array columns are coalesced: else: select_expressions.append( f.coalesce(f.col(column), f.col(f"{prefix}{column}")).alias( @@ -222,10 +227,13 @@ def filter_by_variant(self: VariantIndex, df: DataFrame) -> VariantIndex: """Filter variant annotation dataset by a variant dataframe. Args: - df (DataFrame): A dataframe of variants + df (DataFrame): A dataframe of variants. Returns: - VariantIndex: A filtered variant annotation dataset + VariantIndex: A filtered variant annotation dataset. + + Raises: + AssertionError: When the variant dataframe does not contain eiter `variantId` or `chromosome` column. """ join_columns = ["variantId", "chromosome"] @@ -279,7 +287,7 @@ def get_distance_to_gene( def annotate_with_amino_acid_consequences( self: VariantIndex, annotation: AminoAcidVariants ) -> VariantIndex: - """Enriching in silico predictors with amino-acid derived predicted consequences. + """Enriching variant effect assessments with amino-acid derived predicted consequences. Args: annotation (AminoAcidVariants): amio-acid level variant consequences. @@ -287,7 +295,7 @@ def annotate_with_amino_acid_consequences( Returns: VariantIndex: where amino-acid causing variants are enriched with extra annotation """ - w = Window.partitionBy("variantId").orderBy(f.size("inSilicoPredictors").desc()) + w = Window.partitionBy("variantId").orderBy(f.size("variantEffect").desc()) return VariantIndex( _df=self.df @@ -308,17 +316,17 @@ def annotate_with_amino_acid_consequences( ) # Joining with amino-acid predictions: .join( - annotation.df.withColumnRenamed("inSilicoPredictors", "annotations"), + annotation.df.withColumnRenamed("variantEffect", "annotations"), on=["uniprotAccession", "aminoAcidChange"], how="left", ) # Merge predictors: .withColumn( - "inSilicoPredictors", + "variantEffect", f.when( f.col("annotations").isNotNull(), - f.array_union("inSilicoPredictors", "annotations"), - ).otherwise(f.col("inSilicoPredictors")), + f.array_union("variantEffect", "annotations"), + ).otherwise(f.col("variantEffect")), ) # Dropping unused columns: .drop("uniprotAccession", "aminoAcidChange", "annotations") @@ -356,33 +364,33 @@ def get_loftee(self: VariantIndex) -> DataFrame: ) -class InSilicoPredictorNormaliser: - """Class to normalise in silico predictor assessments. +class VariantEffectNormaliser: + """Class to normalise variant effect assessments. Essentially based on the raw scores, it normalises the scores to a range between -1 and 1, and appends the normalised - value to the in silico predictor struct. + value to the variant effect struct. The higher negative values indicate increasingly confident prediction to be a benign variant, while the higher positive values indicate increasingly deleterious predicted effect. - The point of these operations to make the scores comparable across different in silico predictors. + The point of these operations to make the scores comparable across different variant effect assessments. """ @classmethod - def normalise_in_silico_predictors( - cls: type[InSilicoPredictorNormaliser], - in_silico_predictors: Column, + def normalise_variant_effect( + cls: type[VariantEffectNormaliser], + variant_effect: Column, ) -> Column: - """Normalise in silico predictors. Appends a normalised score to the in silico predictor struct. + """Normalise variant effect assessments. Appends a normalised score to the variant effect struct. Args: - in_silico_predictors (Column): Column containing in silico predictors (list of structs). + variant_effect (Column): Column containing variant effect assessments (list of structs). Returns: - Column: Normalised in silico predictors. + Column: Normalised variant effect assessments. """ return f.transform( - in_silico_predictors, + variant_effect, lambda predictor: f.struct( # Extracing all existing columns: predictor.method.alias("method"), @@ -399,20 +407,20 @@ def normalise_in_silico_predictors( @classmethod def resolve_predictor_methods( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], score: Column, method: Column, assessment: Column, ) -> Column: - """It takes a score, a method, and an assessment, and returns a normalized score for the in silico predictor. + """It takes a score, a method, and an assessment, and returns a normalized score for the variant effect. Args: - score (Column): The raw score from the in silico predictor. + score (Column): The raw score from the variant effect. method (Column): The method used to generate the score. assessment (Column): The assessment of the score. Returns: - Column: Normalised score for the in silico predictor. + Column: Normalised score for the variant effect. """ return ( f.when(method == "LOFTEE", cls._normalise_loftee(assessment)) @@ -421,6 +429,7 @@ def resolve_predictor_methods( .when(method == "AlphaMissense", cls._normalise_alpha_missense(score)) .when(method == "CADD", cls._normalise_cadd(score)) .when(method == "Pangolin", cls._normalise_pangolin(score)) + .when(method == "LossOfFunctionCuration", cls._normalise_lof(assessment)) # The following predictors are not normalised: .when(method == "SpliceAI", score) .when(method == "VEP", score) @@ -454,7 +463,7 @@ def _rescaleColumnValue( @classmethod def _normalise_foldx( - cls: type[InSilicoPredictorNormaliser], score: Column + cls: type[VariantEffectNormaliser], score: Column ) -> Column: """Normalise FoldX ddG energies. @@ -477,7 +486,7 @@ def _normalise_foldx( @classmethod def _normalise_cadd( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], score: Column, ) -> Column: """Normalise CADD scores. @@ -503,7 +512,7 @@ def _normalise_cadd( @classmethod def _normalise_gerp( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], score: Column, ) -> Column: """Normalise GERP scores. @@ -533,9 +542,38 @@ def _normalise_gerp( .when(score < -3, f.lit(-1.0)) ) + @classmethod + def _normalise_lof( + cls: type[VariantEffectNormaliser], + assessment: Column, + ) -> Column: + """Normalise loss-of-function verdicts. + + There are five ordinal verdicts. + The normalised score is determined by the verdict: + - lof: 1 + - likely_lof: 0.5 + - uncertain: 0 + - likely_not_lof: -0.5 + - not_lof: -1 + + Args: + assessment (Column): Loss-of-function assessment. + + Returns: + Column: Normalised loss-of-function score. + """ + return ( + f.when(assessment == "lof", f.lit(1)) + .when(assessment == "likely_lof", f.lit(0.5)) + .when(assessment == "uncertain", f.lit(0)) + .when(assessment == "likely_not_lof", f.lit(-0.5)) + .when(assessment == "not_lof", f.lit(-1)) + ) + @classmethod def _normalise_loftee( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], assessment: Column, ) -> Column: """Normalise LOFTEE scores. @@ -557,7 +595,7 @@ def _normalise_loftee( @classmethod def _normalise_sift( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], score: Column, assessment: Column, ) -> Column: @@ -601,7 +639,7 @@ def _normalise_sift( @classmethod def _normalise_polyphen( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], assessment: Column, score: Column, ) -> Column: @@ -632,7 +670,7 @@ def _normalise_polyphen( @classmethod def _normalise_alpha_missense( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], score: Column, ) -> Column: """Normalise AlphaMissense scores. @@ -656,7 +694,7 @@ def _normalise_alpha_missense( @classmethod def _normalise_pangolin( - cls: type[InSilicoPredictorNormaliser], + cls: type[VariantEffectNormaliser], score: Column, ) -> Column: """Normalise Pangolin scores. diff --git a/src/gentropy/datasource/ensembl/vep_parser.py b/src/gentropy/datasource/ensembl/vep_parser.py index c5c985f3f..0aad3b01f 100644 --- a/src/gentropy/datasource/ensembl/vep_parser.py +++ b/src/gentropy/datasource/ensembl/vep_parser.py @@ -16,7 +16,7 @@ order_array_of_structs_by_field, order_array_of_structs_by_two_fields, ) -from gentropy.dataset.variant_index import InSilicoPredictorNormaliser, VariantIndex +from gentropy.dataset.variant_index import VariantEffectNormaliser, VariantIndex if TYPE_CHECKING: from pyspark.sql import Column, DataFrame @@ -33,9 +33,9 @@ class VariantEffectPredictorParser: DBXREF_SCHEMA = VariantIndex.get_schema()["dbXrefs"].dataType - # Schema description of the in silico predictor object: - IN_SILICO_PREDICTOR_SCHEMA = get_nested_struct_schema( - VariantIndex.get_schema()["inSilicoPredictors"] + # Schema description of the variant effect object: + VARIANT_EFFECT_SCHEMA = get_nested_struct_schema( + VariantIndex.get_schema()["variantEffect"] ) # Schema for the allele frequency column: @@ -325,6 +325,9 @@ def _get_most_severe_transcript( |{0.6, transcript3} | +----------------------+ + + Raises: + AssertionError: When `transcript_column_name` is not a string. """ assert isinstance( transcript_column_name, str @@ -341,7 +344,7 @@ def _get_most_severe_transcript( )[0] @classmethod - @enforce_schema(IN_SILICO_PREDICTOR_SCHEMA) + @enforce_schema(VARIANT_EFFECT_SCHEMA) def _get_vep_prediction(cls, most_severe_consequence: Column) -> Column: return f.struct( f.lit("VEP").alias("method"), @@ -352,7 +355,7 @@ def _get_vep_prediction(cls, most_severe_consequence: Column) -> Column: ) @staticmethod - @enforce_schema(IN_SILICO_PREDICTOR_SCHEMA) + @enforce_schema(VARIANT_EFFECT_SCHEMA) def _get_max_alpha_missense(transcripts: Column) -> Column: """Return the most severe alpha missense prediction from all transcripts. @@ -410,8 +413,8 @@ def _get_max_alpha_missense(transcripts: Column) -> Column: ) @classmethod - @enforce_schema(IN_SILICO_PREDICTOR_SCHEMA) - def _vep_in_silico_prediction_extractor( + @enforce_schema(VARIANT_EFFECT_SCHEMA) + def _vep_variant_effect_extractor( cls: type[VariantEffectPredictorParser], transcript_column_name: str, method_name: str, @@ -419,17 +422,17 @@ def _vep_in_silico_prediction_extractor( assessment_column_name: str | None = None, assessment_flag_column_name: str | None = None, ) -> Column: - """Extract in silico prediction from VEP output. + """Extract variant effect from VEP output. Args: transcript_column_name (str): Name of the column containing the list of transcripts. - method_name (str): Name of the in silico predictor. + method_name (str): Name of the variant effect. score_column_name (str | None): Name of the column containing the score. assessment_column_name (str | None): Name of the column containing the assessment. assessment_flag_column_name (str | None): Name of the column containing the assessment flag. Returns: - Column: In silico predictor. + Column: Variant effect. """ # Get transcript with the highest score: most_severe_transcript: Column = ( @@ -634,34 +637,34 @@ def process_vep_output( cls._extract_clinvar_xrefs(f.col("colocated_variants")).alias( "clinvar_xrefs" ), - # Extracting in silico predictors + # Extracting variant effect assessments f.when( - # The following in-silico predictors are only available for variants with transcript consequences: + # The following variant effect assessments are only available for variants with transcript consequences: f.col("transcript_consequences").isNotNull(), f.filter( f.array( # Extract CADD scores: - cls._vep_in_silico_prediction_extractor( + cls._vep_variant_effect_extractor( transcript_column_name="transcript_consequences", method_name="CADD", score_column_name="cadd_phred", ), # Extract polyphen scores: - cls._vep_in_silico_prediction_extractor( + cls._vep_variant_effect_extractor( transcript_column_name="transcript_consequences", method_name="PolyPhen", score_column_name="polyphen_score", assessment_column_name="polyphen_prediction", ), # Extract sift scores: - cls._vep_in_silico_prediction_extractor( + cls._vep_variant_effect_extractor( transcript_column_name="transcript_consequences", method_name="SIFT", score_column_name="sift_score", assessment_column_name="sift_prediction", ), # Extract loftee scores: - cls._vep_in_silico_prediction_extractor( + cls._vep_variant_effect_extractor( method_name="LOFTEE", transcript_column_name="transcript_consequences", score_column_name="lof", @@ -669,7 +672,7 @@ def process_vep_output( assessment_flag_column_name="lof_filter", ), # Extract GERP conservation score: - cls._vep_in_silico_prediction_extractor( + cls._vep_variant_effect_extractor( method_name="GERP", transcript_column_name="transcript_consequences", score_column_name="conservation", @@ -685,24 +688,12 @@ def process_vep_output( ), ) .otherwise( - # Extract CADD scores from intergenic object: f.array( - cls._vep_in_silico_prediction_extractor( - transcript_column_name="intergenic_consequences", - method_name="CADD", - score_column_name="cadd_phred", - ), - # Extract GERP conservation score: - cls._vep_in_silico_prediction_extractor( - method_name="GERP", - transcript_column_name="intergenic_consequences", - score_column_name="conservation", - ), # Extract VEP prediction: cls._get_vep_prediction(f.col("most_severe_consequence")), ) ) - .alias("inSilicoPredictors"), + .alias("variantEffect"), # Convert consequence to SO: map_column_by_dictionary( f.col("most_severe_consequence"), cls.SEQUENCE_ONTOLOGY_MAP @@ -882,11 +873,11 @@ def process_vep_output( )[f.size("proteinCodingTranscripts") - 1], ), ) - # Normalising in silico predictor assessments: + # Normalising variant effect assessments: .withColumn( - "inSilicoPredictors", - InSilicoPredictorNormaliser.normalise_in_silico_predictors( - f.col("inSilicoPredictors") + "variantEffect", + VariantEffectNormaliser.normalise_variant_effect( + f.col("variantEffect") ), ) # Dropping intermediate xref columns: diff --git a/src/gentropy/datasource/open_targets/foldex_integration.py b/src/gentropy/datasource/open_targets/foldex_integration.py index 5354fc7b6..8f71f56ae 100644 --- a/src/gentropy/datasource/open_targets/foldex_integration.py +++ b/src/gentropy/datasource/open_targets/foldex_integration.py @@ -8,26 +8,26 @@ from gentropy.common.spark_helpers import enforce_schema from gentropy.dataset.amino_acid_variants import AminoAcidVariants -from gentropy.dataset.variant_index import InSilicoPredictorNormaliser +from gentropy.dataset.variant_index import VariantEffectNormaliser class OpenTargetsFoldX: """Class to parser FoldX dataset generated by the OTAR2081 project.""" - INSILICO_SCHEMA = AminoAcidVariants.get_schema()[ - "inSilicoPredictors" + VARIANT_EFFECT_SCHEMA = AminoAcidVariants.get_schema()[ + "variantEffect" ].dataType.elementType @staticmethod - @enforce_schema(INSILICO_SCHEMA) + @enforce_schema(VARIANT_EFFECT_SCHEMA) def get_foldx_prediction(score_column: Column) -> Column: - """Generate inSilicoPredictor object from ddG column. + """Generate variantEffect object from ddG column. Args: score_column (Column): ddG column from the FoldX dataset. Returns: - Column: struct with the right shape of the in silico predictors. + Column: struct with the right shape of the variantEffect field. """ return f.struct( f.lit("FoldX").alias("method"), @@ -58,21 +58,21 @@ def ingest_foldx_data( f.col("wild_type"), f.col("position"), f.col("mutated_type") ).alias("aminoAcidChange"), cls.get_foldx_prediction(f.col("foldx_ddg")).alias( - "inSilicoPredictor" + "foldx_prediction" ), ) # Collapse all predictors for a single array object to avoid variant explosions: .groupBy("uniprotAccession", "aminoAcidChange") .agg( - f.collect_set(f.col("inSilicoPredictor")).alias( - "inSilicoPredictors" + f.collect_set(f.col("foldx_prediction")).alias( + "variantEffect" ) ) # Normalise FoldX free energy changes: .withColumn( - "inSilicoPredictors", - InSilicoPredictorNormaliser.normalise_in_silico_predictors( - f.col("inSilicoPredictors") + "variantEffect", + VariantEffectNormaliser.normalise_variant_effect( + f.col("variantEffect") ), ) ), diff --git a/src/gentropy/datasource/open_targets/lof_curation.py b/src/gentropy/datasource/open_targets/lof_curation.py new file mode 100644 index 000000000..67edcfdd7 --- /dev/null +++ b/src/gentropy/datasource/open_targets/lof_curation.py @@ -0,0 +1,98 @@ +"""Parser for Loss-of-Function variant data from Open Targets Project OTAR2075.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pyspark.sql.functions as f +import pyspark.sql.types as t + +from gentropy.common.spark_helpers import enforce_schema +from gentropy.dataset.variant_index import VariantEffectNormaliser, VariantIndex + +if TYPE_CHECKING: + from pyspark.sql import Column, DataFrame + + +class OpenTargetsLOF: + """Class to parse Loss-of-Function variant data from Open Targets Project OTAR2075.""" + + VARIANT_EFFECT_SCHEMA = VariantIndex.get_schema()[ + "variantEffect" + ].dataType.elementType + + @staticmethod + @enforce_schema(VARIANT_EFFECT_SCHEMA) + def _get_lof_assessment(verdict: Column) -> Column: + """Get curated Loss-of-Function assessment from verdict column. + + Args: + verdict (Column): verdict column from the input dataset. + + Returns: + Column: struct following the variant effect schema. + """ + return f.struct( + f.lit("LossOfFunctionCuration").alias("method"), + verdict.alias("assessment"), + ) + + @staticmethod + def _compose_lof_description(verdict: Column) -> Column: + """Compose variant description based on loss-of-function assessment. + + Args: + verdict (Column): verdict column from the input dataset. + + Returns: + Column: variant description. + """ + lof_description = ( + f.when(verdict == "lof", f.lit("Assessed to cause LoF")) + .when(verdict == "likely_lof", f.lit("Suspected to cause LoF")) + .when(verdict == "uncertain", f.lit("Uncertain LoF assessment")) + .when(verdict == "likely_not_lof", f.lit("Suspected not to cause LoF")) + .when(verdict == "not_lof", f.lit("Assessed not to cause LoF")) + ) + + return f.concat(lof_description, f.lit(" by OTAR2075 variant curation effort.")) + + @classmethod + def as_variant_index( + cls: type[OpenTargetsLOF], + lof_dataset: DataFrame + ) -> VariantIndex: + """Ingest Loss-of-Function information as a VariantIndex object. + + Args: + lof_dataset (DataFrame): curated input dataset from OTAR2075. + + Returns: + VariantIndex: variant annotations with loss-of-function assessments. + """ + return VariantIndex( + _df=( + lof_dataset + .select( + f.from_csv(f.col("Variant ID GRCh37"), "chr string, pos string, ref string, alt string", {"sep": "-"}).alias("h37"), + f.from_csv(f.col("Variant ID GRCh38"), "chr string, pos string, ref string, alt string", {"sep": "-"}).alias("h38"), + "Verdict" + ) + .select( + # As some GRCh37 variants do not correctly lift over to the correct GRCh38 variant, + # chr_pos is taken from the GRCh38 variant id, and ref_alt from the GRCh37 variant id + f.concat_ws("_", f.col("h38.chr"), f.col("h38.pos"), f.col("h37.ref"), f.col("h37.alt")).alias("variantId"), + # Mandatory fields for VariantIndex: + f.col("h38.chr").alias("chromosome"), + f.col("h38.pos").cast(t.IntegerType()).alias("position"), + f.col("h37.ref").alias("referenceAllele"), + f.col("h37.alt").alias("alternateAllele"), + # Populate variantEffect and variantDescription fields: + f.array(cls._get_lof_assessment(f.col("Verdict"))).alias("variantEffect"), + cls._compose_lof_description(f.col("Verdict")).alias("variantDescription"), + ) + # Convert assessments to normalised scores: + .withColumn("variantEffect", VariantEffectNormaliser.normalise_variant_effect(f.col("variantEffect"))) + ), + _schema=VariantIndex.get_schema(), + ) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 5f22471e3..4c3d6c867 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -151,7 +151,6 @@ def __init__( self.session = session self.run_mode = run_mode - self.model_path = model_path self.predictions_path = predictions_path self.features_list = list(features_list) if features_list else None self.hyperparameters = dict(hyperparameters) @@ -164,6 +163,11 @@ def __init__( self.gold_standard_curation_path = gold_standard_curation_path self.gene_interactions_path = gene_interactions_path self.variant_index_path = variant_index_path + self.model_path = ( + hf_hub_repo_id + if not model_path and download_from_hub and hf_hub_repo_id + else model_path + ) # Load common inputs self.credible_set = StudyLocus.from_parquet( @@ -284,14 +288,13 @@ def run_predict(self) -> None: self.credible_set, self.feature_matrix, model_path=self.model_path, + features_list=self.features_list, hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, ) - predictions.filter( - f.col("score") >= self.l2g_threshold - ).add_locus_to_gene_features( + predictions.filter(f.col("score") >= self.l2g_threshold).add_features( self.feature_matrix, - ).df.coalesce(self.session.output_partitions).write.mode( + ).explain().df.coalesce(self.session.output_partitions).write.mode( self.session.write_mode ).parquet(self.predictions_path) self.session.logger.info("L2G predictions saved successfully.") @@ -331,12 +334,10 @@ def run_train(self) -> None: "hfhub-key", "open-targets-genetics-dev" ) trained_model.export_to_hugging_face_hub( - # we upload the model in the filesystem + # we upload the model saved in the filesystem self.model_path.split("/")[-1], hf_hub_token, - data=trained_model.training_data._df.drop( - "goldStandardSet", "geneId" - ).toPandas(), + data=trained_model.training_data._df.toPandas(), repo_id=self.hf_hub_repo_id, commit_message=self.hf_model_commit_message, ) diff --git a/src/gentropy/lof_curation_ingestion.py b/src/gentropy/lof_curation_ingestion.py new file mode 100644 index 000000000..a2d66bc3c --- /dev/null +++ b/src/gentropy/lof_curation_ingestion.py @@ -0,0 +1,38 @@ +"""Ingest Loss-of-Function variant data generated by OTAR2075.""" + +from gentropy.common.session import Session +from gentropy.datasource.open_targets.lof_curation import OpenTargetsLOF + + +class LOFIngestionStep: + """Step to ingest the Loss-of-Function dataset generated by OTAR2075.""" + + def __init__( + self, + session: Session, + lof_curation_dataset_path: str, + lof_curation_variant_annotations_path: str, + ) -> None: + """Initialize step. + + Args: + session (Session): Session object. + lof_curation_dataset_path (str): path of the curated LOF dataset. + lof_curation_variant_annotations_path (str): path of the resulting variant annotations. + """ + # Read in data: + lof_dataset = session.spark.read.csv( + lof_curation_dataset_path, + sep=",", + header=True, + multiLine=True + ) + # Extract relevant information to a VariantIndex + lof_variant_annotations = OpenTargetsLOF.as_variant_index(lof_dataset) + # Write to file: + ( + lof_variant_annotations.df + .coalesce(session.output_partitions) + .write.mode(session.write_mode) + .parquet(lof_curation_variant_annotations_path) + ) diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index b45b91920..3f36ab4db 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -216,7 +216,7 @@ class Coloc(ColocalisationMethodInterface): METHOD_METRIC: str = "h4" PSEUDOCOUNT: float = 1e-10 OVERLAP_SIZE_CUTOFF: int = 5 - POSTERIOR_CUTOFF: float = 0.5 + POSTERIOR_CUTOFF: float = 0.1 @staticmethod def _get_posteriors(all_bfs: NDArray[np.float64]) -> DenseVector: diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index 9d9011332..c695f8a68 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import logging from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any @@ -16,9 +17,9 @@ from gentropy.common.session import Session from gentropy.common.utils import copy_to_gcs +from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix if TYPE_CHECKING: - from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_prediction import L2GPrediction @@ -27,9 +28,7 @@ class LocusToGeneModel: """Wrapper for the Locus to Gene classifier.""" model: Any = GradientBoostingClassifier(random_state=42) - features_list: list[str] = field( - default_factory=list - ) # TODO: default to list in config if not provided + features_list: list[str] = field(default_factory=list) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, @@ -55,12 +54,18 @@ def __post_init__(self: LocusToGeneModel) -> None: @classmethod def load_from_disk( - cls: type[LocusToGeneModel], path: str, **kwargs: Any + cls: type[LocusToGeneModel], + session: Session, + path: str, + model_name: str = "classifier.skops", + **kwargs: Any, ) -> LocusToGeneModel: """Load a fitted model from disk. Args: - path (str): Path to the model + session (Session): Session object that loads the training data + path (str): Path to the directory containing model and metadata + model_name (str): Name of the persisted model to load. Defaults to "classifier.skops". **kwargs(Any): Keyword arguments to pass to the constructor Returns: @@ -69,8 +74,9 @@ def load_from_disk( Raises: ValueError: If the model has not been fitted yet """ - if path.startswith("gs://"): - path = path.removeprefix("gs://") + model_path = (Path(path) / model_name).as_posix() + if model_path.startswith("gs://"): + path = model_path.removeprefix("gs://") bucket_name = path.split("/")[0] blob_name = "/".join(path.split("/")[1:]) from google.cloud import storage @@ -81,25 +87,41 @@ def load_from_disk( data = blob.download_as_string(client=client) loaded_model = sio.loads(data, trusted=sio.get_untrusted_types(data=data)) else: - loaded_model = sio.load(path, trusted=sio.get_untrusted_types(file=path)) + loaded_model = sio.load( + model_path, trusted=sio.get_untrusted_types(file=model_path) + ) + try: + # Try loading the training data if it is in the model directory + training_data = L2GFeatureMatrix( + _df=session.spark.createDataFrame( + # Parquet is read with Pandas to easily read local files + pd.read_parquet( + (Path(path) / "training_data.parquet").as_posix() + ) + ), + features_list=kwargs.get("features_list"), + ) + except Exception as e: + logging.error("Training data set to none. Error: %s", e) + training_data = None if not loaded_model._is_fitted(): raise ValueError("Model has not been fitted yet.") - return cls(model=loaded_model, **kwargs) + return cls(model=loaded_model, training_data=training_data, **kwargs) @classmethod def load_from_hub( cls: type[LocusToGeneModel], + session: Session, model_id: str, hf_token: str | None = None, - model_name: str = "classifier.skops", ) -> LocusToGeneModel: """Load a model from the Hugging Face Hub. This will download the model from the hub and load it from disk. Args: + session (Session): Session object to load the training data model_id (str): Model ID on the Hugging Face Hub hf_token (str | None): Hugging Face Hub token to download the model (only required if private) - model_name (str): Name of the persisted model to load. Defaults to "classifier.skops". Returns: LocusToGeneModel: L2G model loaded from the Hugging Face Hub @@ -119,14 +141,22 @@ def get_features_list_from_metadata() -> list[str]: return [ column for column in model_config["sklearn"]["columns"] - if column != "studyLocusId" + if column + not in [ + "studyLocusId", + "geneId", + "traitFromSourceMappedId", + "goldStandardSet", + ] ] - local_path = Path(model_id) + local_path = model_id hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token) features_list = get_features_list_from_metadata() return cls.load_from_disk( - str(Path(local_path) / model_name), features_list=features_list + session, + local_path, + features_list=features_list, ) @property @@ -196,6 +226,8 @@ def save(self: LocusToGeneModel, path: str) -> None: sio.dump(self.model, local_path) copy_to_gcs(local_path, path) else: + # create directory if path does not exist + Path(path).parent.mkdir(parents=True, exist_ok=True) sio.dump(self.model, path) @staticmethod @@ -231,7 +263,6 @@ def _create_hugging_face_model_card( - Distance: (from credible set variants to gene) - Molecular QTL Colocalization - - Chromatin Interaction: (e.g., promoter-capture Hi-C) - Variant Pathogenicity: (from VEP) More information at: https://opentargets.github.io/gentropy/python_api/methods/l2g/_l2g/ @@ -270,7 +301,7 @@ def export_to_hugging_face_hub( repo_id: str = "opentargets/locus_to_gene", local_repo: str = "locus_to_gene", ) -> None: - """Share the model on Hugging Face Hub. + """Share the model and training dataset on Hugging Face Hub. Args: model_path (str): The path to the L2G model file. @@ -294,6 +325,7 @@ def export_to_hugging_face_hub( data=data, ) self._create_hugging_face_model_card(local_repo) + data.to_parquet(f"{local_repo}/training_data.parquet") hub_utils.push( repo_id=repo_id, source=local_repo, diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index a123cfda9..3288d9ea8 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -88,15 +88,16 @@ def fit( Raises: ValueError: Train data not set, nothing to fit. + AssertionError: When x_train_size or y_train_size are not zero. """ if ( self.x_train is not None and self.y_train is not None and self.features_list is not None ): - assert self.x_train.size != 0 and self.y_train.size != 0, ( - "Train data not set, nothing to fit." - ) + assert ( + self.x_train.size != 0 and self.y_train.size != 0 + ), "Train data not set, nothing to fit." fitted_model = self.model.model.fit(X=self.x_train, y=self.y_train) self.model = LocusToGeneModel( model=fitted_model, @@ -111,7 +112,7 @@ def _get_shap_explanation( self: LocusToGeneTrainer, model: LocusToGeneModel, ) -> Explanation: - """Get the SHAP values for the given model and data. We pass the full X matrix (without the labels) to interpret their shap values. + """Get the SHAP values for the given model and data. We sample the full X matrix (without the labels) to interpret their shap values. Args: model (LocusToGeneModel): Model to explain. @@ -132,12 +133,15 @@ def _get_shap_explanation( model.model, data=training_data, feature_perturbation="interventional", + model_output="probability", ) try: - return explainer(training_data) + return explainer(training_data.sample(n=1_000)) except Exception as e: if "Additivity check failed in TreeExplainer" in repr(e): - return explainer(training_data, check_additivity=False) + return explainer( + training_data.sample(n=1_000), check_additivity=False + ) else: raise @@ -180,6 +184,7 @@ def log_to_wandb( Raises: RuntimeError: If dependencies are not available. + AssertionError: When x_train_size or y_train_size are not zero. """ if ( self.x_train is None @@ -189,9 +194,9 @@ def log_to_wandb( or self.features_list is None ): raise RuntimeError("Train data not set, we cannot log to W&B.") - assert self.x_train.size != 0 and self.y_train.size != 0, ( - "Train data not set, nothing to evaluate." - ) + assert ( + self.x_train.size != 0 and self.y_train.size != 0 + ), "Train data not set, nothing to evaluate." fitted_classifier = self.model.model y_predicted = fitted_classifier.predict(self.x_test) y_probas = fitted_classifier.predict_proba(self.x_test) diff --git a/src/gentropy/method/susie_inf.py b/src/gentropy/method/susie_inf.py index e8a4a57b1..c53d6a939 100644 --- a/src/gentropy/method/susie_inf.py +++ b/src/gentropy/method/susie_inf.py @@ -493,6 +493,9 @@ def credible_set_qc( Returns: StudyLocus: Credible sets which pass filters and LD clumping. + + Raises: + AssertionError: When running in clump mode, but no study study_index or ld_index or ld_min_r2 were provided. """ cred_sets.df = ( cred_sets.df.withColumn( diff --git a/src/gentropy/variant_index.py b/src/gentropy/variant_index.py index 595546f17..2064d7b67 100644 --- a/src/gentropy/variant_index.py +++ b/src/gentropy/variant_index.py @@ -27,7 +27,7 @@ def __init__( vep_output_json_path: str, variant_index_path: str, hash_threshold: int, - gnomad_variant_annotations_path: str | None = None, + variant_annotations_path: list[str] | None = None, amino_acid_change_annotations: list[str] | None = None, ) -> None: """Run VariantIndex step. @@ -37,7 +37,7 @@ def __init__( vep_output_json_path (str): Variant effect predictor output path (in json format). variant_index_path (str): Variant index dataset path to save resulting data. hash_threshold (int): Hash threshold for variant identifier length. - gnomad_variant_annotations_path (str | None): Path to extra variant annotation dataset. + variant_annotations_path (list[str] | None): List of paths to extra variant annotation datasets. amino_acid_change_annotations (list[str] | None): list of paths to amino-acid based variant annotations. """ # Extract variant annotations from VEP output: @@ -46,19 +46,20 @@ def __init__( ) # Process variant annotations if provided: - if gnomad_variant_annotations_path: - # Read variant annotations from parquet: - annotations = VariantIndex.from_parquet( - session=session, - path=gnomad_variant_annotations_path, - recursiveFileLookup=True, - id_threshold=hash_threshold, - ) + if variant_annotations_path: + for annotation_path in variant_annotations_path: + # Read variant annotations from parquet: + annotations = VariantIndex.from_parquet( + session=session, + path=annotation_path, + recursiveFileLookup=True, + id_threshold=hash_threshold, + ) - # Update index with extra annotations: - variant_index = variant_index.add_annotation(annotations) + # Update index with extra annotations: + variant_index = variant_index.add_annotation(annotations) - # If provided read amion-acid based annotation and enrich variant index: + # If provided read amino-acid based annotation and enrich variant index: if amino_acid_change_annotations: for annotation_path in amino_acid_change_annotations: annotation_data = AminoAcidVariants.from_parquet( @@ -105,6 +106,9 @@ def __init__( source_formats (list[str]): Format of the input dataset. output_path (str): Output VCF file path. partition_size (int): Approximate number of variants in each output partition. + + Raises: + AssertionError: When the length of `source_paths` does not match the lenght of `source_formats`. """ assert len(source_formats) == len( source_paths diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 185c0e177..9d436817d 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -306,7 +306,7 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex: # https://github.com/databrickslabs/dbldatagen/issues/135 # It's a workaround for nested column handling in dbldatagen. .withColumnSpec( - "inSilicoPredictors", + "variantEffect", expr=""" array( named_struct( diff --git a/tests/gentropy/dataset/test_colocalisation.py b/tests/gentropy/dataset/test_colocalisation.py index c15653787..c2fbc62b0 100644 --- a/tests/gentropy/dataset/test_colocalisation.py +++ b/tests/gentropy/dataset/test_colocalisation.py @@ -63,7 +63,9 @@ def test_append_study_metadata_right( assert ( observed_df.select(f"{colocalisation_side}GeneId").collect()[0][0] == expected_geneId - ), f"Expected {colocalisation_side}GeneId {expected_geneId}, but got {observed_df.select(f'{colocalisation_side}GeneId').collect()[0][0]}" + ), ( + f"Expected {colocalisation_side}GeneId {expected_geneId}, but got {observed_df.select(f'{colocalisation_side}GeneId').collect()[0][0]}" + ) @pytest.fixture(autouse=True) def _setup(self: TestAppendStudyMetadata, spark: SparkSession) -> None: diff --git a/tests/gentropy/dataset/test_dataset.py b/tests/gentropy/dataset/test_dataset.py index 96a96ec27..2db065e50 100644 --- a/tests/gentropy/dataset/test_dataset.py +++ b/tests/gentropy/dataset/test_dataset.py @@ -42,9 +42,9 @@ def test_initialize_without_schema(self: TestDataset, spark: SparkSession) -> No """Test if Dataset derived class collects the schema from assets if schema is not provided.""" df = spark.createDataFrame([(1,)], schema=MockDataset.get_schema()) ds = MockDataset(_df=df) - assert ( - ds.schema == MockDataset.get_schema() - ), "Schema should be inferred from df" + assert ds.schema == MockDataset.get_schema(), ( + "Schema should be inferred from df" + ) def test_passing_incorrect_types(self: TestDataset, spark: SparkSession) -> None: """Test if passing incorrect object types to Dataset raises an error.""" @@ -97,6 +97,6 @@ def test_process_class_params(spark: SparkSession) -> None: } class_params, spark_params = Dataset._process_class_params(params) assert "_df" in class_params, "Class params should contain _df" - assert ( - "recursiveFileLookup" in spark_params - ), "Spark params should contain recursiveFileLookup" + assert "recursiveFileLookup" in spark_params, ( + "Spark params should contain recursiveFileLookup" + ) diff --git a/tests/gentropy/dataset/test_l2g.py b/tests/gentropy/dataset/test_l2g.py index 293735edd..ad8982c64 100644 --- a/tests/gentropy/dataset/test_l2g.py +++ b/tests/gentropy/dataset/test_l2g.py @@ -29,9 +29,9 @@ def test_process_gene_interactions(sample_otp_interactions: DataFrame) -> None: """Tests processing of gene interactions from OTP.""" expected_cols = ["geneIdA", "geneIdB", "score"] observed_df = L2GGoldStandard.process_gene_interactions(sample_otp_interactions) - assert ( - observed_df.columns == expected_cols - ), "Gene interactions has a different schema." + assert observed_df.columns == expected_cols, ( + "Gene interactions has a different schema." + ) def test_predictions(mock_l2g_predictions: L2GPrediction) -> None: @@ -171,9 +171,9 @@ def test_l2g_feature_constructor_with_schema_mismatch( ), with_gold_standard=False, ) - assert ( - fm._df.schema["distanceTssMean"].dataType == FloatType() - ), "Feature `distanceTssMean` is not being casted to FloatType. Check L2GFeatureMatrix constructor." + assert fm._df.schema["distanceTssMean"].dataType == FloatType(), ( + "Feature `distanceTssMean` is not being casted to FloatType. Check L2GFeatureMatrix constructor." + ) def test_calculate_feature_missingness_rate( @@ -185,9 +185,9 @@ def test_calculate_feature_missingness_rate( assert isinstance(observed_missingness, dict) assert mock_l2g_feature_matrix.features_list is not None and len( observed_missingness - ) == len( - mock_l2g_feature_matrix.features_list - ), "Missing features in the missingness rate dictionary." - assert ( - observed_missingness == expected_missingness - ), "Missingness rate is incorrect." + ) == len(mock_l2g_feature_matrix.features_list), ( + "Missing features in the missingness rate dictionary." + ) + assert observed_missingness == expected_missingness, ( + "Missingness rate is incorrect." + ) diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 3302338e4..fec93b564 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -295,9 +295,9 @@ def test__common_colocalisation_feature_logic( }, ], ).select("studyLocusId", "geneId", "eQtlColocH4Maximum") - assert ( - observed_df.collect() == expected_df.collect() - ), "The feature values are not as expected." + assert observed_df.collect() == expected_df.collect(), ( + "The feature values are not as expected." + ) def test_extend_missing_colocalisation_to_neighbourhood_genes( self: TestCommonColocalisationFeatureLogic, @@ -330,9 +330,9 @@ def test_extend_missing_colocalisation_to_neighbourhood_genes( expected_df = spark.createDataFrame( [{"geneId": "gene3", "studyLocusId": "1", "eQtlColocH4Maximum": 0.0}] ).select("studyLocusId", "geneId", "eQtlColocH4Maximum") - assert ( - observed_df.collect() == expected_df.collect() - ), "The feature values are not as expected." + assert observed_df.collect() == expected_df.collect(), ( + "The feature values are not as expected." + ) def test_common_neighbourhood_colocalisation_feature_logic( self: TestCommonColocalisationFeatureLogic, @@ -369,9 +369,9 @@ def test_common_neighbourhood_colocalisation_feature_logic( }, ], ).select("geneId", "studyLocusId", "eQtlColocH4MaximumNeighbourhood") - assert ( - observed_df.collect() == expected_df.collect() - ), "The expected and observed dataframes do not match." + assert observed_df.collect() == expected_df.collect(), ( + "The expected and observed dataframes do not match." + ) @pytest.fixture(autouse=True) def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> None: @@ -555,9 +555,9 @@ def test_common_distance_feature_logic( .select("studyLocusId", "geneId", feature_name) .orderBy(feature_name) ) - assert ( - observed_df.collect() == expected_df.collect() - ), f"Expected and observed dataframes are not equal for feature {feature_name}." + assert observed_df.collect() == expected_df.collect(), ( + f"Expected and observed dataframes are not equal for feature {feature_name}." + ) def test_common_neighbourhood_distance_feature_logic( self: TestCommonDistanceFeatureLogic, @@ -584,9 +584,9 @@ def test_common_neighbourhood_distance_feature_logic( ), # 0.91/0.91 ["geneId", "studyLocusId", feature_name], ).orderBy(feature_name) - assert ( - observed_df.collect() == expected_df.collect() - ), "Output doesn't meet the expectation." + assert observed_df.collect() == expected_df.collect(), ( + "Output doesn't meet the expectation." + ) @pytest.fixture(autouse=True) def _setup( @@ -773,9 +773,9 @@ def test_common_vep_feature_logic( .orderBy(feature_name) .select("studyLocusId", "geneId", feature_name) ) - assert ( - observed_df.collect() == expected_df.collect() - ), f"Expected and observed dataframes are not equal for feature {feature_name}." + assert observed_df.collect() == expected_df.collect(), ( + f"Expected and observed dataframes are not equal for feature {feature_name}." + ) def test_common_neighbourhood_vep_feature_logic( self: TestCommonVepFeatureLogic, @@ -807,9 +807,9 @@ def test_common_neighbourhood_vep_feature_logic( .orderBy(feature_name) .select("studyLocusId", "geneId", feature_name) ) - assert ( - observed_df.collect() == expected_df.collect() - ), "Output doesn't meet the expectation." + assert observed_df.collect() == expected_df.collect(), ( + "Output doesn't meet the expectation." + ) @pytest.fixture(autouse=True) def _setup(self: TestCommonVepFeatureLogic, spark: SparkSession) -> None: @@ -890,9 +890,9 @@ def test_common_genecount_feature_logic( .orderBy("studyLocusId", "geneId") ) - assert ( - observed_df.collect() == expected_df.collect() - ), f"Expected and observed dataframes do not match for feature {feature_name}." + assert observed_df.collect() == expected_df.collect(), ( + f"Expected and observed dataframes do not match for feature {feature_name}." + ) @pytest.fixture(autouse=True) def _setup(self: TestCommonGeneCountFeatureLogic, spark: SparkSession) -> None: @@ -981,9 +981,9 @@ def test_is_protein_coding_feature_logic( .select("studyLocusId", "geneId", "isProteinCoding") .orderBy("studyLocusId", "geneId") ) - assert ( - observed_df.collect() == expected_df.collect() - ), "Expected and observed DataFrames do not match." + assert observed_df.collect() == expected_df.collect(), ( + "Expected and observed DataFrames do not match." + ) @pytest.fixture(autouse=True) def _setup( diff --git a/tests/gentropy/dataset/test_l2g_feature_matrix.py b/tests/gentropy/dataset/test_l2g_feature_matrix.py index 02fed80ba..4fe40f804 100644 --- a/tests/gentropy/dataset/test_l2g_feature_matrix.py +++ b/tests/gentropy/dataset/test_l2g_feature_matrix.py @@ -60,9 +60,9 @@ def test_study_locus( self.sample_study_locus, features_list, loader ) for feature in features_list: - assert ( - feature in fm._df.columns - ), f"Feature {feature} not found in feature matrix." + assert feature in fm._df.columns, ( + f"Feature {feature} not found in feature matrix." + ) def test_gold_standard( self: TestFromFeaturesList, @@ -78,9 +78,9 @@ def test_gold_standard( self.sample_gold_standard, features_list, loader ) for feature in features_list: - assert ( - feature in fm._df.columns - ), f"Feature {feature} not found in feature matrix." + assert feature in fm._df.columns, ( + f"Feature {feature} not found in feature matrix." + ) @pytest.fixture(autouse=True) def _setup(self: TestFromFeaturesList, spark: SparkSession) -> None: diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index f8e59d97e..72cfd91c5 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -2,6 +2,7 @@ from __future__ import annotations +from pathlib import Path from typing import Any import pyspark.sql.functions as f @@ -18,6 +19,8 @@ StructType, ) +from gentropy.common.schemas import SchemaValidationError +from gentropy.common.session import Session from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.ld_index import LDIndex @@ -519,9 +522,9 @@ def test_filter_ld_set(spark: SparkSession) -> None: observed_data, ["studyLocusId", "ldSet"] ).withColumn("ldSet", StudyLocus.filter_ld_set(f.col("ldSet"), 0.5)) expected_tags_in_ld = 0 - assert ( - observed_df.filter(f.size("ldSet") > 1).count() == expected_tags_in_ld - ), "Expected tags in ld set differ from observed." + assert observed_df.filter(f.size("ldSet") > 1).count() == expected_tags_in_ld, ( + "Expected tags in ld set differ from observed." + ) def test_annotate_locus_statistics_boundaries( @@ -862,9 +865,9 @@ def test_build_feature_matrix( study_locus=mock_study_locus, ) fm = mock_study_locus.build_feature_matrix(features_list, loader) - assert isinstance( - fm, L2GFeatureMatrix - ), "Feature matrix should be of type L2GFeatureMatrix" + assert isinstance(fm, L2GFeatureMatrix), ( + "Feature matrix should be of type L2GFeatureMatrix" + ) class TestStudyLocusRedundancyFlagging: @@ -1209,7 +1212,6 @@ class TestTransQtlFlagging: ] STUDY_LOCUS_COLUMNS = ["studyLocusId", "variantId", "studyId"] - STUDY_DATA = [ ("s1", "p1", "qtl", "g1"), ("s2", "p2", "gwas", None), @@ -1221,21 +1223,21 @@ class TestTransQtlFlagging: GENE_COLUMNS = ["id", "strand", "start", "end", "chromosome", "tss"] @pytest.fixture(autouse=True) - def _setup(self: TestTransQtlFlagging, spark: SparkSession) -> None: + def _setup(self: TestTransQtlFlagging, session: Session) -> None: """Setup study locus for testing.""" self.study_locus = StudyLocus( _df=( - spark.createDataFrame( + session.spark.createDataFrame( self.STUDY_LOCUS_DATA, self.STUDY_LOCUS_COLUMNS ).withColumn("locus", f.array(f.struct("variantId"))) ) ) self.study_index = StudyIndex( - _df=spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS) + _df=session.spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS) ) self.target_index = TargetIndex( _df=( - spark.createDataFrame(self.GENE_DATA, self.GENE_COLUMNS).select( + session.spark.createDataFrame(self.GENE_DATA, self.GENE_COLUMNS).select( f.struct( f.col("strand").cast(IntegerType()).alias("strand"), "start", @@ -1280,6 +1282,30 @@ def test_correctness_all_qlts_are_flagged(self: TestTransQtlFlagging) -> None: def test_correctness_found_trans(self: TestTransQtlFlagging) -> None: """Make sure trans qtls are flagged.""" - assert ( - self.qtl_flagged.df.filter(f.col("isTransQtl")).count() == 2 - ), "Expected number of rows differ from observed." + assert self.qtl_flagged.df.filter(f.col("isTransQtl")).count() == 2, ( + "Expected number of rows differ from observed." + ) + + def test_add_flag_if_column_is_present( + self: TestTransQtlFlagging, tmp_path: Path, session: Session + ) -> None: + """Test adding flag if the `isTransQtl` column is already present. + + When reading the dataset, the reader will add the `isTransQtl` column to + the schema, which can cause column duplication captured only by Dataset schema validation. + + This test ensures that the column is dropped before the `flag_trans_qtls` is run. + """ + dataset_path = str(tmp_path / "study_locus") + self.study_locus.df.write.parquet(dataset_path) + schema_validated_study_locus = StudyLocus.from_parquet(session, dataset_path) + assert "isTransQtl" in schema_validated_study_locus.df.columns, ( + "`isTransQtl` column is missing after reading the dataset." + ) + # Rerun the flag addition and check if any error is raised by the schema validation + try: + schema_validated_study_locus.flag_trans_qtls( + self.study_index, self.target_index, self.THRESHOLD + ) + except SchemaValidationError: + pytest.fail("Failed to validate the schema when adding isTransQtl flag") diff --git a/tests/gentropy/datasource/biosample_ontologies/test_biosample_ontology.py b/tests/gentropy/datasource/biosample_ontologies/test_biosample_ontology.py index a9099048c..56cf319b0 100644 --- a/tests/gentropy/datasource/biosample_ontologies/test_biosample_ontology.py +++ b/tests/gentropy/datasource/biosample_ontologies/test_biosample_ontology.py @@ -28,15 +28,15 @@ def test_ontology_parser(self: TestOntologyParger, spark: SparkSession) -> None: self.SAMPLE_EFO_PATH, spark ).retain_rows_with_ancestor_id(["CL_0000000"]) - assert isinstance( - cell_ontology, BiosampleIndex - ), "Cell ontology subset is not parsed correctly to BiosampleIndex." - assert isinstance( - uberon, BiosampleIndex - ), "Uberon subset is not parsed correctly to BiosampleIndex." - assert isinstance( - efo_cell_line, BiosampleIndex - ), "EFO cell line subset is not parsed correctly to BiosampleIndex." + assert isinstance(cell_ontology, BiosampleIndex), ( + "Cell ontology subset is not parsed correctly to BiosampleIndex." + ) + assert isinstance(uberon, BiosampleIndex), ( + "Uberon subset is not parsed correctly to BiosampleIndex." + ) + assert isinstance(efo_cell_line, BiosampleIndex), ( + "EFO cell line subset is not parsed correctly to BiosampleIndex." + ) def test_merge_biosample_indices( self: TestOntologyParger, spark: SparkSession @@ -49,6 +49,6 @@ def test_merge_biosample_indices( efo = extract_ontology_from_json(self.SAMPLE_EFO_PATH, spark) merged = cell_ontology.merge_indices([uberon, efo]) - assert isinstance( - merged, BiosampleIndex - ), "Merging of biosample indices is not correct." + assert isinstance(merged, BiosampleIndex), ( + "Merging of biosample indices is not correct." + ) diff --git a/tests/gentropy/datasource/ensembl/test_vep_variants.py b/tests/gentropy/datasource/ensembl/test_vep_variants.py index f0127b9b2..f51e98914 100644 --- a/tests/gentropy/datasource/ensembl/test_vep_variants.py +++ b/tests/gentropy/datasource/ensembl/test_vep_variants.py @@ -15,8 +15,8 @@ from pyspark.sql import SparkSession -class TestVEPParserInSilicoExtractor: - """Testing the _vep_in_silico_prediction_extractor method of the VEP parser class. +class TestVEPParserVariantEffectExtractor: + """Testing the _vep_variant_effect_extractor method of the VEP parser class. These tests assumes that the _get_most_severe_transcript() method works correctly, as it's not tested. @@ -42,7 +42,7 @@ class TestVEPParserInSilicoExtractor: SAMPLE_COLUMNS = ["variantId", "assessment", "score", "gene_id", "flag"] @pytest.fixture(autouse=True) - def _setup(self: TestVEPParserInSilicoExtractor, spark: SparkSession) -> None: + def _setup(self: TestVEPParserVariantEffectExtractor, spark: SparkSession) -> None: """Setup fixture.""" parsed_df = ( spark.createDataFrame(self.SAMPLE_DATA, self.SAMPLE_COLUMNS) @@ -59,31 +59,28 @@ def _setup(self: TestVEPParserInSilicoExtractor, spark: SparkSession) -> None: ) .select( "variantId", - VariantEffectPredictorParser._vep_in_silico_prediction_extractor( + VariantEffectPredictorParser._vep_variant_effect_extractor( "transcripts", "method_name", "score", "assessment", "flag" - ).alias("in_silico_predictions"), + ).alias("variant_effect"), ) ).persist() self.df = parsed_df - def test_in_silico_output_missing_value( - self: TestVEPParserInSilicoExtractor, + def test_variant_effect_missing_value( + self: TestVEPParserVariantEffectExtractor, ) -> None: - """Test if the in silico output count is correct.""" + """Test if the variant effect count is correct.""" variant_with_missing_score = [ x[0] for x in filter(lambda x: x[2] is None, self.SAMPLE_DATA) ] # Assert that the correct variants return null: - assert ( - [ - x["variantId"] - for x in self.df.filter( - f.col("in_silico_predictions").isNull() - ).collect() - ] - == variant_with_missing_score - ), "Not the right variants got nullified in-silico predictor object." + assert [ + x["variantId"] + for x in self.df.filter(f.col("variant_effect").isNull()).collect() + ] == variant_with_missing_score, ( + "Not the right variants got nullified in variant effect object." + ) class TestVEPParser: @@ -120,18 +117,18 @@ def test_conversion(self: TestVEPParser) -> None: _schema=VariantIndex.get_schema(), ) - assert isinstance( - variant_index, VariantIndex - ), "VariantIndex object not created." + assert isinstance(variant_index, VariantIndex), ( + "VariantIndex object not created." + ) def test_variant_count(self: TestVEPParser) -> None: """Test if the number of variants is correct. It is expected that all rows from the parsed VEP output are present in the processed VEP output. """ - assert ( - self.raw_vep_output.count() == self.processed_vep_output.count() - ), f"Incorrect number of variants in processed VEP output: expected {self.raw_vep_output.count()}, got {self.processed_vep_output.count()}." + assert self.raw_vep_output.count() == self.processed_vep_output.count(), ( + f"Incorrect number of variants in processed VEP output: expected {self.raw_vep_output.count()}, got {self.processed_vep_output.count()}." + ) def test_collection(self: TestVEPParser) -> None: """Test if the collection of VEP variantIndex runs without failures.""" @@ -150,6 +147,6 @@ def test_ensembl_transcripts_no_duplicates(self: TestVEPParser) -> None: ) asserted_targets = [t["targetId"] for t in targets] - assert len(asserted_targets) == len( - set(asserted_targets) - ), "Duplicate ensembl transcripts in a single row." + assert len(asserted_targets) == len(set(asserted_targets)), ( + "Duplicate ensembl transcripts in a single row." + ) diff --git a/tests/gentropy/datasource/finngen/test_finngen_study_index.py b/tests/gentropy/datasource/finngen/test_finngen_study_index.py index c85629a09..635b64d2a 100644 --- a/tests/gentropy/datasource/finngen/test_finngen_study_index.py +++ b/tests/gentropy/datasource/finngen/test_finngen_study_index.py @@ -354,9 +354,9 @@ def test_finngen_validate_release_prefix( ) -> None: """Test validate_release_prefix.""" if not xfail: - assert ( - FinnGenStudyIndex.validate_release_prefix(prefix) == expected_output - ), "Incorrect match object" + assert FinnGenStudyIndex.validate_release_prefix(prefix) == expected_output, ( + "Incorrect match object" + ) else: with pytest.raises(ValueError): FinnGenStudyIndex.validate_release_prefix(prefix) diff --git a/tests/gentropy/datasource/gnomad/test_gnomad_ld.py b/tests/gentropy/datasource/gnomad/test_gnomad_ld.py index 78b96ad84..f7a58c007 100644 --- a/tests/gentropy/datasource/gnomad/test_gnomad_ld.py +++ b/tests/gentropy/datasource/gnomad/test_gnomad_ld.py @@ -135,15 +135,15 @@ def test_get_ld_matrix_slice__count(self: TestGnomADLDMatrixSlice) -> None: included_indices = self.slice_end_index - self.slice_start_index + 1 expected_pariwise_count = included_indices**2 - assert ( - self.matrix_slice.count() == expected_pariwise_count - ), "The matrix is not complete." + assert self.matrix_slice.count() == expected_pariwise_count, ( + "The matrix is not complete." + ) def test_get_ld_matrix_slice__type(self: TestGnomADLDMatrixSlice) -> None: """Test LD matrix slice.""" - assert isinstance( - self.matrix_slice, DataFrame - ), "The returned data is not a dataframe." + assert isinstance(self.matrix_slice, DataFrame), ( + "The returned data is not a dataframe." + ) def test_get_ld_matrix_slice__symmetry( self: TestGnomADLDMatrixSlice, @@ -162,9 +162,9 @@ def test_get_ld_matrix_slice__symmetry( how="inner", ) - assert ( - compared.count() == self.matrix_slice.count() - ), "The matrix is not complete." + assert compared.count() == self.matrix_slice.count(), ( + "The matrix is not complete." + ) assert ( compared.filter(f.col("r") == f.col("r_sym")).count() == compared.count() ), "The matrix is not symmetric." diff --git a/tests/gentropy/datasource/gwas_catalog/test_gwas_catalog_curation.py b/tests/gentropy/datasource/gwas_catalog/test_gwas_catalog_curation.py index 0dd1a7363..bb58e0401 100644 --- a/tests/gentropy/datasource/gwas_catalog/test_gwas_catalog_curation.py +++ b/tests/gentropy/datasource/gwas_catalog/test_gwas_catalog_curation.py @@ -80,13 +80,17 @@ def test_curation__return_type( assert isinstance( mock_gwas_study_index.annotate_from_study_curation(None), StudyIndexGWASCatalog, - ), f"Applying curation without curation table should yield a study table, but got: {type(mock_gwas_study_index.annotate_from_study_curation(None))}" + ), ( + f"Applying curation without curation table should yield a study table, but got: {type(mock_gwas_study_index.annotate_from_study_curation(None))}" + ) # Return type should work: assert isinstance( mock_gwas_study_index.annotate_from_study_curation(mock_study_curation), StudyIndexGWASCatalog, - ), f"Applying curation should return a study table, however got: {type(mock_gwas_study_index.annotate_from_study_curation(mock_study_curation))}" + ), ( + f"Applying curation should return a study table, however got: {type(mock_gwas_study_index.annotate_from_study_curation(mock_study_curation))}" + ) @staticmethod def test_curation__returned_rows( @@ -101,14 +105,14 @@ def test_curation__returned_rows( mock_study_curation ).df.count() # Method should work on empty curation: - assert ( - zero_return_count == expected_count - ), f"When applied None to curation function, the size of the returned data was not as expected ({zero_return_count} vs {expected_count})." + assert zero_return_count == expected_count, ( + f"When applied None to curation function, the size of the returned data was not as expected ({zero_return_count} vs {expected_count})." + ) # Return type should work: - assert ( - return_count == expected_count - ), f"When applied curation data, the size of the returned data was not as expected ({return_count} vs {expected_count})." + assert return_count == expected_count, ( + f"When applied curation data, the size of the returned data was not as expected ({return_count} vs {expected_count})." + ) # Test updated type @staticmethod diff --git a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py index 54bcbf8d0..dee560928 100644 --- a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py +++ b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py @@ -66,9 +66,9 @@ def test_expand_gold_standard_with_negatives_logic( self: TestExpandGoldStandardWithNegatives, spark: SparkSession ) -> None: """Test expanding positive set with negative set coincides with expected results.""" - assert ( - self.observed_df.collect() == self.expected_expanded_gs.collect() - ), "GS expansion is not as expected." + assert self.observed_df.collect() == self.expected_expanded_gs.collect(), ( + "GS expansion is not as expected." + ) def test_expand_gold_standard_with_negatives_same_positives( self: TestExpandGoldStandardWithNegatives, spark: SparkSession diff --git a/tests/gentropy/datasource/open_targets/test_variants.py b/tests/gentropy/datasource/open_targets/test_variants.py index 6aa22e628..57784e98c 100644 --- a/tests/gentropy/datasource/open_targets/test_variants.py +++ b/tests/gentropy/datasource/open_targets/test_variants.py @@ -49,9 +49,9 @@ def test_as_vcf_df_credible_set( ], vcf_cols, ) - assert ( - observed_df.collect() == df_credible_set_expected_df.collect() - ), "Unexpected VCF dataframe." + assert observed_df.collect() == df_credible_set_expected_df.collect(), ( + "Unexpected VCF dataframe." + ) def test_as_vcf_df_without_variant_id( self: TestOpenTargetsVariant, @@ -85,6 +85,6 @@ def test_as_vcf_df_without_rs_id( vcf_cols, ) - assert ( - observed_df.collect() == df_without_rs_id_expected_df.collect() - ), "Unexpected VCF dataframe." + assert observed_df.collect() == df_without_rs_id_expected_df.collect(), ( + "Unexpected VCF dataframe." + ) diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index 78a66f732..9e2541543 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -249,7 +249,7 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: "right_logBF": 10.5, "left_beta": 0.5, "right_beta": 0.2, - "left_posteriorProbability": 0.36, + "left_posteriorProbability": 0.09, "right_posteriorProbability": 0.92, }, }, @@ -297,9 +297,9 @@ def test_coloc_semantic( expected_coloc_pdf = expected_coloc_df.toPandas() if expected_coloc_pdf.empty: - assert ( - observed_coloc_pdf.empty - ), f"Expected an empty DataFrame, but got:\n{observed_coloc_pdf}" + assert observed_coloc_pdf.empty, ( + f"Expected an empty DataFrame, but got:\n{observed_coloc_pdf}" + ) else: assert_frame_equal( observed_coloc_pdf, @@ -366,12 +366,12 @@ def test_coloc_no_logbf( StudyLocusOverlap.get_schema(), ) observed_coloc_df = Coloc.colocalise(observed_overlap).df - assert ( - observed_coloc_df.select("h0").collect()[0]["h0"] > minimum_expected_h0 - ), "COLOC should return a high h0 (no association) when the input data has irrelevant logBF." - assert ( - observed_coloc_df.select("h4").collect()[0]["h4"] < maximum_expected_h4 - ), "COLOC should return a low h4 (traits are associated) when the input data has irrelevant logBF." + assert observed_coloc_df.select("h0").collect()[0]["h0"] > minimum_expected_h0, ( + "COLOC should return a high h0 (no association) when the input data has irrelevant logBF." + ) + assert observed_coloc_df.select("h4").collect()[0]["h4"] < maximum_expected_h4, ( + "COLOC should return a low h4 (traits are associated) when the input data has irrelevant logBF." + ) def test_coloc_no_betas(spark: SparkSession) -> None: diff --git a/tests/gentropy/method/test_locus_breaker_clumping.py b/tests/gentropy/method/test_locus_breaker_clumping.py index c2c23eca5..3eb2290c6 100644 --- a/tests/gentropy/method/test_locus_breaker_clumping.py +++ b/tests/gentropy/method/test_locus_breaker_clumping.py @@ -92,29 +92,29 @@ def test_return_type( self: TestLocusBreakerClumping, clumped_data: StudyLocus ) -> None: """Testing return type.""" - assert isinstance( - clumped_data, StudyLocus - ), f"Unexpected return type: {type(clumped_data)}" + assert isinstance(clumped_data, StudyLocus), ( + f"Unexpected return type: {type(clumped_data)}" + ) def test_number_of_loci( self: TestLocusBreakerClumping, clumped_data: StudyLocus ) -> None: """Testing return type.""" - assert ( - clumped_data.df.count() == 5 - ), f"Unexpected number of loci: {clumped_data.df.count()}" + assert clumped_data.df.count() == 5, ( + f"Unexpected number of loci: {clumped_data.df.count()}" + ) def test_top_loci(self: TestLocusBreakerClumping, clumped_data: StudyLocus) -> None: """Testing selected top-loci.""" top_loci_variants = clumped_data.df.select("variantId").distinct().collect() - assert ( - len(top_loci_variants) == 1 - ), f"Unexpected number of top loci: {len(top_loci_variants)} ({top_loci_variants})" + assert len(top_loci_variants) == 1, ( + f"Unexpected number of top loci: {len(top_loci_variants)} ({top_loci_variants})" + ) - assert ( - top_loci_variants[0]["variantId"] == "top_loci" - ), f"Unexpected top locus: {top_loci_variants[0]['variantId']}" + assert top_loci_variants[0]["variantId"] == "top_loci", ( + f"Unexpected top locus: {top_loci_variants[0]['variantId']}" + ) def test_locus_boundaries( self: TestLocusBreakerClumping, clumped_data: StudyLocus diff --git a/tests/gentropy/method/test_susie_inf.py b/tests/gentropy/method/test_susie_inf.py index 4885a3d8a..ad2f8b43b 100644 --- a/tests/gentropy/method/test_susie_inf.py +++ b/tests/gentropy/method/test_susie_inf.py @@ -24,9 +24,9 @@ def test_SUSIE_inf_lbf_moments( lbf_moments = sample_data_for_susie_inf[2] susie_output = SUSIE_inf.susie_inf(z=z, LD=ld, est_tausq=True, method="moments") lbf_calc = susie_output["lbf_variable"][:, 0] - assert np.allclose( - lbf_calc, lbf_moments - ), "LBFs for method of moments are not equal" + assert np.allclose(lbf_calc, lbf_moments), ( + "LBFs for method of moments are not equal" + ) def test_SUSIE_inf_lbf_mle( self: TestSUSIE_inf, sample_data_for_susie_inf: list[np.ndarray] @@ -37,9 +37,9 @@ def test_SUSIE_inf_lbf_mle( lbf_mle = sample_data_for_susie_inf[3] susie_output = SUSIE_inf.susie_inf(z=z, LD=ld, est_tausq=True, method="MLE") lbf_calc = susie_output["lbf_variable"][:, 0] - assert np.allclose( - lbf_calc, lbf_mle, atol=1e-1 - ), "LBFs for maximum likelihood estimation are not equal" + assert np.allclose(lbf_calc, lbf_mle, atol=1e-1), ( + "LBFs for maximum likelihood estimation are not equal" + ) def test_SUSIE_inf_cred( self: TestSUSIE_inf, sample_data_for_susie_inf: list[np.ndarray] diff --git a/tests/gentropy/step/test_colocalisation_step.py b/tests/gentropy/step/test_colocalisation_step.py index 2d2a43707..5663e2f88 100644 --- a/tests/gentropy/step/test_colocalisation_step.py +++ b/tests/gentropy/step/test_colocalisation_step.py @@ -215,9 +215,9 @@ def test_get_colocalisation_class( ) -> None: """Test _get_colocalisation_class method on ColocalisationStep.""" method = ColocalisationStep._get_colocalisation_class(label) - assert ( - method is expected_method - ), "Incorrect colocalisation class returned by ColocalisationStep._get_colocalisation_class(label)" + assert method is expected_method, ( + "Incorrect colocalisation class returned by ColocalisationStep._get_colocalisation_class(label)" + ) def test_label_with_invalid_method(self) -> None: """Test what happens when invalid method_label is passed to the _get_colocalisation_class.""" @@ -286,10 +286,10 @@ def test_colocalise( values = [c[column] for c in coloc_dataset.df.collect()] for v, e in zip(values, expected_values): if isinstance(e, float): - assert ( - e == pytest.approx(v, 1e-1) - ), f"Incorrect value {v} at {column} found in {coloc_method}, expected {e}" + assert e == pytest.approx(v, 1e-1), ( + f"Incorrect value {v} at {column} found in {coloc_method}, expected {e}" + ) else: - assert ( - e == v - ), f"Incorrect value {v} at {column} found in {coloc_method}, expected {e}" + assert e == v, ( + f"Incorrect value {v} at {column} found in {coloc_method}, expected {e}" + ) diff --git a/tests/gentropy/step/test_convert_to_vcf_step.py b/tests/gentropy/step/test_convert_to_vcf_step.py index cc4ec800d..7d43f88b1 100644 --- a/tests/gentropy/step/test_convert_to_vcf_step.py +++ b/tests/gentropy/step/test_convert_to_vcf_step.py @@ -90,15 +90,15 @@ def test_step( variants_df = session.spark.read.csv(output_path, sep="\t", header=True) # 40 variants (10 variants from each source) expected_variant_count = sum(c["n_variants"] for c in sources) - assert ( - variants_df.count() == expected_variant_count - ), "Found incorrect number of variants" + assert variants_df.count() == expected_variant_count, ( + "Found incorrect number of variants" + ) partitions = [ str(p) for p in Path(output_path).iterdir() if str(p).endswith("csv") ] - assert ( - len(partitions) == expected_partition_number - ), "Found incorrect number of partitions" + assert len(partitions) == expected_partition_number, ( + "Found incorrect number of partitions" + ) def test_sorting( self, diff --git a/tests/gentropy/step/test_credible_set_qc.py b/tests/gentropy/step/test_credible_set_qc.py index c7fb58c8c..f4a9d2e0f 100644 --- a/tests/gentropy/step/test_credible_set_qc.py +++ b/tests/gentropy/step/test_credible_set_qc.py @@ -135,9 +135,9 @@ def test_step(self, session: Session) -> None: for p in Path(self.output_path).iterdir() if str(p).endswith(".parquet") ] - assert ( - len(partitions) == self.n_partitions - ), "Incorrect number of partitions in the output." + assert len(partitions) == self.n_partitions, ( + "Incorrect number of partitions in the output." + ) cs = StudyLocus.from_parquet( session, self.output_path, recursiveFileLookup=True ) diff --git a/tests/gentropy/test_schemas.py b/tests/gentropy/test_schemas.py index d86108491..a82b7bb40 100644 --- a/tests/gentropy/test_schemas.py +++ b/tests/gentropy/test_schemas.py @@ -69,9 +69,9 @@ def test_schema_columns_camelcase(schema_json: str) -> None: # CamelCase starts with a lowercase letter and has uppercase letters in between. for field in schema.fields: - assert is_camelcase( - field.name - ), f"Column name '{field.name}' is not in camelCase." + assert is_camelcase(field.name), ( + f"Column name '{field.name}' is not in camelCase." + ) class TestValidateSchema: diff --git a/tests/gentropy/test_spark_helpers.py b/tests/gentropy/test_spark_helpers.py index 1adb686cd..d3ccecec9 100644 --- a/tests/gentropy/test_spark_helpers.py +++ b/tests/gentropy/test_spark_helpers.py @@ -39,9 +39,9 @@ def test_get_record_with_minimum_value_group_one_col( df = mock_variant_df.transform( lambda df: get_record_with_minimum_value(df, grouping_col, sorting_col) ) - assert ( - df.filter(f.col("chromosome") == 16).collect()[0].__getitem__("position") - ), 10116 + assert df.filter(f.col("chromosome") == 16).collect()[0].__getitem__("position"), ( + 10116 + ) def test_get_record_with_maximum_value_group_two_cols( diff --git a/utils/clean_status.sh b/utils/clean_status.sh index 583bb881c..73e3ad50a 100755 --- a/utils/clean_status.sh +++ b/utils/clean_status.sh @@ -1,4 +1,12 @@ #!/usr/bin/env bash +REQUESTED_REF=$1 + +CURRENT_REF=$(git rev-parse --abbrev-ref HEAD) + +if [ "$REQUESTED_REF" != "$CURRENT_REF" ]; then + echo "Requested branch $REQUESTED_REF is not the current branch $CURRENT_REF, skipping status checks" + exit 0 +fi echo "Fetching version changes..." git fetch diff --git a/uv.lock b/uv.lock index c615f6593..1d1385cdb 100644 --- a/uv.lock +++ b/uv.lock @@ -985,7 +985,7 @@ test = [ [package.metadata] requires-dist = [ - { name = "google-cloud-secret-manager", specifier = ">=2.12.6,<2.13.0" }, + { name = "google-cloud-secret-manager", specifier = ">=2.12.6,<2.24.0" }, { name = "google-cloud-storage", specifier = ">=2.14.0,<3.1.0" }, { name = "hail", specifier = ">=0.2.133,<0.3.0" }, { name = "hydra-core", specifier = ">=1.3.2,<1.4.0" },