diff --git a/docs/reference/filter_covid_genome_metadata.rst b/docs/reference/filter_covid_genome_metadata.rst deleted file mode 100644 index d52c80f..0000000 --- a/docs/reference/filter_covid_genome_metadata.rst +++ /dev/null @@ -1,6 +0,0 @@ -==================================================== -cladetime.util.sequence.filter_covid_genome_metadata -==================================================== - -.. autofunction:: cladetime.util.sequence.filter_covid_genome_metadata - diff --git a/docs/reference/index.rst b/docs/reference/index.rst index b1ea9ff..2ddb1d7 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -4,6 +4,6 @@ API Reference .. toctree:: cladetime - filter_covid_genome_metadata + sequence types diff --git a/docs/reference/sequence.rst b/docs/reference/sequence.rst new file mode 100644 index 0000000..49f4800 --- /dev/null +++ b/docs/reference/sequence.rst @@ -0,0 +1,6 @@ +========= +sequence +========= + +.. autofunction:: cladetime.sequence.filter_sequence_metadata + diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 75f8178..a577dd1 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -1,5 +1,5 @@ ===== -Types +types ===== diff --git a/src/cladetime/assign_clades.py b/src/cladetime/assign_clades.py index 0bd5a76..6aaec51 100644 --- a/src/cladetime/assign_clades.py +++ b/src/cladetime/assign_clades.py @@ -8,13 +8,13 @@ import rich_click as click import structlog -from cladetime.util.config import Config -from cladetime.util.reference import get_nextclade_dataset -from cladetime.util.sequence import ( +from cladetime.sequence import ( _unzip_sequence_package, get_covid_genome_data, parse_sequence_assignments, ) +from cladetime.util.config import Config +from cladetime.util.reference import get_nextclade_dataset logger = structlog.get_logger() diff --git a/src/cladetime/cladetime.py b/src/cladetime/cladetime.py index 47a2589..91c300f 100644 --- a/src/cladetime/cladetime.py +++ b/src/cladetime/cladetime.py @@ -7,9 +7,9 @@ import structlog from cladetime.exceptions import CladeTimeFutureDateWarning, CladeTimeInvalidDateError, CladeTimeInvalidURLError +from cladetime.sequence import _get_ncov_metadata, get_covid_genome_metadata from cladetime.util.config import Config from cladetime.util.reference import _get_s3_object_url -from cladetime.util.sequence import _get_ncov_metadata, get_covid_genome_metadata logger = structlog.get_logger() diff --git a/src/cladetime/get_clade_list.py b/src/cladetime/get_clade_list.py index a5e63be..06ba22b 100644 --- a/src/cladetime/get_clade_list.py +++ b/src/cladetime/get_clade_list.py @@ -7,13 +7,13 @@ import structlog from cloudpathlib import AnyPath -from cladetime.util.config import Config -from cladetime.util.sequence import ( +from cladetime.sequence import ( download_covid_genome_metadata, - filter_covid_genome_metadata, + filter_sequence_metadata, get_clade_counts, get_covid_genome_metadata, ) +from cladetime.util.config import Config from cladetime.util.session import _get_session from cladetime.util.timing import time_function @@ -107,7 +107,7 @@ def main( data_dir, ) lf_metadata = get_covid_genome_metadata(genome_metadata_path) - lf_metadata_filtered = filter_covid_genome_metadata(lf_metadata) + lf_metadata_filtered = filter_sequence_metadata(lf_metadata) counts = get_clade_counts(lf_metadata_filtered) clade_list = get_clades(counts, threshold, threshold_weeks, max_clades) diff --git a/src/cladetime/sequence.py b/src/cladetime/sequence.py new file mode 100644 index 0000000..83a5be3 --- /dev/null +++ b/src/cladetime/sequence.py @@ -0,0 +1,331 @@ +"""Functions for retrieving and parsing SARS-CoV-2 virus genome data.""" + +import json +import lzma +import zipfile +from datetime import datetime, timezone +from pathlib import Path + +import polars as pl +import structlog +import us +from requests import Session + +from cladetime.types import StateFormat +from cladetime.util.reference import _get_s3_object_url +from cladetime.util.session import _check_response, _get_session +from cladetime.util.timing import time_function + +logger = structlog.get_logger() + + +@time_function +def get_covid_genome_data(released_since_date: str, base_url: str, filename: str): + """ + Download genome data package from NCBI. + FIXME: Download the Nextclade-processed GenBank sequence data (which originates from NCBI) + from https://data.nextstrain.org/files/ncov/open/sequences.fasta.zst instead of using + the NCBI API. + """ + headers = { + "Accept": "application/zip", + } + session = _get_session() + session.headers.update(headers) + + # TODO: this might be a better as an item in the forthcoming config file + request_body = { + "released_since": released_since_date, + "taxon": "SARS-CoV-2", + "refseq_only": False, + "annotated_only": False, + "host": "Homo sapiens", + "complete_only": False, + "table_fields": ["unspecified"], + "include_sequence": ["GENOME"], + "aux_report": ["DATASET_REPORT"], + "format": "tsv", + "use_psg": False, + } + + logger.info("NCBI API call starting", released_since_date=released_since_date) + + response = session.post(base_url, data=json.dumps(request_body), timeout=(300, 300)) + _check_response(response) + + # Originally tried saving the NCBI package via a stream call and iter_content (to prevent potential + # memory issues that can arise when download large files). However, ran into an intermittent error: + # ChunkedEncodingError(ProtocolError('Response ended prematurely'). + # We may need to revisit this at some point, depending on how much data we place to request via the + # API and what kind of machine the pipeline will run on. + with open(filename, "wb") as f: + f.write(response.content) + + +@time_function +def download_covid_genome_metadata( + session: Session, bucket: str, key: str, data_path: Path, as_of: str | None = None, use_existing: bool = False +) -> Path: + """Download the latest GenBank genome metadata data from Nextstrain.""" + + if as_of is None: + as_of_datetime = datetime.now().replace(tzinfo=timezone.utc) + else: + as_of_datetime = datetime.strptime(as_of, "%Y-%m-%d").replace(tzinfo=timezone.utc) + + (s3_version, s3_url) = _get_s3_object_url(bucket, key, as_of_datetime) + filename = data_path / f"{as_of_datetime.date().strftime('%Y-%m-%d')}-{Path(key).name}" + + if use_existing and filename.exists(): + logger.info("using existing genome metadata file", metadata_file=str(filename)) + return filename + + logger.info("starting genome metadata download", source=s3_url, destination=str(filename)) + with session.get(s3_url, stream=True) as result: + result.raise_for_status() + with open(filename, "wb") as f: + for chunk in result.iter_content(chunk_size=None): + f.write(chunk) + + return filename + + +def get_covid_genome_metadata( + metadata_path: Path | None = None, metadata_url: str | None = None, num_rows: int | None = None +) -> pl.LazyFrame: + """ + Read GenBank genome metadata into a Polars LazyFrame. + + Parameters + ---------- + metadata_path : Path | None + Path to location of a NextStrain GenBank genome metadata file. + Cannot be used with metadata_url. + metadata_url: str | None + URL to a NextStrain GenBank genome metadata file. + Cannot be used with metadata_path. + num_rows : int | None, default = None + The number of genome metadata rows to request. + When not supplied, request all rows. + """ + + path_flag = metadata_path is not None + url_flag = metadata_url is not None + + assert path_flag + url_flag == 1, "Specify metadata_path or metadata_url, but not both." + + if metadata_url: + metadata = pl.scan_csv(metadata_url, separator="\t", n_rows=num_rows) + return metadata + + if metadata_path: + if (compression_type := metadata_path.suffix) in [".tsv", ".zst"]: + metadata = pl.scan_csv(metadata_path, separator="\t", n_rows=num_rows) + elif compression_type == ".xz": + metadata = pl.read_csv( + lzma.open(metadata_path), separator="\t", n_rows=num_rows, infer_schema_length=100000 + ).lazy() + + return metadata + + +def _get_ncov_metadata( + url_ncov_metadata: str, + session: Session | None = None, +) -> dict: + """Return metadata emitted by the Nextstrain ncov pipeline.""" + if not session: + session = _get_session(retry=False) + + response = session.get(url_ncov_metadata) + if not response.ok: + logger.warn( + "Failed to retrieve ncov metadata", + status_code=response.status_code, + response_text=response.text, + request=response.request.url, + request_body=response.request.body, + ) + return {} + + metadata = response.json() + if metadata.get("nextclade_dataset_name", "").lower() == "sars-cov-2": + metadata["nextclade_dataset_name_full"] = "nextstrain/sars-cov-2/wuhan-hu-1/orfs" + + return metadata + + +def filter_sequence_metadata( + metadata: pl.LazyFrame, cols: list | None = None, state_format: StateFormat = StateFormat.ABBR +) -> pl.LazyFrame: + """Apply standard filters to Nextstrain's SARS-CoV-2 sequence metadata. + + A helper function to apply commonly-used filters to a Polars LazyFrame + that represents Nextstrain's SARS-CoV-2 sequence metadata. It filters + on human sequences from the United States (including Puerto Rico and + Washington, DC). + + This function also performs small transformations to the metadata, + such as casting the collection date to a date type, renaming columns, + and returning alternate state formats if requested. + + Parameters + ---------- + metadata : :class:`polars.LazyFrame` + The :attr:`cladetime.CladeTime.url_sequence_metadata` + attribute of a :class:`cladetime.CladeTime` object. This parameter + represents SARS-CoV-2 sequence metadata produced by Nextstrain + as an intermediate file in their daily workflow + cols : list + Optional. A list of columns to include in the filtered metadata. + The default columns included in the filtered metadata are: + clade_nextstrain, country, date, division, genbank_accession, + genbank_accession_rev, host + state_format : :class:`cladetime.types.StateFormat` + Optional. The state name format returned in the filtered metadata's + location column. Defaults to `StateFormat.ABBR` + + Returns + ------- + :class:`polars.LazyFrame` + A Polars LazyFrame that represents the filtered SARS-CoV-2 sequence + metadata. + + Raises + ------ + ValueError + If the state_format parameter is not a valid + :class:`cladetime.types.StateFormat`. + + Notes + ----- + This function will filter out metadata rows with invalid state names or + date strings that cannot be cast to a Polars date format. + + Example: + -------- + >>> from cladetime import CladeTime + >>> from cladetime.sequence import filter_covid_genome_metadata + + Apply common filters to the sequence metadata of a CladeTime object: + + >>> ct = CladeTime(seq_as_of="2024-10-15") + >>> ct = CladeTime(sequence_as_of="2024-10-15") + >>> filtered_metadata = filter_covid_genome_metadata(ct.sequence_metadata) + >>> filtered_metadata.collect().head(5) + shape: (5, 7) + ┌───────┬─────────┬────────────┬────────────┬────────────┬──────────────┬──────┬ + │ clade ┆ country ┆ date ┆ genbank_ ┆ genbank_ac ┆ host ┆ loca │ + │ ┆ ┆ ┆ accession ┆ cession_rev┆ ┆ tion │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ date ┆ str ┆ str ┆ str ┆ str │ + │ ┆ ┆ ┆ ┆ ┆ ┆ │ + ╞═══════╪═════════╪════════════╪════════════╪════════════╪══════════════╪══════╡ + │ 22A ┆ USA ┆ 2022-07-07 ┆ PP223234 ┆ PP223234.1 ┆ Homo sapiens ┆ AL │ + │ 22B ┆ USA ┆ 2022-07-02 ┆ PP223435 ┆ PP223435.1 ┆ Homo sapiens ┆ AZ │ + │ 22B ┆ USA ┆ 2022-07-19 ┆ PP223235 ┆ PP223235.1 ┆ Homo sapiens ┆ AZ │ + │ 22B ┆ USA ┆ 2022-07-15 ┆ PP223236 ┆ PP223236.1 ┆ Homo sapiens ┆ AZ │ + │ 22B ┆ USA ┆ 2022-07-20 ┆ PP223237 ┆ PP223237.1 ┆ Homo sapiens ┆ AZ │ + └───────┴─────────┴────────────┴────────────┴────────────┴─────────────────────┴ + """ + if state_format not in StateFormat: + raise ValueError(f"Invalid state_format. Must be one of: {list(StateFormat.__members__.items())}") + + # Default columns to include in the filtered metadata + if not cols: + cols = [ + "clade_nextstrain", + "country", + "date", + "division", + "genbank_accession", + "genbank_accession_rev", + "host", + ] + + # There are some other odd divisions in the data, but these are 50 states, DC and PR + states = [state.name for state in us.states.STATES] + states.extend(["Washington DC", "District of Columbia", "Puerto Rico"]) + + # Filter dataset and do some general tidying + filtered_metadata = ( + metadata.select(cols) + .filter( + pl.col("country") == "USA", + pl.col("division").is_in(states), + pl.col("host") == "Homo sapiens", + ) + .rename({"clade_nextstrain": "clade"}) + .cast({"date": pl.Date}, strict=False) + # date filtering at the end ensures we filter out null + # values created by the above .cast operation + .filter( + pl.col("date").is_not_null(), + ) + ) + + # Create state mappings based on state_format parameter, including a DC alias, since + # Nextrain's metadata uses a different name than the us package + if state_format == StateFormat.FIPS: + state_dict = {state.name: state.fips for state in us.states.STATES_AND_TERRITORIES} + state_dict["Washington DC"] = us.states.DC.fips + elif state_format == StateFormat.ABBR: + state_dict = {state.name: state.abbr for state in us.states.STATES_AND_TERRITORIES} + state_dict["Washington DC"] = us.states.DC.abbr + else: + state_dict = {state.name: state.name for state in us.states.STATES_AND_TERRITORIES} + state_dict["Washington DC"] = "Washington DC" + + filtered_metadata = filtered_metadata.with_columns(pl.col("division").replace(state_dict).alias("location")).drop( + "division" + ) + + return filtered_metadata + + +def get_clade_counts(filtered_metadata: pl.LazyFrame) -> pl.LazyFrame: + """Return a count of clades by location and date.""" + + cols = [ + "clade", + "country", + "date", + "location", + "host", + ] + + counts = filtered_metadata.select(cols).group_by("location", "date", "clade").agg(pl.len().alias("count")) + + return counts + + +def _unzip_sequence_package(filename: Path, data_path: Path): + """Unzip the downloaded virus genome data package.""" + with zipfile.ZipFile(filename, "r") as package_zip: + zip_contents = package_zip.namelist() + is_metadata = next((s for s in zip_contents if "data_report" in s), None) + is_sequence = next((s for s in zip_contents if "genomic" in s), None) + if is_metadata and is_sequence: + package_zip.extractall(data_path) + else: + logger.error("NCBI package is missing expected files", zip_contents=zip_contents) + # Exit the pipeline without displaying a traceback + raise SystemExit("Error downloading NCBI package") + + +def parse_sequence_assignments(df_assignments: pl.DataFrame) -> pl.DataFrame: + """Parse out the sequence number from the seqName column returned by the clade assignment tool.""" + + # polars apparently can't split out the sequence number from that big name column + # without resorting an apply, so here we're dropping into pandas to do that + # (might be a premature optimization, since this manoever requires both pandas and pyarrow) + seq = pl.from_pandas(df_assignments.to_pandas()["seqName"].str.split(" ").str[0].rename("seq")) + + # we're expecting one row per sequence + if seq.n_unique() != df_assignments.shape[0]: + raise ValueError("Clade assignment data contains duplicate sequence. Stopping assignment process.") + + # add the parsed sequence number as a new column + df_assignments = df_assignments.insert_column(1, seq) # type: ignore + + return df_assignments diff --git a/src/cladetime/util/sequence.py b/src/cladetime/util/sequence.py index 77969d2..71b55e3 100644 --- a/src/cladetime/util/sequence.py +++ b/src/cladetime/util/sequence.py @@ -1,331 +1,6 @@ -"""Functions for retrieving and parsing SARS-CoV-2 virus genome data.""" +"""cladetime.util.sequence moved to cladetime.sequence.""" -import json -import lzma -import zipfile -from datetime import datetime, timezone -from pathlib import Path - -import polars as pl -import structlog -import us -from requests import Session - -from cladetime.types import StateFormat -from cladetime.util.reference import _get_s3_object_url -from cladetime.util.session import _check_response, _get_session -from cladetime.util.timing import time_function - -logger = structlog.get_logger() - - -@time_function -def get_covid_genome_data(released_since_date: str, base_url: str, filename: str): - """ - Download genome data package from NCBI. - FIXME: Download the Nextclade-processed GenBank sequence data (which originates from NCBI) - from https://data.nextstrain.org/files/ncov/open/sequences.fasta.zst instead of using - the NCBI API. - """ - headers = { - "Accept": "application/zip", - } - session = _get_session() - session.headers.update(headers) - - # TODO: this might be a better as an item in the forthcoming config file - request_body = { - "released_since": released_since_date, - "taxon": "SARS-CoV-2", - "refseq_only": False, - "annotated_only": False, - "host": "Homo sapiens", - "complete_only": False, - "table_fields": ["unspecified"], - "include_sequence": ["GENOME"], - "aux_report": ["DATASET_REPORT"], - "format": "tsv", - "use_psg": False, - } - - logger.info("NCBI API call starting", released_since_date=released_since_date) - - response = session.post(base_url, data=json.dumps(request_body), timeout=(300, 300)) - _check_response(response) - - # Originally tried saving the NCBI package via a stream call and iter_content (to prevent potential - # memory issues that can arise when download large files). However, ran into an intermittent error: - # ChunkedEncodingError(ProtocolError('Response ended prematurely'). - # We may need to revisit this at some point, depending on how much data we place to request via the - # API and what kind of machine the pipeline will run on. - with open(filename, "wb") as f: - f.write(response.content) - - -@time_function -def download_covid_genome_metadata( - session: Session, bucket: str, key: str, data_path: Path, as_of: str | None = None, use_existing: bool = False -) -> Path: - """Download the latest GenBank genome metadata data from Nextstrain.""" - - if as_of is None: - as_of_datetime = datetime.now().replace(tzinfo=timezone.utc) - else: - as_of_datetime = datetime.strptime(as_of, "%Y-%m-%d").replace(tzinfo=timezone.utc) - - (s3_version, s3_url) = _get_s3_object_url(bucket, key, as_of_datetime) - filename = data_path / f"{as_of_datetime.date().strftime('%Y-%m-%d')}-{Path(key).name}" - - if use_existing and filename.exists(): - logger.info("using existing genome metadata file", metadata_file=str(filename)) - return filename - - logger.info("starting genome metadata download", source=s3_url, destination=str(filename)) - with session.get(s3_url, stream=True) as result: - result.raise_for_status() - with open(filename, "wb") as f: - for chunk in result.iter_content(chunk_size=None): - f.write(chunk) - - return filename - - -def get_covid_genome_metadata( - metadata_path: Path | None = None, metadata_url: str | None = None, num_rows: int | None = None -) -> pl.LazyFrame: - """ - Read GenBank genome metadata into a Polars LazyFrame. - - Parameters - ---------- - metadata_path : Path | None - Path to location of a NextStrain GenBank genome metadata file. - Cannot be used with metadata_url. - metadata_url: str | None - URL to a NextStrain GenBank genome metadata file. - Cannot be used with metadata_path. - num_rows : int | None, default = None - The number of genome metadata rows to request. - When not supplied, request all rows. - """ - - path_flag = metadata_path is not None - url_flag = metadata_url is not None - - assert path_flag + url_flag == 1, "Specify metadata_path or metadata_url, but not both." - - if metadata_url: - metadata = pl.scan_csv(metadata_url, separator="\t", n_rows=num_rows) - return metadata - - if metadata_path: - if (compression_type := metadata_path.suffix) in [".tsv", ".zst"]: - metadata = pl.scan_csv(metadata_path, separator="\t", n_rows=num_rows) - elif compression_type == ".xz": - metadata = pl.read_csv( - lzma.open(metadata_path), separator="\t", n_rows=num_rows, infer_schema_length=100000 - ).lazy() - - return metadata - - -def _get_ncov_metadata( - url_ncov_metadata: str, - session: Session | None = None, -) -> dict: - """Return metadata emitted by the Nextstrain ncov pipeline.""" - if not session: - session = _get_session(retry=False) - - response = session.get(url_ncov_metadata) - if not response.ok: - logger.warn( - "Failed to retrieve ncov metadata", - status_code=response.status_code, - response_text=response.text, - request=response.request.url, - request_body=response.request.body, - ) - return {} - - metadata = response.json() - if metadata.get("nextclade_dataset_name", "").lower() == "sars-cov-2": - metadata["nextclade_dataset_name_full"] = "nextstrain/sars-cov-2/wuhan-hu-1/orfs" - - return metadata - - -def filter_covid_genome_metadata( - metadata: pl.LazyFrame, cols: list | None = None, state_format: StateFormat = StateFormat.ABBR -) -> pl.LazyFrame: - """Apply standard filters to Nextstrain's SARS-CoV-2 sequence metadata. - - A helper function to apply commonly-used filters to a Polars LazyFrame - that represents Nextstrain's SARS-CoV-2 sequence metadata. It filters - on human sequences from the United States (including Puerto Rico and - Washington, DC). - - This function also performs small transformations to the metadata, - such as casting the collection date to a date type, renaming columns, - and returning alternate state formats if requested. - - Parameters - ---------- - metadata : :class:`polars.LazyFrame` - The :attr:`cladetime.CladeTime.url_sequence_metadata` - attribute of a :class:`cladetime.CladeTime` object. This parameter - represents SARS-CoV-2 sequence metadata produced by Nextstrain - as an intermediate file in their daily workflow - cols : list - Optional. A list of columns to include in the filtered metadata. - The default columns included in the filtered metadata are: - clade_nextstrain, country, date, division, genbank_accession, - genbank_accession_rev, host - state_format : :class:`cladetime.types.StateFormat` - Optional. The state name format returned in the filtered metadata's - location column. Defaults to `StateFormat.ABBR` - - Returns - ------- - :class:`polars.LazyFrame` - A Polars LazyFrame that represents the filtered SARS-CoV-2 sequence - metadata. - - Raises - ------ - ValueError - If the state_format parameter is not a valid - :class:`cladetime.types.StateFormat`. - - Notes - ----- - This function will filter out metadata rows with invalid state names or - date strings that cannot be cast to a Polars date format. - - Example: - -------- - >>> from cladetime import CladeTime - >>> from cladetime.util.sequence import filter_covid_genome_metadata - - Apply common filters to the sequence metadata of a CladeTime object: - - >>> ct = CladeTime(seq_as_of="2024-10-15") - >>> ct = CladeTime(sequence_as_of="2024-10-15") - >>> filtered_metadata = filter_covid_genome_metadata(ct.sequence_metadata) - >>> filtered_metadata.collect().head(5) - shape: (5, 7) - ┌───────┬─────────┬────────────┬────────────┬────────────┬──────────────┬──────┬ - │ clade ┆ country ┆ date ┆ genbank_ ┆ genbank_ac ┆ host ┆ loca │ - │ ┆ ┆ ┆ accession ┆ cession_rev┆ ┆ tion │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ str ┆ date ┆ str ┆ str ┆ str ┆ str │ - │ ┆ ┆ ┆ ┆ ┆ ┆ │ - ╞═══════╪═════════╪════════════╪════════════╪════════════╪══════════════╪══════╡ - │ 22A ┆ USA ┆ 2022-07-07 ┆ PP223234 ┆ PP223234.1 ┆ Homo sapiens ┆ AL │ - │ 22B ┆ USA ┆ 2022-07-02 ┆ PP223435 ┆ PP223435.1 ┆ Homo sapiens ┆ AZ │ - │ 22B ┆ USA ┆ 2022-07-19 ┆ PP223235 ┆ PP223235.1 ┆ Homo sapiens ┆ AZ │ - │ 22B ┆ USA ┆ 2022-07-15 ┆ PP223236 ┆ PP223236.1 ┆ Homo sapiens ┆ AZ │ - │ 22B ┆ USA ┆ 2022-07-20 ┆ PP223237 ┆ PP223237.1 ┆ Homo sapiens ┆ AZ │ - └───────┴─────────┴────────────┴────────────┴────────────┴─────────────────────┴ - """ - if state_format not in StateFormat: - raise ValueError(f"Invalid state_format. Must be one of: {list(StateFormat.__members__.items())}") - - # Default columns to include in the filtered metadata - if not cols: - cols = [ - "clade_nextstrain", - "country", - "date", - "division", - "genbank_accession", - "genbank_accession_rev", - "host", - ] - - # There are some other odd divisions in the data, but these are 50 states, DC and PR - states = [state.name for state in us.states.STATES] - states.extend(["Washington DC", "District of Columbia", "Puerto Rico"]) - - # Filter dataset and do some general tidying - filtered_metadata = ( - metadata.select(cols) - .filter( - pl.col("country") == "USA", - pl.col("division").is_in(states), - pl.col("host") == "Homo sapiens", - ) - .rename({"clade_nextstrain": "clade"}) - .cast({"date": pl.Date}, strict=False) - # date filtering at the end ensures we filter out null - # values created by the above .cast operation - .filter( - pl.col("date").is_not_null(), - ) - ) - - # Create state mappings based on state_format parameter, including a DC alias, since - # Nextrain's metadata uses a different name than the us package - if state_format == StateFormat.FIPS: - state_dict = {state.name: state.fips for state in us.states.STATES_AND_TERRITORIES} - state_dict["Washington DC"] = us.states.DC.fips - elif state_format == StateFormat.ABBR: - state_dict = {state.name: state.abbr for state in us.states.STATES_AND_TERRITORIES} - state_dict["Washington DC"] = us.states.DC.abbr - else: - state_dict = {state.name: state.name for state in us.states.STATES_AND_TERRITORIES} - state_dict["Washington DC"] = "Washington DC" - - filtered_metadata = filtered_metadata.with_columns(pl.col("division").replace(state_dict).alias("location")).drop( - "division" - ) - - return filtered_metadata - - -def get_clade_counts(filtered_metadata: pl.LazyFrame) -> pl.LazyFrame: - """Return a count of clades by location and date.""" - - cols = [ - "clade", - "country", - "date", - "location", - "host", - ] - - counts = filtered_metadata.select(cols).group_by("location", "date", "clade").agg(pl.len().alias("count")) - - return counts - - -def _unzip_sequence_package(filename: Path, data_path: Path): - """Unzip the downloaded virus genome data package.""" - with zipfile.ZipFile(filename, "r") as package_zip: - zip_contents = package_zip.namelist() - is_metadata = next((s for s in zip_contents if "data_report" in s), None) - is_sequence = next((s for s in zip_contents if "genomic" in s), None) - if is_metadata and is_sequence: - package_zip.extractall(data_path) - else: - logger.error("NCBI package is missing expected files", zip_contents=zip_contents) - # Exit the pipeline without displaying a traceback - raise SystemExit("Error downloading NCBI package") - - -def parse_sequence_assignments(df_assignments: pl.DataFrame) -> pl.DataFrame: - """Parse out the sequence number from the seqName column returned by the clade assignment tool.""" - - # polars apparently can't split out the sequence number from that big name column - # without resorting an apply, so here we're dropping into pandas to do that - # (might be a premature optimization, since this manoever requires both pandas and pyarrow) - seq = pl.from_pandas(df_assignments.to_pandas()["seqName"].str.split(" ").str[0].rename("seq")) - - # we're expecting one row per sequence - if seq.n_unique() != df_assignments.shape[0]: - raise ValueError("Clade assignment data contains duplicate sequence. Stopping assignment process.") - - # add the parsed sequence number as a new column - df_assignments = df_assignments.insert_column(1, seq) # type: ignore - - return df_assignments +# For temporary backwards compatibility +from cladetime.sequence import _get_ncov_metadata as _get_ncov_metadata # noqa: F401 +from cladetime.sequence import filter_sequence_metadata as filter_covid_genome_metadata # noqa: F401 +from cladetime.sequence import get_clade_counts as get_clade_counts diff --git a/tests/unit/util/test_sequence.py b/tests/unit/util/test_sequence.py index 4fe14f9..c4eb5aa 100644 --- a/tests/unit/util/test_sequence.py +++ b/tests/unit/util/test_sequence.py @@ -4,13 +4,13 @@ import polars as pl import pytest -from cladetime.types import StateFormat -from cladetime.util.sequence import ( +from cladetime.sequence import ( download_covid_genome_metadata, - filter_covid_genome_metadata, + filter_sequence_metadata, get_covid_genome_metadata, parse_sequence_assignments, ) +from cladetime.types import StateFormat @pytest.fixture @@ -115,7 +115,7 @@ def test_filter_covid_genome_metadata(): } lf_metadata = pl.LazyFrame(test_genome_metadata) - lf_filtered = filter_covid_genome_metadata(lf_metadata).collect() + lf_filtered = filter_sequence_metadata(lf_metadata).collect() assert len(lf_filtered) == 2 @@ -152,7 +152,7 @@ def test_filter_covid_genome_metadata_state_name(): } lf_metadata = pl.LazyFrame(test_genome_metadata) - lf_filtered = filter_covid_genome_metadata(lf_metadata, state_format=StateFormat.NAME) + lf_filtered = filter_sequence_metadata(lf_metadata, state_format=StateFormat.NAME) lf_filtered = lf_filtered.collect() # Un-mapped states are dropped from dataset @@ -176,7 +176,7 @@ def test_filter_covid_genome_metadata_state_fips(): } lf_metadata = pl.LazyFrame(test_genome_metadata) - lf_filtered = filter_covid_genome_metadata(lf_metadata, state_format=StateFormat.FIPS) + lf_filtered = filter_sequence_metadata(lf_metadata, state_format=StateFormat.FIPS) lf_filtered = lf_filtered.collect() # Un-mapped states are dropped from dataset