Skip to content

Commit

Permalink
Port clade_data_utils function to get clade list for forecasting
Browse files Browse the repository at this point in the history
  • Loading branch information
bsweger committed Sep 6, 2024
1 parent 9268d25 commit bd8e1b8
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions src/virus_clade_utils/get_clade_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Get a list of SARS-CoV-2 clades."""

import time
from datetime import timedelta

import polars as pl
import structlog
from cloudpathlib import AnyPath

from virus_clade_utils.util.config import Config
from virus_clade_utils.util.sequence import (
download_covid_genome_metadata,
filter_covid_genome_metadata,
get_clade_counts,
get_covid_genome_metadata,
)

logger = structlog.get_logger()


def clades_to_model(
clade_counts: pl.LazyFrame, threshold: float = 0.01, threshold_weeks: int = 3, max_clades: int = 9
) -> list[str]:
"""
Determine list of clades to model
Parameters
----------
clade_counts : polars.LazyFrame
Clade counts by date and location, summarized from Nextstrain metadata
threshold : float
Clades that account for at least ``threshold`` proportion of reported
sequences are candidates for inclusion.
threshold_weeks : int
The number of weeks that we look back to identify clades.
Returns
-------
list of strings
"""
start = time.perf_counter()

# based on the data's most recent date, get the week start three weeks ago (not including this week)
max_day = clade_counts.select(pl.max("date")).collect().item()
three_sundays_ago = max_day - timedelta(days=max_day.weekday() + 7 * (threshold_weeks))

# sum over weeks, combine states, and limit to just the past 3 weeks (not including current week)
lf = (
clade_counts.filter(pl.col("date") >= three_sundays_ago)
.sort("date")
.group_by_dynamic("date", every="1w", start_by="sunday", group_by="clade")
.agg(pl.col("count").sum())
)

# create a separate frame with the total counts per week
total_counts = lf.group_by("date").agg(pl.col("count").sum().alias("total_count"))

# join with count data to add a total counts per day column
prop_dat = lf.join(total_counts, on="date").with_columns(
(pl.col("count") / pl.col("total_count")).alias("proportion")
)

# retrieve list of variants which have crossed the threshold over the past threshold_weeks
high_prev_variants = prop_dat.filter(pl.col("proportion") > threshold).select("clade").unique().collect()

# if more than the specified number of clades cross the threshold,
# take the clades with the largest counts over the past threshold_weeks
if len(high_prev_variants) > max_clades:
high_prev_variants = prop_dat.group_by("clade").agg(pl.col("count").sum()).sort("count", descending=True)

variants = high_prev_variants.get_column("clade").to_list()[:max_clades]

end = time.perf_counter()
elapsed = end - start
logger.info("generated clade list", elapsed=elapsed)

return variants


def main():
# FIXME: provide ability to instantiate Config for the
# get_clade_list function and get the data_path from there
data_path = AnyPath(".").home() / "covid_variant"

genome_metadata_path = download_covid_genome_metadata(
Config.nextstrain_latest_genome_metadata,
data_path,
)
lf_metadata = get_covid_genome_metadata(genome_metadata_path)
lf_metadata_filtered = filter_covid_genome_metadata(lf_metadata)
counts = get_clade_counts(lf_metadata_filtered)
clade_list = clades_to_model(counts)

return clade_list


if __name__ == "__main__":
main()

0 comments on commit bd8e1b8

Please sign in to comment.