diff --git a/src/virus_clade_utils/get_clade_list.py b/src/virus_clade_utils/get_clade_list.py new file mode 100644 index 0000000..b77d682 --- /dev/null +++ b/src/virus_clade_utils/get_clade_list.py @@ -0,0 +1,98 @@ +"""Get a list of SARS-CoV-2 clades.""" + +import time +from datetime import timedelta + +import polars as pl +import structlog +from cloudpathlib import AnyPath + +from virus_clade_utils.util.config import Config +from virus_clade_utils.util.sequence import ( + download_covid_genome_metadata, + filter_covid_genome_metadata, + get_clade_counts, + get_covid_genome_metadata, +) + +logger = structlog.get_logger() + + +def clades_to_model( + clade_counts: pl.LazyFrame, threshold: float = 0.01, threshold_weeks: int = 3, max_clades: int = 9 +) -> list[str]: + """ + Determine list of clades to model + + Parameters + ---------- + clade_counts : polars.LazyFrame + Clade counts by date and location, summarized from Nextstrain metadata + threshold : float + Clades that account for at least ``threshold`` proportion of reported + sequences are candidates for inclusion. + threshold_weeks : int + The number of weeks that we look back to identify clades. + + Returns + ------- + list of strings + """ + start = time.perf_counter() + + # based on the data's most recent date, get the week start three weeks ago (not including this week) + max_day = clade_counts.select(pl.max("date")).collect().item() + three_sundays_ago = max_day - timedelta(days=max_day.weekday() + 7 * (threshold_weeks)) + + # sum over weeks, combine states, and limit to just the past 3 weeks (not including current week) + lf = ( + clade_counts.filter(pl.col("date") >= three_sundays_ago) + .sort("date") + .group_by_dynamic("date", every="1w", start_by="sunday", group_by="clade") + .agg(pl.col("count").sum()) + ) + + # create a separate frame with the total counts per week + total_counts = lf.group_by("date").agg(pl.col("count").sum().alias("total_count")) + + # join with count data to add a total counts per day column + prop_dat = lf.join(total_counts, on="date").with_columns( + (pl.col("count") / pl.col("total_count")).alias("proportion") + ) + + # retrieve list of variants which have crossed the threshold over the past threshold_weeks + high_prev_variants = prop_dat.filter(pl.col("proportion") > threshold).select("clade").unique().collect() + + # if more than the specified number of clades cross the threshold, + # take the clades with the largest counts over the past threshold_weeks + if len(high_prev_variants) > max_clades: + high_prev_variants = prop_dat.group_by("clade").agg(pl.col("count").sum()).sort("count", descending=True) + + variants = high_prev_variants.get_column("clade").to_list()[:max_clades] + + end = time.perf_counter() + elapsed = end - start + logger.info("generated clade list", elapsed=elapsed) + + return variants + + +def main(): + # FIXME: provide ability to instantiate Config for the + # get_clade_list function and get the data_path from there + data_path = AnyPath(".").home() / "covid_variant" + + genome_metadata_path = download_covid_genome_metadata( + Config.nextstrain_latest_genome_metadata, + data_path, + ) + lf_metadata = get_covid_genome_metadata(genome_metadata_path) + lf_metadata_filtered = filter_covid_genome_metadata(lf_metadata) + counts = get_clade_counts(lf_metadata_filtered) + clade_list = clades_to_model(counts) + + return clade_list + + +if __name__ == "__main__": + main()