Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assign_clade method to CladeTime class #57

Merged
merged 13 commits into from
Nov 13, 2024
12 changes: 12 additions & 0 deletions src/cladetime/_clade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass

import polars as pl


@dataclass
class Clade:
"""Holds detailed and summarized information about clade assignments."""

meta: dict
detail: pl.LazyFrame
summary: pl.LazyFrame
27 changes: 23 additions & 4 deletions src/cladetime/cladetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import structlog

from cladetime import Tree, sequence
from cladetime._clade import Clade
from cladetime.exceptions import CladeTimeDateWarning, CladeTimeInvalidURLError, CladeTimeSequenceWarning
from cladetime.util.config import Config
from cladetime.util.reference import _get_clade_assignments, _get_date, _get_nextclade_dataset, _get_s3_object_url
Expand Down Expand Up @@ -233,9 +234,11 @@ def assign_clades(self, sequence_metadata: pl.LazyFrame, output_file: str | None

Returns
-------
metadata_clades : polars.LazyFrame
Nextstrain sequence_metadata with an additional column for clade assignments
metadata_clades : Clade
A Clade object that contains detailed and summarized information
about clades assigned to the sequences in sequence_metadata.
"""
assignment_date = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
if output_file is not None:
output_file = Path(output_file)
else:
Expand All @@ -249,7 +252,7 @@ def assign_clades(self, sequence_metadata: pl.LazyFrame, output_file: str | None
msg,
category=CladeTimeSequenceWarning,
)
return pl.LazyFrame()
return Clade(meta={}, detail=pl.LazyFrame(), summary=pl.LazyFrame())

# if there are many sequences in the filtered metadata, warn that clade assignment will
# take a long time and require a lot of resources
Expand Down Expand Up @@ -307,7 +310,23 @@ def assign_clades(self, sequence_metadata: pl.LazyFrame, output_file: str | None

assigned_clades = pl.read_csv(assignments, separator=";", infer_schema_length=100000)

# join the assigned clades with the original sequence metadata, create a summarized LazyFrame
# of clade counts by location, date, and host, and return both (along with metadata) in a
# Clade object
assigned_clades = sequence_metadata.join(
assigned_clades.lazy(), left_on="strain", right_on="seqName", how="left"
)
return assigned_clades
summarized_clades = sequence.summarize_clades(
assigned_clades, group_by=["location", "date", "host", "clade_nextstrain", "country"]
)
metadata = {
"sequence_as_of": self.sequence_as_of,
"tree_as_of": self.tree_as_of,
"nextclade_dataset_version": tree.ncov_metadata.get("nextclade_dataset_version"),
"nextclade_dataset_name": tree.ncov_metadata.get("nextclade_dataset_name"),
"nextclade_version_num": tree.ncov_metadata.get("nextclade_version_num"),
"assignment_as_of": assignment_date,
}
metadata_clades = Clade(meta=metadata, detail=assigned_clades, summary=summarized_clades)

return metadata_clades
60 changes: 59 additions & 1 deletion src/cladetime/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import lzma
import os
import re
import warnings
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse
Expand All @@ -14,6 +15,7 @@
from Bio.SeqIO import FastaIO
from requests import Session

from cladetime.exceptions import CladeTimeSequenceWarning
from cladetime.types import StateFormat
from cladetime.util.reference import _get_date
from cladetime.util.session import _get_session
Expand Down Expand Up @@ -258,7 +260,12 @@ def filter_metadata(


def get_clade_counts(filtered_metadata: pl.LazyFrame) -> pl.LazyFrame:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left this here because the variant-nowcast-hub scripts still reference it. It's replacement (summarize clades) does the same thing but:

  • has a better name
  • allows a configurable list of group_by columns

"""Return a count of clades by location and date."""
"""Return a count of clades by location and date.

Notes:
------
Deprecated in favor of summarize_clades
"""

cols = [
"clade",
Expand All @@ -273,6 +280,57 @@ def get_clade_counts(filtered_metadata: pl.LazyFrame) -> pl.LazyFrame:
return counts


def summarize_clades(sequence_metadata: pl.LazyFrame, group_by: list | None = None) -> pl.LazyFrame:
"""Return clade counts summarized by specific sequence metadata columns.

Parameters
----------
sequence_metadata : :class:`polars.DataFrame` or :class:`polars.LazyFrame`
A Polars DataFrame or LazyFrame that represents
Nextstrain SARS-CoV-2 sequence metadata
group_by : list
Optional. A list of columns to group the clade counts by. Defaults
to ["clade_nextstrain", "country", "date", "location", "host"]

Returns
-------
:class:`polars.DataFrame` | :class:`polars.LazyFrame`
A Frame that summarizes clade counts by the specified columns. If sequence_metadata
is a LazyFrame, returns a LazyFrame. Otherwise, returns a DataFrame.

Raises
------
CladeTimeSequenceWarning
If group_by contains a column name that is not in sequence_metadata or
if group_by contains a column named 'count'
"""
if group_by is None:
group_by = ["clade_nextstrain", "country", "date", "location", "host"]

# Validate group_by columns
metadata_cols = sequence_metadata.collect_schema().names()
warning_msg = ""
if not all(col in metadata_cols for col in group_by):
warning_msg = warning_msg + f"Invalid group_by columns: {group_by} \n"
if "count" in group_by:
warning_msg = warning_msg + "Group_by cannot contain 'count' column \n"
if len(warning_msg) > 0:
warnings.warn(
warning_msg[0],
category=CladeTimeSequenceWarning,
)
if isinstance(sequence_metadata, pl.LazyFrame):
return pl.LazyFrame()
else:
return pl.DataFrame()

counts = (
sequence_metadata.select(group_by).group_by(group_by).agg(pl.len().alias("count")).cast({"count": pl.UInt32})
)

return counts


def get_metadata_ids(sequence_metadata: pl.DataFrame | pl.LazyFrame) -> set:
"""Return sequence IDs for a specified set of Nextstrain sequence metadata.

Expand Down
69 changes: 59 additions & 10 deletions tests/integration/test_cladetime_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,44 @@ def test_cladetime_assign_clades(tmp_path, metadata_100k):
assigned_clades = ct.assign_clades(metadata_filtered, output_file=assignment_file)

# clade assignments via cladetime should match the original clade assignments
check_clade_assignments = original_clade_assignments.join(assigned_clades, on=["strain", "clade"]).collect()
check_clade_assignments = original_clade_assignments.join(
assigned_clades.detail, on=["strain", "clade"]
).collect()
assert len(check_clade_assignments) == len(metadata_filtered.collect())
unmatched_clade_count = check_clade_assignments.filter(pl.col("clade").is_null()).shape[0]
assert unmatched_clade_count == 0

# summarized clade assignments should also match summarized clade assignments from the
# original metadata file
assert_frame_equal(
sequence.summarize_clades(metadata_filtered.rename({"clade": "clade_nextstrain"})),
assigned_clades.summary,
check_column_order=False,
check_row_order=False,
)

# metadata should reflect ncov metadata as of 2024-11-01
assert assigned_clades.meta.get("sequence_as_of") == datetime(2024, 11, 1, tzinfo=timezone.utc)
assert assigned_clades.meta.get("tree_as_of") == datetime(2024, 11, 1, tzinfo=timezone.utc)
assert assigned_clades.meta.get("nextclade_dataset_version") == "2024-10-17--16-48-48Z"
assert assigned_clades.meta.get("nextclade_version_num") == "3.9.1"
assert assigned_clades.meta.get("assignment_as_of") == "2024-11-01 00:00"


@pytest.mark.skipif(not docker_enabled, reason="Docker is not installed")
def test_assign_old_tree(test_file_path, tmp_path, test_sequences):
sequence_file, sequence_set = test_sequences
sequence_list = list(sequence_set)
sequence_list.sort()

fasta_mock = MagicMock(return_value=test_file_path / sequence_file, name="cladetime.sequence.filter")
test_filtered_metadata = {"date": ["2022-01-01", "2022-01-02", "2023-12-27"], "strain": list(sequence_set)}
test_filtered_metadata = {
"country": ["USA", "USA", "USA"],
"date": ["2022-01-02", "2022-01-02", "2023-02-01"],
"host": ["Homo sapiens", "Homo sapiens", "Homo sapiens"],
"location": ["Hawaii", "Hawaii", "Utah"],
"strain": sequence_list,
}
metadata_filtered = pl.LazyFrame(test_filtered_metadata)

# expected clade assignments for 2024-08-02 (as retrieved from Nextrain metadata)
Expand All @@ -89,17 +115,38 @@ def test_assign_old_tree(test_file_path, tmp_path, test_sequences):
ct_current_tree = CladeTime()
with patch("cladetime.sequence.filter", fasta_mock):
current_assigned_clades = ct_current_tree.assign_clades(metadata_filtered, output_file=current_file)
current_assigned_clades = current_assigned_clades.select(["strain", "clade"]).collect()
current_assigned_clades = current_assigned_clades.detail.select(["strain", "clade"]).collect()

old_file = tmp_path / "old_assignments.csv"
ct_old_tree = CladeTime(tree_as_of="2024-08-02")
with patch("cladetime.sequence.filter", fasta_mock):
old_assigned_clades = ct_old_tree.assign_clades(metadata_filtered, output_file=old_file)
old_assigned_clades = old_assigned_clades.select(["strain", "clade"]).collect()
old_assigned_clade_detail = old_assigned_clades.detail.select(["strain", "clade"]).collect()

assert_frame_equal(current_assigned_clades.select("strain"), old_assigned_clade_detail.select("strain"))
assert_frame_not_equal(current_assigned_clades.select("clade"), old_assigned_clade_detail.select("clade"))
assert_frame_equal(old_assigned_clade_detail.sort("strain"), expected_assignments.sort("strain"))

expected_summary = pl.DataFrame(
{
"clade_nextstrain": ["24B", "24C"],
"country": ["USA", "USA"],
"date": ["2022-01-02", "2023-02-01"],
"host": ["Homo sapiens", "Homo sapiens"],
"location": ["Hawaii", "Utah"],
"count": [2, 1],
}
).cast({"count": pl.UInt32})
assert_frame_equal(
expected_summary, old_assigned_clades.summary.collect(), check_column_order=False, check_row_order=False
)

assert_frame_equal(current_assigned_clades.select("strain"), old_assigned_clades.select("strain"))
assert_frame_not_equal(current_assigned_clades.select("clade"), old_assigned_clades.select("clade"))
assert_frame_equal(old_assigned_clades.sort("strain"), expected_assignments.sort("strain"))
# metadata should reflect ncov metadata as of 2024-11-01
assert old_assigned_clades.meta.get("sequence_as_of") == datetime(2024, 11, 1, tzinfo=timezone.utc)
assert old_assigned_clades.meta.get("tree_as_of") == datetime(2024, 8, 2, tzinfo=timezone.utc)
assert old_assigned_clades.meta.get("nextclade_dataset_version") == "2024-07-17--12-57-03Z"
assert old_assigned_clades.meta.get("nextclade_version_num") == "3.8.2"
assert old_assigned_clades.meta.get("assignment_as_of") == "2024-11-01 00:00"


@pytest.mark.skipif(not docker_enabled, reason="Docker is not installed")
Expand Down Expand Up @@ -128,7 +175,7 @@ def test_assign_date_filters(test_file_path, tmp_path, test_sequences, min_date,
assignment_file = tmp_path / "assignments.csv"
with patch("cladetime.sequence.filter", fasta_mock):
assigned_clades = ct.assign_clades(metadata_filtered, output_file=assignment_file)
assert len(assigned_clades.collect()) == expected_rows
assert len(assigned_clades.detail.collect()) == expected_rows


def test_assign_too_many_sequences_warning(tmp_path, test_file_path, test_sequences):
Expand All @@ -143,7 +190,7 @@ def test_assign_too_many_sequences_warning(tmp_path, test_file_path, test_sequen
with pytest.warns(CladeTimeSequenceWarning):
assignments = ct.assign_clades(metadata_filtered, output_file=tmp_path / "assignments.csv")
# clade assignment should proceed, despite the warning
assert len(assignments.collect()) == 3
assert len(assignments.detail.collect()) == 3


def test_assign_clades_no_sequences():
Expand All @@ -152,4 +199,6 @@ def test_assign_clades_no_sequences():
assignments = ct.assign_clades(
pl.LazyFrame(),
)
assert assignments.collect().shape == (0, 0)
assert assignments.detail.collect().shape == (0, 0)
assert assignments.summary.collect().shape == (0, 0)
assert assignments.meta == {}
96 changes: 96 additions & 0 deletions tests/unit/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import polars as pl
import pytest
from Bio import SeqIO
from polars.testing import assert_frame_equal

from cladetime import sequence
from cladetime.exceptions import CladeTimeSequenceWarning
from cladetime.types import StateFormat


Expand Down Expand Up @@ -247,3 +249,97 @@ def test_filter_empty_fasta(tmpdir):
seq_filtered = sequence.filter(test_sequence_set, "http://thisismocked.com", tmpdir)
contents = seq_filtered.read_text(encoding=None)
assert len(contents) == 0


def test_summarize_clades():
test_metadata = pl.DataFrame(
{
"clade_nextstrain": ["11C", "11C", "11C"],
"country": ["USA", "USA", "USA"],
"date": ["2022-01-01", "2022-01-01", "2023-12-27"],
"host": ["Homo sapiens", "Homo sapiens", "Homo sapiens"],
"location": ["Utah", "Utah", "Utah"],
"strain": ["abc/123", "abc/456", "def/123"],
"wombat_count": [2, 22, 222],
}
)

expected_summary = pl.DataFrame(
{
"clade_nextstrain": ["11C", "11C"],
"country": ["USA", "USA"],
"date": ["2022-01-01", "2023-12-27"],
"host": ["Homo sapiens", "Homo sapiens"],
"location": ["Utah", "Utah"],
"count": [2, 1],
}
).cast({"count": pl.UInt32})

summarized = sequence.summarize_clades(test_metadata)
assert_frame_equal(expected_summary, summarized, check_column_order=False, check_row_order=False)


def test_summarize_clades_custom_group():
test_metadata = pl.LazyFrame(
{
"clade_nextstrain": ["11C", "11C", "11C"],
"country": ["Canada", "USA", "USA"],
"date": ["2022-01-01", "2022-01-01", "2023-12-27"],
"host": ["Homo sapiens", "Homo sapiens", "Homo sapiens"],
"location": ["Utah", "Utah", "Utah"],
"strain": ["abc/123", "abc/456", "def/123"],
"wombat_count": [2, 22, 22],
}
)

expected_summary = pl.LazyFrame(
{
"country": ["Canada", "USA"],
"wombat_count": [2, 22],
"count": [1, 2],
}
).cast({"count": pl.UInt32})

summarized = sequence.summarize_clades(test_metadata, group_by=["country", "wombat_count"])
assert_frame_equal(expected_summary, summarized, check_column_order=False, check_row_order=False)

test_metadata = pl.LazyFrame(
{
"clade_nextstrain": ["11C", "11C", "11C"],
"country": ["Canada", "USA", "USA"],
"date": ["2022-01-01", "2022-01-01", "2023-12-27"],
}
)

expected_summary = pl.LazyFrame(
{
"clade_nextstrain": ["11C"],
"count": [3],
}
).cast({"count": pl.UInt32})

summarized = sequence.summarize_clades(test_metadata, group_by=["clade_nextstrain"])
assert_frame_equal(expected_summary, summarized, check_column_order=False, check_row_order=False)


def test_summarize_clades_invalid_cols():
test_metadata = pl.DataFrame(
{
"clade_nextstrain": ["11C", "11C", "11C"],
"country": ["Canada", "USA", "USA"],
"date": ["2022-01-01", "2022-01-01", "2023-12-27"],
}
)
with pytest.warns(CladeTimeSequenceWarning):
summarized = sequence.summarize_clades(test_metadata, group_by=["country", "wombat_count"])
assert len(summarized) == 0

test_metadata = pl.DataFrame(
{
"clade_nextstrain": ["11C", "11C", "11C"],
"count": [1, 2, 3],
}
)
with pytest.warns(CladeTimeSequenceWarning):
summarized = sequence.summarize_clades(test_metadata, group_by=["clade_nextstrain", "count"])
assert len(summarized) == 0