Skip to content

Commit

Permalink
Merge pull request #32 from jodyphelan/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
jodyphelan authored Oct 18, 2023
2 parents d65a0ba + 41d8221 commit a2cca8f
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 94 deletions.
47 changes: 23 additions & 24 deletions ntm_profiler/output.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@

from collections import defaultdict
import os
from pathogenprofiler import filecheck, debug, infolog
from pathogenprofiler import filecheck
import csv
import pathogenprofiler as pp
import time
from tqdm import tqdm
import json
import jinja2
import logging

def write_outputs(args,results):
infolog("\nWriting outputs")
infolog("---------------")
logging.info("\nWriting outputs")
logging.info("---------------")
json_output = args.dir+"/"+args.prefix+".results.json"
text_output = args.dir+"/"+args.prefix+".results.txt"
csv_output = args.dir+"/"+args.prefix+".results.csv"
extra_columns = [x.lower() for x in args.add_columns.split(",")] if args.add_columns else []
infolog(f"Writing json file: {json_output}")
logging.info(f"Writing json file: {json_output}")
json.dump(results,open(json_output,"w"))
if args.txt:
infolog(f"Writing text file: {text_output}")
write_text(results,args.conf,text_output,extra_columns,reporting_af=args.reporting_af)
logging.info(f"Writing text file: {text_output}")
write_text(results,args.conf,text_output,extra_columns)
if args.csv:
infolog(f"Writing csv file: {csv_output}")
logging.info(f"Writing csv file: {csv_output}")
write_text(results,args.conf,csv_output,extra_columns)

default_template = """
Expand Down Expand Up @@ -100,11 +101,7 @@ def write_outputs(args,results):
Species report
-----------------
{{d['species_report']}}
Mash species report
-----------------
{{d['mash_species_report']}}
{{d['sourmash_species_report']}}
Analysis pipeline specifications
--------------------------------
Expand All @@ -124,20 +121,20 @@ def load_text(text_strings,template = None,file_template=None):
return t.render(d=text_strings)


def write_text(json_results,conf,outfile,columns = None,reporting_af = 0.0,sep="\t",template_file=None):
def write_text(json_results,conf,outfile,columns = None,sep="\t",template_file=None):
if "resistance_genes" not in json_results:
return write_species_text(json_results,outfile)
json_results = pp.get_summary(json_results,conf,columns = columns,reporting_af=reporting_af)
json_results = pp.get_summary(json_results,conf,columns = columns)
json_results["drug_table"] = [[y for y in json_results["drug_table"] if y["Drug"].upper()==d.upper()][0] for d in conf["drugs"]]
for var in json_results["dr_variants"]:
var["drug"] = ", ".join([d["drug"] for d in var["drugs"]])
text_strings = {}
text_strings["id"] = json_results["id"]
text_strings["date"] = time.ctime()
if json_results["species"] is not None:
text_strings["species_report"] = pp.dict_list2text(json_results["species"]["prediction"],["species","mean"],{"species":"Species","mean":"Mean kmer coverage"},sep=sep)
if "mash_closest_species" in json_results:
text_strings["mash_species_report"] = pp.dict_list2text(json_results["mash_closest_species"]["prediction"],{"accession":"Accession","species":"Species","mash-ANI":"mash-ANI"},sep=sep)
# if json_results["species"] is not None:
# text_strings["species_report"] = pp.dict_list2text(json_results["species"]["prediction"],["species","mean"],{"species":"Species","mean":"Mean kmer coverage"},sep=sep)
if "species" in json_results and len(json_results['species']['prediction_info'])>0:
text_strings["sourmash_species_report"] = pp.dict_list2text(json_results["species"]["prediction_info"],{"accession":"Accession","species":"Species","ani":"ANI","abundance":"Abundance"},sep=sep)
if "barcode" in json_results:
text_strings["cluster_report"] = pp.dict_list2text(json_results["barcode"],mappings={"annotation":"Cluster","freq":"Frequency"})
text_strings["dr_report"] = pp.dict_list2text(json_results["drug_table"],["Drug","Genotypic Resistance","Mutations"]+columns if columns else [],sep=sep)
Expand All @@ -148,7 +145,7 @@ def write_text(json_results,conf,outfile,columns = None,reporting_af = 0.0,sep="
text_strings["missing_report"] = pp.dict_list2text(json_results["qc"]["missing_positions"],["gene","locus_tag","position","position_type","drug_resistance_position"],sep=sep) if "missing_report" in json_results["qc"] else "N/A"
text_strings["pipeline"] = pp.dict_list2text(json_results["pipeline_software"],["Analysis","Program"],sep=sep)
text_strings["version"] = json_results["software_version"]
debug(json_results["species"]["species_db_version"])

text_strings["species_db_version"] = "%(name)s_%(Author)s_%(Date)s" % json_results["species"]["species_db_version"] if "species_db_version" in json_results['species'] else "N/A"
text_strings["resistance_db_version"] = "%(name)s_%(Author)s_%(Date)s" % json_results["resistance_db_version"] if "resistance_db_version" in json_results else "N/A"
if sep=="\t":
Expand All @@ -165,8 +162,8 @@ def write_species_text(json_results,outfile,sep="\t",template_file=None):
text_strings["id"] = json_results["id"]
text_strings["date"] = time.ctime()
text_strings["species_report"] = pp.dict_list2text(json_results["species"]["prediction"],["species","mean"],{"species":"Species","mean":"Mean kmer coverage"},sep=sep)
if "mash_closest_species" in json_results:
text_strings["mash_species_report"] = pp.dict_list2text(json_results["mash_closest_species"]["prediction"],{"accession":"Accession","species":"Species","mash-ANI":"mash-ANI"},sep=sep)
if "species" in json_results and len(json_results['species']['prediction_info'])>0:
text_strings["sourmash_species_report"] = pp.dict_list2text(json_results["species"]["prediction_info"],{"species":"Species","ani":"ANI","abundance":"Abundance","accession":"Closest accession"},sep=sep)
text_strings["pipeline"] = pp.dict_list2text(json_results["pipeline_software"],["Analysis","Program"],sep=sep)
text_strings["version"] = json_results["software_version"]
text_strings["species_db_version"] = "%(name)s_%(Author)s_%(Date)s" % json_results["species"]["species_db_version"]
Expand All @@ -190,7 +187,7 @@ def collate(args):
samples = [x.replace(args.suffix,"") for x in os.listdir(args.dir) if x[-len(args.suffix):]==args.suffix]

if len(samples)==0:
pp.infolog(f"\nNo result files found in directory '{args.dir}'. Do you need to specify '--dir'?\n")
pp.logging.info(f"\nNo result files found in directory '{args.dir}'. Do you need to specify '--dir'?\n")
quit(0)

# Loop through the sample result files
Expand All @@ -203,8 +200,10 @@ def collate(args):
for s in tqdm(samples):
# Data has the same structure as the .result.json files
data = json.load(open(filecheck("%s/%s%s" % (args.dir,s,args.suffix))))
if len(data["species"]["prediction"])>0:
species[s] = ";".join([d["species"] for d in data["species"]["prediction"]])
if data["species"]["prediction"]:
species[s] = data["species"]["prediction"]
else:
species[s] = "N/A"
if "mash_closest_species" in data:
closest_seq[s] = "|".join(pp.stringify(data["mash_closest_species"]["prediction"][0].values())) if len(data["mash_closest_species"]["prediction"])>0 else ""
if "barcode" in data:
Expand Down
9 changes: 9 additions & 0 deletions ntm_profiler/reformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ def reformat_resistance_genes(results):
del d["annotations"]
return results

def add_subspecies(results):
if "barcode" in results:
subspecies_found = [d['annotation'] for d in results["barcode"] if d['annotation'].startswith("subsp.")]
if len(subspecies_found)>0:
subspecies_list = [results["species"]["prediction"] + " "+ d for d in subspecies_found]
results["species"]["prediction"] = ";".join(subspecies_list)


def reformat(results,conf):
results["variants"] = [x for x in results["variants"] if len(x["consequences"])>0]
results["variants"] = pp.select_csq(results["variants"])
Expand All @@ -19,4 +27,5 @@ def reformat(results,conf):
if "missing_positions" in results["qc"]:
results["qc"]["missing_positions"] = pp.reformat_missing_genome_pos(results["qc"]["missing_positions"],conf)

add_subspecies(results)
return results
69 changes: 40 additions & 29 deletions ntm_profiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
from uuid import uuid4
import pathogenprofiler as pp

def infolog(x):
sys.stderr.write('\033[94m' + str(x) + '\033[0m' + '\n')

def errlog(x):
sys.stderr.write('\033[91m' + str(x) + '\033[0m' + '\n')


def test_resistance_genes(conf,results):
resistance_genes = {}
Expand Down Expand Up @@ -37,34 +31,51 @@ def test_resistance_genes(conf,results):
return results


def get_mash_hit(args):
def get_sourmash_hit(args):
args.species_conf = pp.get_db(args.software_name,args.species_db)
db_info = pp.parse_csv(args.species_conf["mash_db_info"])
if args.read1:
if args.read2:
pp.run_cmd(f"cat {args.read1} {args.read2} > {args.files_prefix}.fq.gz")
reads = f"{args.files_prefix}.fq.gz"
fastq = pp.Fastq(args.read1,args.read2)
else:
reads = args.read1
pp.run_cmd(f"mash dist -m 2 {args.species_conf['mash_db']} {reads} | sort -gk3 | head > {args.files_prefix}.mash_dist.txt")
fastq = pp.Fastq(args.read1)
sourmash_sig = fastq.sourmash_sketch(args.files_prefix)
elif args.fasta:
pp.run_cmd(f"mash dist {args.species_conf['mash_db']} {args.fasta} | sort -gk3 | head > {args.files_prefix}.mash_dist.txt")
pp.run_cmd(f"mash dist {args.species_conf['sourmash_db']} {args.fasta} | sort -gk3 | head > {args.files_prefix}.mash_dist.txt")
fasta = pp.Fasta(args.fasta)
sourmash_sig = fasta.sourmash_sketch(args.files_prefix)
elif args.bam:
pp.run_cmd(f"samtools fastq {args.bam} | mash dist -m 2 {args.species_conf['mash_db']} - | sort -gk3 | head > {args.files_prefix}.mash_dist.txt")

result = {
"prediction_method":"mash",
"prediction":[],
"species_db_version":args.species_conf["version"]
}
for l in open(f"{args.files_prefix}.mash_dist.txt"):
row = l.strip().split()
acc = row[0].replace("db/","").replace(".fa","")
species = db_info[acc]["species"]
result["prediction"].append({
"accession":acc,
"species":species,
"mash-ANI":1-float(row[2])
})
pp.run_cmd(f"samtools fastq {args.bam} > {args.files_prefix}.tmp.fastq")
fq_file = f"{args.files_prefix}.tmp.fastq"
fastq = pp.Fastq(fq_file)
sourmash_sig = fastq.sourmash_sketch(args.files_prefix)

sourmash_sig = sourmash_sig.gather(args.species_conf["sourmash_db"],args.species_conf["sourmash_db_info"],intersect_bp=2500000,f_match_threshold=0.1)
result = []

if len(sourmash_sig)>0:
result = sourmash_sig

return result

def summarise_sourmash_hits(sourmash_hits):
species = []
for hit in sourmash_hits:
if hit["species"] not in species:
species.append(hit["species"])
return ";".join(species)

def consolidate_species_predictions(kmer_prediction, sourmash_prediction):
filtered_sourmash_prediction = [d for d in sourmash_prediction if d["ani"]>95]
if len(kmer_prediction)>1:
return None
elif len(kmer_prediction)>0 and len(filtered_sourmash_prediction)>0:
if kmer_prediction[0]["species"]==filtered_sourmash_prediction[0]["species"]:
return kmer_prediction[0]["species"]
else:
return None
elif len(filtered_sourmash_prediction)>0:
return filtered_sourmash_prediction[0]["species"]
elif len(kmer_prediction)>0:
return kmer_prediction[0]["species"]
else:
return None
Loading

0 comments on commit a2cca8f

Please sign in to comment.