-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port clade_data_utils function to get clade list for forecasting
- Loading branch information
Showing
1 changed file
with
98 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
"""Get a list of SARS-CoV-2 clades.""" | ||
|
||
import time | ||
from datetime import timedelta | ||
|
||
import polars as pl | ||
import structlog | ||
from cloudpathlib import AnyPath | ||
|
||
from virus_clade_utils.util.config import Config | ||
from virus_clade_utils.util.sequence import ( | ||
download_covid_genome_metadata, | ||
filter_covid_genome_metadata, | ||
get_clade_counts, | ||
get_covid_genome_metadata, | ||
) | ||
|
||
logger = structlog.get_logger() | ||
|
||
|
||
def clades_to_model( | ||
clade_counts: pl.LazyFrame, threshold: float = 0.01, threshold_weeks: int = 3, max_clades: int = 9 | ||
) -> list[str]: | ||
""" | ||
Determine list of clades to model | ||
Parameters | ||
---------- | ||
clade_counts : polars.LazyFrame | ||
Clade counts by date and location, summarized from Nextstrain metadata | ||
threshold : float | ||
Clades that account for at least ``threshold`` proportion of reported | ||
sequences are candidates for inclusion. | ||
threshold_weeks : int | ||
The number of weeks that we look back to identify clades. | ||
Returns | ||
------- | ||
list of strings | ||
""" | ||
start = time.perf_counter() | ||
|
||
# based on the data's most recent date, get the week start three weeks ago (not including this week) | ||
max_day = clade_counts.select(pl.max("date")).collect().item() | ||
three_sundays_ago = max_day - timedelta(days=max_day.weekday() + 7 * (threshold_weeks)) | ||
|
||
# sum over weeks, combine states, and limit to just the past 3 weeks (not including current week) | ||
lf = ( | ||
clade_counts.filter(pl.col("date") >= three_sundays_ago) | ||
.sort("date") | ||
.group_by_dynamic("date", every="1w", start_by="sunday", group_by="clade") | ||
.agg(pl.col("count").sum()) | ||
) | ||
|
||
# create a separate frame with the total counts per week | ||
total_counts = lf.group_by("date").agg(pl.col("count").sum().alias("total_count")) | ||
|
||
# join with count data to add a total counts per day column | ||
prop_dat = lf.join(total_counts, on="date").with_columns( | ||
(pl.col("count") / pl.col("total_count")).alias("proportion") | ||
) | ||
|
||
# retrieve list of variants which have crossed the threshold over the past threshold_weeks | ||
high_prev_variants = prop_dat.filter(pl.col("proportion") > threshold).select("clade").unique().collect() | ||
|
||
# if more than the specified number of clades cross the threshold, | ||
# take the clades with the largest counts over the past threshold_weeks | ||
if len(high_prev_variants) > max_clades: | ||
high_prev_variants = prop_dat.group_by("clade").agg(pl.col("count").sum()).sort("count", descending=True) | ||
|
||
variants = high_prev_variants.get_column("clade").to_list()[:max_clades] | ||
|
||
end = time.perf_counter() | ||
elapsed = end - start | ||
logger.info("generated clade list", elapsed=elapsed) | ||
|
||
return variants | ||
|
||
|
||
def main(): | ||
# FIXME: provide ability to instantiate Config for the | ||
# get_clade_list function and get the data_path from there | ||
data_path = AnyPath(".").home() / "covid_variant" | ||
|
||
genome_metadata_path = download_covid_genome_metadata( | ||
Config.nextstrain_latest_genome_metadata, | ||
data_path, | ||
) | ||
lf_metadata = get_covid_genome_metadata(genome_metadata_path) | ||
lf_metadata_filtered = filter_covid_genome_metadata(lf_metadata) | ||
counts = get_clade_counts(lf_metadata_filtered) | ||
clade_list = clades_to_model(counts) | ||
|
||
return clade_list | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |