diff --git a/py/transcript_abundance.py b/py/transcript_abundance.py index b16e517..f807d20 100755 --- a/py/transcript_abundance.py +++ b/py/transcript_abundance.py @@ -2,16 +2,37 @@ import argparse import gzip from collections import defaultdict +from typing import List, Tuple +import numpy as np from tqdm import tqdm # Includes EM code by Jared Simpson from # https://github.com/jts/nanopore-rna-analysis/blob/master/nanopore_transcript_abundance.py +IUPAC_nts = { + "A": np.array(["A"], dtype=str), + "C": np.array(["C"], dtype=str), + "G": np.array(["G"], dtype=str), + "T": np.array(["T"], dtype=str), + "R": np.array(["A", "G"], dtype=str), + "Y": np.array(["C", "T"], dtype=str), + "K": np.array(["G", "T"], dtype=str), + "M": np.array(["A", "C"], dtype=str), + "S": np.array(["C", "G"], dtype=str), + "W": np.array(["A", "T"], dtype=str), + "B": np.array(["C", "G", "T"], dtype=str), + "D": np.array(["A", "G", "T"], dtype=str), + "H": np.array(["A", "C", "T"], dtype=str), + "V": np.array(["A", "C", "G"], dtype=str), + "N": np.array(["A", "C", "G", "T"], dtype=str), +} + def parse_args(): parser = argparse.ArgumentParser( - description="Output a TSV file of the long-read trasncript expresion." + description="Output a TSV file of the long-read trasncript expresion.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "-p", @@ -27,43 +48,131 @@ def parse_args(): default="", help="TSV output of the long-reads barcode matching from scTager (e.g. S1.lr_matches.tsv.gz)", ) + parser.add_argument( + "--cb-count", + type=int, + default=0, + help="Number of cell barcodes to simulate." + + " If >0, then the abundance will be split by cell barcodes according to a normal distribution." + + " Cannot be used with --lr-br parameter.", + ) + parser.add_argument( + "--cb-lognorm-params", + type=str, + default="10,1", + help="Parameters for the lognormal distribution of cell barcodes abundance. " + + " Takes two comma-separated values: mean and standard deviation.", + ) + parser.add_argument( + "--cb-pattern", + type=str, + default="NNNNNNNNNNNN", + help="FASTA pattern for the cell barcode. If used, --cb-count parameter must also be used.", + ) + parser.add_argument( + "--cb-dropout", + type=float, + default=0.2, + help="Fraction of reads without cell barcodes. If used, --cb-count parameter must also be used.", + ) + parser.add_argument( + "--cb-txt", + type=str, + default="", + help="Path to a text file of cell barcodes whitelist, one per line." + + "If used, --cb-count parameter must also be used." + + "Do not use with --cb-pattern parameter.", + ) parser.add_argument( "-o", "--output", type=str, required=True, - help="Path for the output file (e.g cDNA.abundance.tsv.tsv.gz or cDNA.abundance.tsv). Will be gzipped if it ends with .gz extenstion.", + help="Path for the output file (e.g cDNA.abundance.tsv.tsv.gz or cDNA.abundance.tsv)." + + "Will be gzipped if it ends with .gz extenstion.", + ) + parser.add_argument( + "-em", + "--em-iterations", + type=int, + default=10, + ) + parser.add_argument( + "--random-seed", + type=int, + default=42, ) - parser.add_argument("-em", "--em-iterations", type=int, default=10) - parser.add_argument("-v", "--verbose", type=int, default=0) + parser.add_argument( + "-v", + "--verbose", + type=int, + default=0, + ) + class ListPrinter(argparse.Action): - def __call__(self, parser, namespace ,values, option_string): - txt = '\n'.join([getattr(k, 'dest') for k in parser._actions]) + def __call__(self, parser, namespace, values, option_string): + txt = "\n".join([getattr(k, "dest") for k in parser._actions]) print(txt) parser.exit() - parser.add_argument( - "--list", - nargs=0, - action=ListPrinter - ) + parser.add_argument("--list", nargs=0, action=ListPrinter) args = parser.parse_args() + if args.cb_count > 0: + if args.lr_br != "": + parser.error("--lr-br must not be set with --cb-count") + if args.cb_pattern == "" and args.cb_txt == "": + parser.error("--cb-pattern or --cb-txt must be set with --cb-count") + for c in args.cb_pattern: + if c not in IUPAC_nts.keys(): + parser.error( + "--cb-pattern must contain only valid IUPAC nucleotide letters: " + + f"<{c}> not in {','.join(IUPAC_nts.keys())}" + ) + if args.cb_dropout < 0 or args.cb_dropout > 1: + parser.error("--cb-dropout must be between 0 and 1") + args.cb_lognorm_params = tuple( + float(x) for x in args.cb_lognorm_params.split(",") + ) + assert len(args.cb_lognorm_params) == 2 + assert args.cb_lognorm_params[1] > 0 return args +def generate_barcodes_from_pattern(pattern: str, count: int): + barcodes: List[str] = list() + for _ in range(count): + barcode: List[str] = list() + for p in pattern: + c: str = np.random.choice(IUPAC_nts[p]) + barcode.append(c) + barcodes.append("".join(barcode)) + return barcodes + + +def parse_barcodes_txt(barcode_txt: str): + barcodes: List[str] = list() + if barcode_txt.endswith(".gz"): + infile = gzip.open(barcode_txt, "rt") + else: + infile = open(barcode_txt, "r") + for l in infile: + l: str = l + barcode = l.rstrip("\n") + barcodes.append(barcode) + return barcodes + + def parse_lr_bc_matches(lr_br_tsv): - rid_to_bc = defaultdict(lambda: ".") - if lr_br_tsv == "": - return rid_to_bc - elif lr_br_tsv.endswith(".gz"): + rid_to_bc: defaultdict[str, str] = defaultdict(lambda: ".") + if lr_br_tsv.endswith(".gz"): infile = gzip.open(lr_br_tsv, "rt") else: infile = open(lr_br_tsv, "r") print("Parsing LR barcode matches TSV...") - for l in tqdm(infile): - l = l.rstrip("\n").split("\t") - rid, _, c, _, bc = l + for l in tqdm(infile): # type: ignore + l: str = l + rid, _, c, _, bc = l.rstrip("\n").split("\t") if c != "1": continue rid_to_bc[rid] = bc @@ -193,10 +302,52 @@ def calculate_split_abundance(compatibility, rid_to_bc): return abundance +def generate_rid_to_bc( + barcodes: List[str], dropout: float, lognorm_params: Tuple[float, float] +): + weights = np.random.lognormal(*lognorm_params, size=len(barcodes)) + total_with_dropout: float = weights.sum() / (1 - dropout) + dropout_weight: float = total_with_dropout * dropout + barcodes.append(".") + weights = np.append(weights, np.array([dropout_weight])) + weights = weights / weights.sum() + + rid_to_bc: defaultdict[str, str] = defaultdict( + lambda: np.random.choice( + a=np.array(barcodes, dtype=str), + size=1, + p=weights, + replace=True, + )[0] + ) + return rid_to_bc + + def main(): args = parse_args() - - rid_to_bc = parse_lr_bc_matches(args.lr_br) + np.random.seed(args.random_seed) + if args.lr_br == "": + if args.cb_count <= 0: + rid_to_bc: defaultdict[str, str] = defaultdict(lambda: ".") + else: + if args.cb_txt != "": + barcodes = parse_barcodes_txt(args.cb_txt) + assert len(barcodes) >= args.cb_count + barcodes = np.random.choice( + barcodes, + size=args.cb_count, + replace=True, + ) + else: + barcodes = generate_barcodes_from_pattern( + args.cb_pattern, + args.cb_count, + ) + rid_to_bc = generate_rid_to_bc( + barcodes, args.cb_dropout, args.cb_lognorm_params + ) + else: + rid_to_bc = parse_lr_bc_matches(args.lr_br) tid_to_tname, alignments = parse_paf(args.paf) transcript_compatibility = get_compatibility(alignments) del alignments @@ -217,7 +368,7 @@ def main(): outfile = open(args.output, "w+") total_reads = len(transcript_compatibility) print(f"Parsed alignments for {total_reads} reads") - outfile.write("target_id\ttpm\tcell\n") + outfile.write("target_id\ttpm\tcell\n") # type: ignore for (tid, cell), a in abundance.items(): tpm = a * 1_000_000 # if you need >100M reads to see this transcript, then skip @@ -230,9 +381,9 @@ def main(): f"{tpm:.3f}", f"{cell}", ] - ) + ) # type: ignore ) - outfile.write("\n") + outfile.write("\n") # type: ignore outfile.close()