Skip to content

Commit

Permalink
Remove column validations from summarize_clades
Browse files Browse the repository at this point in the history
Summarize_clades was doing a collect_schema operation to ensure
that items in the group_by paramater exist as colums in the
sequence metadata. However, testing on a lower-memory laptop
revealed that collect_schema was too memory intensive to
introduce in the middle of LazyFrame handling, so we'll
take it out.
  • Loading branch information
bsweger committed Nov 14, 2024
1 parent 7e17fa2 commit 052c0a2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 48 deletions.
31 changes: 7 additions & 24 deletions src/cladetime/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import lzma
import os
import re
import warnings
from datetime import datetime
from pathlib import Path
from urllib.parse import urlparse
Expand All @@ -15,7 +14,6 @@
from Bio.SeqIO import FastaIO
from requests import Session

from cladetime.exceptions import CladeTimeSequenceWarning
from cladetime.types import StateFormat
from cladetime.util.reference import _get_date
from cladetime.util.session import _get_session
Expand Down Expand Up @@ -298,32 +296,17 @@ def summarize_clades(sequence_metadata: pl.LazyFrame, group_by: list | None = No
A Frame that summarizes clade counts by the specified columns. If sequence_metadata
is a LazyFrame, returns a LazyFrame. Otherwise, returns a DataFrame.
Raises
------
CladeTimeSequenceWarning
If group_by contains a column name that is not in sequence_metadata or
if group_by contains a column named 'count'
Notes
-----
This function does not validate the group_by columns because doing so on a
large LazyFrame would involve a memory-intensive collect_schema operation.
If the group_by columns are not in the sequence metadata, this function
will succeed, but a subsequent collect() on the returned LazyFrame will
result in an error.
"""
if group_by is None:
group_by = ["clade_nextstrain", "country", "date", "location", "host"]

# Validate group_by columns
metadata_cols = sequence_metadata.collect_schema().names()
warning_msg = ""
if not all(col in metadata_cols for col in group_by):
warning_msg = warning_msg + f"Invalid group_by columns: {group_by} \n"
if "count" in group_by:
warning_msg = warning_msg + "Group_by cannot contain 'count' column \n"
if len(warning_msg) > 0:
warnings.warn(
warning_msg[0],
category=CladeTimeSequenceWarning,
)
if isinstance(sequence_metadata, pl.LazyFrame):
return pl.LazyFrame()
else:
return pl.DataFrame()

counts = (
sequence_metadata.select(group_by).group_by(group_by).agg(pl.len().alias("count")).cast({"count": pl.UInt32})
)
Expand Down
24 changes: 0 additions & 24 deletions tests/unit/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from polars.testing import assert_frame_equal

from cladetime import sequence
from cladetime.exceptions import CladeTimeSequenceWarning
from cladetime.types import StateFormat


Expand Down Expand Up @@ -320,26 +319,3 @@ def test_summarize_clades_custom_group():

summarized = sequence.summarize_clades(test_metadata, group_by=["clade_nextstrain"])
assert_frame_equal(expected_summary, summarized, check_column_order=False, check_row_order=False)


def test_summarize_clades_invalid_cols():
test_metadata = pl.DataFrame(
{
"clade_nextstrain": ["11C", "11C", "11C"],
"country": ["Canada", "USA", "USA"],
"date": ["2022-01-01", "2022-01-01", "2023-12-27"],
}
)
with pytest.warns(CladeTimeSequenceWarning):
summarized = sequence.summarize_clades(test_metadata, group_by=["country", "wombat_count"])
assert len(summarized) == 0

test_metadata = pl.DataFrame(
{
"clade_nextstrain": ["11C", "11C", "11C"],
"count": [1, 2, 3],
}
)
with pytest.warns(CladeTimeSequenceWarning):
summarized = sequence.summarize_clades(test_metadata, group_by=["clade_nextstrain", "count"])
assert len(summarized) == 0

0 comments on commit 052c0a2

Please sign in to comment.