Skip to content

Commit

Permalink
Merge pull request #303 from VariantEffect/feature/bencap/302/require…
Browse files Browse the repository at this point in the history
…-fully-qualified-variants-for-accession-based-score-sets

Require fully qualified variants for accession based score sets
  • Loading branch information
bencap authored Oct 11, 2024
2 parents b1ed277 + 074c863 commit bc29fc3
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 96 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import sqlalchemy as sa
from sqlalchemy.orm import Session, configure_mappers

from mavedb.models import *

from mavedb.models.score_set import ScoreSet
from mavedb.models.variant import Variant
from mavedb.models.target_gene import TargetGene
from mavedb.models.target_accession import TargetAccession

from mavedb.db.session import SessionLocal

configure_mappers()


def do_migration(db: Session):
accession_based_score_sets = db.execute(
sa.select(ScoreSet).join(TargetGene).where(TargetGene.accession_id.isnot(None))
).scalars()

for score_set in accession_based_score_sets:
total_targets = len(
list(db.execute(sa.select(TargetGene).where(TargetGene.score_set_id == score_set.id)).scalars())
)

# Variants from score sets with multiple targets are already in the desired format.
if total_targets > 1:
continue

target_accession = db.execute(
sa.select(TargetAccession.accession).join(TargetGene).where(TargetGene.score_set_id == score_set.id)
).scalar()
variants = db.execute(sa.select(Variant).where(Variant.score_set_id == score_set.id)).scalars()

if target_accession is None:
raise ValueError("target accession should never be None.")

for variant in variants:
if variant.hgvs_nt:
variant.hgvs_nt = f"{target_accession}:{variant.hgvs_nt}"
if variant.hgvs_pro:
variant.hgvs_pro = f"{target_accession}:{variant.hgvs_pro}"
if variant.hgvs_splice:
variant.hgvs_splice = f"{target_accession}:{variant.hgvs_splice}"

db.add(variant)


if __name__ == "__main__":
db = SessionLocal()
db.current_user = None # type: ignore

do_migration(db)

db.commit()
db.close()
25 changes: 11 additions & 14 deletions src/mavedb/lib/validation/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,23 @@ def validate_dataframe(df: pd.DataFrame, kind: str, targets: list["TargetGene"],
if df[column_mapping[c]].isna().all() and not is_index:
continue

score_set_is_accession_based = all(target.target_accession for target in targets)
score_set_is_sequence_based = all(target.target_sequence for target in targets)

# This is typesafe, despite Pylance's claims otherwise
if all(target.target_accession for target in targets):
if score_set_is_accession_based and not score_set_is_sequence_based:
validate_hgvs_genomic_column(
df[column_mapping[c]], is_index, [target.target_accession for target in targets], hdp # type: ignore
)
elif all(target.target_sequence for target in targets):
elif score_set_is_sequence_based and not score_set_is_accession_based:
validate_hgvs_transgenic_column(
df[column_mapping[c]], is_index, {target.target_sequence.label: target.target_sequence for target in targets} # type: ignore
)
else:
raise MixedTargetError("Could not validate dataframe against provided mixed target types.")

# post validation, handle prefixes. We've already established these columns are non-null
if len(targets) > 1:
if score_set_is_accession_based or len(targets) > 1:
prefixes[c] = (
df[column_mapping[c]].dropna()[0].split(" ")[0].split(":")[1][0]
) # Just take the first prefix, we validate consistency elsewhere
Expand Down Expand Up @@ -374,7 +377,7 @@ def validate_hgvs_transgenic_column(column: pd.Series, is_index: bool, targets:
valid_sequence_types = ("dna", "protein")
validate_variant_column(column, is_index)
prefixes = generate_variant_prefixes(column)
validate_variant_formatting(column, prefixes, list(targets.keys()))
validate_variant_formatting(column, prefixes, list(targets.keys()), len(targets) > 1)

observed_sequence_types = [target.sequence_type for target in targets.values()]
invalid_sequence_types = set(observed_sequence_types) - set(valid_sequence_types)
Expand Down Expand Up @@ -454,9 +457,6 @@ def validate_hgvs_genomic_column(
This function also validates all individual variants in the column and checks for agreement against the target
sequence (for non-splice variants).
Implementation NOTE: We assume variants will only be presented as fully qualified (accession:variant)
if this column is being validated against multiple targets.
Parameters
----------
column : pd.Series
Expand All @@ -482,7 +482,7 @@ def validate_hgvs_genomic_column(
validate_variant_column(column, is_index)
prefixes = generate_variant_prefixes(column)
validate_variant_formatting(
column, prefixes, [target.accession for target in targets if target.accession is not None]
column, prefixes, [target.accession for target in targets if target.accession is not None], True
)

# validate the individual variant strings
Expand All @@ -508,12 +508,9 @@ def validate_hgvs_genomic_column(
for i, s in column.items():
if s is not None:
for variant in s.split(" "):
# Add accession info when we only have one target
if len(targets) == 1:
s = f"{targets[0].accession}:{variant}"
try:
# We set strict to `False` to suppress validation warnings about intronic variants.
vr.validate(hp.parse(s), strict=False)
vr.validate(hp.parse(variant), strict=False)
except hgvs.exceptions.HGVSError as e:
invalid_variants.append(f"Failed to parse row {i} with HGVS exception: {e}")

Expand All @@ -524,7 +521,7 @@ def validate_hgvs_genomic_column(
)


def validate_variant_formatting(column: pd.Series, prefixes: list[str], targets: list[str]):
def validate_variant_formatting(column: pd.Series, prefixes: list[str], targets: list[str], fully_qualified: bool):
"""
Validate the formatting of HGVS variants present in the passed column against
lists of prefixes and targets
Expand Down Expand Up @@ -554,7 +551,7 @@ def validate_variant_formatting(column: pd.Series, prefixes: list[str], targets:
variants = [variant for s in column.dropna() for variant in s.split(" ")]

# if there is more than one target, we expect variants to be fully qualified
if len(targets) > 1:
if fully_qualified:
if not all(len(str(v).split(":")) == 2 for v in variants):
raise ValidationError(
f"variant column '{column.name}' needs fully qualified coordinates when validating against multiple targets"
Expand Down
14 changes: 11 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ def session(postgresql):
)

engine = create_engine(connection, echo=False, poolclass=NullPool)
session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = sessionmaker(autocommit=False, autoflush=False, bind=engine)()

Base.metadata.create_all(bind=engine)

try:
yield session()
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)


Expand Down Expand Up @@ -170,7 +171,14 @@ async def on_job(ctx):

@pytest.fixture
def standalone_worker_context(session, data_provider, arq_redis):
yield {"db": session, "hdp": data_provider, "state": {}, "job_id": "test_job", "redis": arq_redis, "pool": futures.ProcessPoolExecutor()}
yield {
"db": session,
"hdp": data_provider,
"state": {},
"job_id": "test_job",
"redis": arq_redis,
"pool": futures.ProcessPoolExecutor(),
}


@pytest.fixture()
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def mock_worker_variant_insertion(client, db, data_provider, score_set, scores_c
score_df = csv_data_to_df(score_file)

if counts_csv_path is not None:
with open(scores_csv_path, "rb") as score_file:
with open(scores_csv_path, "rb") as counts_file:
counts_df = csv_data_to_df(counts_file)
else:
counts_df = None
Expand Down
6 changes: 3 additions & 3 deletions tests/routers/data/counts_acc.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
hgvs_nt,c_0,c_1
c.1G>C,10,20
c.2A>G,8,8
c.6C>A,90,2
NM_001637.3:c.1G>C,10,20
NM_001637.3:c.2A>G,8,8
NM_001637.3:c.6C>A,90,2
6 changes: 3 additions & 3 deletions tests/routers/data/scores_acc.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
hgvs_nt,score
c.1G>C,0.3
c.2A>G,0.0
c.6C>A,-1.65
NM_001637.3:c.1G>C,0.3
NM_001637.3:c.2A>G,0.0
NM_001637.3:c.6C>A,-1.65
59 changes: 35 additions & 24 deletions tests/validation/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,34 +515,34 @@ def setUp(self) -> None:
self.valid_targets = ["test1", "test2"]

def test_single_target_valid_variants(self):
validate_variant_formatting(self.valid, self.valid_prefixes, self.valid_target)
validate_variant_formatting(self.valid, self.valid_prefixes, self.valid_target, False)

def test_single_target_inconsistent_variants(self):
with self.assertRaises(ValidationError):
validate_variant_formatting(self.inconsistent, self.valid_prefixes, self.valid_target)
validate_variant_formatting(self.inconsistent, self.valid_prefixes, self.valid_target, False)

def test_single_target_invalid_prefixes(self):
with self.assertRaises(ValidationError):
validate_variant_formatting(self.valid, self.invalid_prefixes, self.valid_target)
validate_variant_formatting(self.valid, self.invalid_prefixes, self.valid_target, False)

def test_multi_target_valid_variants(self):
validate_variant_formatting(self.valid_multi, self.valid_prefixes, self.valid_targets)
validate_variant_formatting(self.valid_multi, self.valid_prefixes, self.valid_targets, True)

def test_multi_target_inconsistent_variants(self):
with self.assertRaises(ValidationError):
validate_variant_formatting(self.inconsistent_multi, self.valid_prefixes, self.valid_targets)
validate_variant_formatting(self.inconsistent_multi, self.valid_prefixes, self.valid_targets, True)

def test_multi_target_invalid_prefixes(self):
with self.assertRaises(ValidationError):
validate_variant_formatting(self.valid_multi, self.invalid_prefixes, self.valid_targets)
validate_variant_formatting(self.valid_multi, self.invalid_prefixes, self.valid_targets, True)

def test_multi_target_lacking_full_coords(self):
with self.assertRaises(ValidationError):
validate_variant_formatting(self.valid, self.valid_prefixes, self.valid_targets)
validate_variant_formatting(self.valid, self.valid_prefixes, self.valid_targets, True)

def test_multi_target_invalid_accessions(self):
with self.assertRaises(ValidationError):
validate_variant_formatting(self.invalid_multi, self.valid_prefixes, self.valid_targets)
validate_variant_formatting(self.invalid_multi, self.valid_prefixes, self.valid_targets, True)


class TestGenerateVariantPrefixes(DfTestCase):
Expand Down Expand Up @@ -910,27 +910,38 @@ def setUp(self):

self.accession_test_case = AccessionTestCase()

self.valid_hgvs_column = pd.Series(["c.1G>A", "c.2A>T"], name=hgvs_nt_column)
self.missing_data = pd.Series(["c.3T>G", None], name=hgvs_nt_column)
self.duplicate_data = pd.Series(["c.4A>G", "c.4A>G"], name=hgvs_nt_column)
self.valid_hgvs_column = pd.Series(
[f"{VALID_ACCESSION}:c.1G>A", f"{VALID_ACCESSION}:c.2A>T"], name=hgvs_nt_column
)
self.missing_data = pd.Series([f"{VALID_ACCESSION}:c.3T>G", None], name=hgvs_nt_column)
self.duplicate_data = pd.Series([f"{VALID_ACCESSION}:c.4A>G", f"{VALID_ACCESSION}:c.4A>G"], name=hgvs_nt_column)

self.invalid_hgvs_columns_by_name = [
pd.Series(["g.1A>G", "g.1A>T"], name=hgvs_splice_column),
pd.Series(["g.1A>G", "g.1A>T"], name=hgvs_pro_column),
pd.Series(["c.1A>G", "c.1A>T"], name=hgvs_pro_column),
pd.Series(["n.1A>G", "n.1A>T"], name=hgvs_pro_column),
pd.Series(["p.Met1Val", "p.Met1Leu"], name=hgvs_nt_column),
pd.Series([f"{VALID_ACCESSION}:g.1A>G", f"{VALID_ACCESSION}:g.1A>T"], name=hgvs_splice_column),
pd.Series([f"{VALID_ACCESSION}:g.1A>G", f"{VALID_ACCESSION}:g.1A>T"], name=hgvs_pro_column),
pd.Series([f"{VALID_ACCESSION}:c.1A>G", f"{VALID_ACCESSION}:c.1A>T"], name=hgvs_pro_column),
pd.Series([f"{VALID_ACCESSION}:n.1A>G", f"{VALID_ACCESSION}:n.1A>T"], name=hgvs_pro_column),
pd.Series([f"{VALID_ACCESSION}:p.Met1Val", f"{VALID_ACCESSION}:p.Met1Leu"], name=hgvs_nt_column),
]

self.invalid_hgvs_columns_by_contents = [
pd.Series(["r.1a>g", "r.1a>u"], name=hgvs_splice_column), # rna not allowed
pd.Series(["r.1a>g", "r.1a>u"], name=hgvs_nt_column), # rna not allowed
pd.Series(["c.1A>G", "c.5A>T"], name=hgvs_nt_column), # out of bounds for target
pd.Series(["c.1A>G", "_wt"], name=hgvs_nt_column), # old special variant
pd.Series(["p.Met1Leu", "_sy"], name=hgvs_pro_column), # old special variant
pd.Series(["n.1A>G", "c.1A>T"], name=hgvs_nt_column), # mixed prefix
pd.Series(["c.1A>G", "p.Met1Leu"], name=hgvs_pro_column), # mixed types/prefix
pd.Series(["c.1A>G", 2.5], name=hgvs_nt_column), # contains numeric
pd.Series(
[f"{VALID_ACCESSION}:r.1a>g", f"{VALID_ACCESSION}:r.1a>u"], name=hgvs_splice_column
), # rna not allowed
pd.Series(
[f"{VALID_ACCESSION}:r.1a>g", f"{VALID_ACCESSION}:r.1a>u"], name=hgvs_nt_column
), # rna not allowed
pd.Series(
[f"{VALID_ACCESSION}:c.1A>G", f"{VALID_ACCESSION}:c.5A>T"], name=hgvs_nt_column
), # out of bounds for target
pd.Series([f"{VALID_ACCESSION}:c.1A>G", "_wt"], name=hgvs_nt_column), # old special variant
pd.Series([f"{VALID_ACCESSION}:p.Met1Leu", "_sy"], name=hgvs_pro_column), # old special variant
pd.Series([f"{VALID_ACCESSION}:n.1A>G", f"{VALID_ACCESSION}:c.1A>T"], name=hgvs_nt_column), # mixed prefix
pd.Series(
[f"{VALID_ACCESSION}:c.1A>G", f"{VALID_ACCESSION}:p.Met1Leu"], name=hgvs_pro_column
), # mixed types/prefix
pd.Series(["c.1A>G", "p.Met1Leu"], name=hgvs_pro_column), # variants should be fully qualified
pd.Series([f"{VALID_ACCESSION}:c.1A>G", 2.5], name=hgvs_nt_column), # contains numeric
pd.Series([1.0, 2.5], name=hgvs_nt_column), # contains numeric
pd.Series([1.0, 2.5], name=hgvs_splice_column), # contains numeric
pd.Series([1.0, 2.5], name=hgvs_pro_column), # contains numeric
Expand Down
6 changes: 3 additions & 3 deletions tests/worker/data/counts_acc.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
hgvs_nt,c_0,c_1
c.1G>C,10,20
c.2A>G,8,8
c.6C>A,90,2
NM_001637.3:c.1G>C,10,20
NM_001637.3:c.2A>G,8,8
NM_001637.3:c.6C>A,90,2
6 changes: 3 additions & 3 deletions tests/worker/data/scores_acc.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
hgvs_nt,score
c.1G>C,0.3
c.2A>G,0.0
c.6C>A,-1.65
NM_001637.3:c.1G>C,0.3
NM_001637.3:c.2A>G,0.0
NM_001637.3:c.6C>A,-1.65
Loading

0 comments on commit bc29fc3

Please sign in to comment.