diff --git a/src/cladetime/sequence.py b/src/cladetime/sequence.py index 922e0fc..7b31df9 100644 --- a/src/cladetime/sequence.py +++ b/src/cladetime/sequence.py @@ -3,7 +3,6 @@ import lzma import os import re -import warnings from datetime import datetime from pathlib import Path from urllib.parse import urlparse @@ -15,7 +14,6 @@ from Bio.SeqIO import FastaIO from requests import Session -from cladetime.exceptions import CladeTimeSequenceWarning from cladetime.types import StateFormat from cladetime.util.reference import _get_date from cladetime.util.session import _get_session @@ -298,32 +296,17 @@ def summarize_clades(sequence_metadata: pl.LazyFrame, group_by: list | None = No A Frame that summarizes clade counts by the specified columns. If sequence_metadata is a LazyFrame, returns a LazyFrame. Otherwise, returns a DataFrame. - Raises - ------ - CladeTimeSequenceWarning - If group_by contains a column name that is not in sequence_metadata or - if group_by contains a column named 'count' + Notes + ----- + This function does not validate the group_by columns because doing so on a + large LazyFrame would involve a memory-intensive collect_schema operation. + If the group_by columns are not in the sequence metadata, this function + will succeed, but a subsequent collect() on the returned LazyFrame will + result in an error. """ if group_by is None: group_by = ["clade_nextstrain", "country", "date", "location", "host"] - # Validate group_by columns - metadata_cols = sequence_metadata.collect_schema().names() - warning_msg = "" - if not all(col in metadata_cols for col in group_by): - warning_msg = warning_msg + f"Invalid group_by columns: {group_by} \n" - if "count" in group_by: - warning_msg = warning_msg + "Group_by cannot contain 'count' column \n" - if len(warning_msg) > 0: - warnings.warn( - warning_msg[0], - category=CladeTimeSequenceWarning, - ) - if isinstance(sequence_metadata, pl.LazyFrame): - return pl.LazyFrame() - else: - return pl.DataFrame() - counts = ( sequence_metadata.select(group_by).group_by(group_by).agg(pl.len().alias("count")).cast({"count": pl.UInt32}) ) diff --git a/tests/unit/test_sequence.py b/tests/unit/test_sequence.py index 77865e9..f79d988 100644 --- a/tests/unit/test_sequence.py +++ b/tests/unit/test_sequence.py @@ -9,7 +9,6 @@ from polars.testing import assert_frame_equal from cladetime import sequence -from cladetime.exceptions import CladeTimeSequenceWarning from cladetime.types import StateFormat @@ -320,26 +319,3 @@ def test_summarize_clades_custom_group(): summarized = sequence.summarize_clades(test_metadata, group_by=["clade_nextstrain"]) assert_frame_equal(expected_summary, summarized, check_column_order=False, check_row_order=False) - - -def test_summarize_clades_invalid_cols(): - test_metadata = pl.DataFrame( - { - "clade_nextstrain": ["11C", "11C", "11C"], - "country": ["Canada", "USA", "USA"], - "date": ["2022-01-01", "2022-01-01", "2023-12-27"], - } - ) - with pytest.warns(CladeTimeSequenceWarning): - summarized = sequence.summarize_clades(test_metadata, group_by=["country", "wombat_count"]) - assert len(summarized) == 0 - - test_metadata = pl.DataFrame( - { - "clade_nextstrain": ["11C", "11C", "11C"], - "count": [1, 2, 3], - } - ) - with pytest.warns(CladeTimeSequenceWarning): - summarized = sequence.summarize_clades(test_metadata, group_by=["clade_nextstrain", "count"]) - assert len(summarized) == 0