Skip to content

Commit

Permalink
placate mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
kedhammar committed Jun 12, 2024
1 parent 6140598 commit 7d5a6ee
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 56 deletions.
13 changes: 8 additions & 5 deletions anglerfish/anglerfish.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import pkg_resources

from .demux.adaptor import Adaptor
from .demux.demux import (
cluster_matches,
layout_matches,
Expand Down Expand Up @@ -86,14 +87,16 @@ def run_demux(args):
adaptor_set: set[tuple[str, str]] = set(adaptor_tuples)

# Create a dictionary with the adaptors as keys and an empty list as value
adaptors_sorted: dict[tuple[str, str], list] = dict([(i, []) for i in adaptor_set])
adaptors_sorted: dict[tuple[str, str], list[tuple[str, Adaptor, str]]] = dict(
[(i, []) for i in adaptor_set]
)

# Populate the dictionary values with sample-specific information
"""
adaptors_sorted = {
( adaptor_name, ont_barcode ) : [
(sample_name, adaptor, fastq),
(sample_name, adaptor, fastq),
adaptor_name_str, ont_barcode_str ) : [
(sample_name_str, Adaptor, fastq_str),
(sample_name_str, Adaptor, fastq_str),
...
],
...
Expand Down Expand Up @@ -168,7 +171,7 @@ def run_demux(args):
**flips[args.force_rc],
)
flipped_i7, flipped_i5 = flips[args.force_rc].values()
elif args.lenient: # Try reverse complementing the I5 and/or i7 indices and choose the best match
elif args.lenient: # Try reverse complementing the i5 and/or i7 indices and choose the best match
flipped = {}
results = []
pool = multiprocessing.Pool(
Expand Down
25 changes: 13 additions & 12 deletions anglerfish/demux/adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def __init__(self, sequence_token: str, name: str, index_seq: str | None):

def _setup(self, sequence_token: str, name: str, index_seq: str | None):
# Assign attributes from args
self.sequence_token: str = sequence_token
self.name: str = name
self.index_seq: str | None = index_seq
self.sequence_token = sequence_token
self.name = name
self.index_seq = index_seq

# Index bool and len
if has_match(INDEX_TOKEN, self.sequence_token):
Expand All @@ -117,9 +117,9 @@ def _setup(self, sequence_token: str, name: str, index_seq: str | None):
)
elif len(umi_tokens) == 1:
self.has_umi = True
self.len_umi = int(
re.search(UMI_LENGTH_TOKEN, self.sequence_token).group(1)
)
umi_token_search = re.search(UMI_LENGTH_TOKEN, self.sequence_token)
assert isinstance(umi_token_search, re.Match)
self.len_umi = int(umi_token_search.group(1))
else:
self.has_umi = False
self.len_umi = 0
Expand Down Expand Up @@ -192,11 +192,12 @@ def get_mask(self, insert_Ns: bool = True) -> str:
else 0
)

umi_mask_length = (
max(self.len_umi_after_index, self.len_umi_before_index)
if insert_Ns and self.has_umi
else 0
)
if insert_Ns and self.has_umi:
assert self.len_umi_before_index is not None
assert self.len_umi_after_index is not None
umi_mask_length = max(self.len_umi_after_index, self.len_umi_before_index)
else:
umi_mask_length = 0

# Test if the index is specified in the adaptor sequence when it shouldn't be
if (
Expand All @@ -216,7 +217,7 @@ def get_mask(self, insert_Ns: bool = True) -> str:
return self.sequence_token


def has_match(pattern: re.Pattern, query: str) -> bool:
def has_match(pattern: re.Pattern | str, query: str) -> bool:
"""General function to check if a string contains a pattern."""
match = re.search(pattern, query)
if match is None:
Expand Down
72 changes: 43 additions & 29 deletions anglerfish/demux/demux.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def parse_cs(
cs_string: str, index_seq: str, umi_before: int = 0, umi_after: int = 0
cs_string: str, index_seq: str, umi_before: int|None = 0, umi_after: int|None = 0
) -> tuple[str, int]:
"""
Given a cs string, an index sequence, and optional UMI lengths:
Expand All @@ -42,16 +42,16 @@ def parse_cs(


def run_minimap2(
fastq_in: os.PathLike,
index_file: os.PathLike,
output_paf: os.PathLike,
fastq_in: str,
index_file: str,
output_paf: str,
threads: int,
minimap_b: int = 1,
):
"""
Runs Minimap2
"""
cmd = [
cmd: list[str] = [
"minimap2",
"--cs", # Output the cs tag (short)
"-c", # Output cigar string in .paf
Expand All @@ -72,7 +72,7 @@ def run_minimap2(


def parse_paf_lines(
paf: os.PathLike, min_qual: int = 1, complex_identifier: bool = False
paf_path: str, min_qual: int = 1, complex_identifier: bool = False
) -> dict[str, list[dict]]:
"""
Read and parse one paf alignment lines.
Expand All @@ -81,30 +81,30 @@ def parse_paf_lines(
If complex_identifier is True (default False), the keys will be on the form
"{read}_{i5_or_i7}_{strand_str}".
"""
entries = {}
with open(paf) as paf:
entries: dict = {}
with open(paf_path) as paf:
for paf_line in paf:
paf_cols = paf_line.split()
try:
# TODO: objectify this
entry = {
"read": paf_cols[0],
"adapter": paf_cols[5],
"rlen": int(paf_cols[1]), # read length
"rstart": int(paf_cols[2]), # start alignment on read
"rend": int(paf_cols[3]), # end alignment on read
"strand": paf_cols[4],
"cg": paf_cols[-2], # cigar string
"cs": paf_cols[-1], # cigar diff string
"q": int(paf_cols[11]), # Q score
"iseq": None,
"sample": None,
}
read = entry["read"]

# Unpack cols to vars for type annotation
read: str = paf_cols[0]
adapter: str = paf_cols[5]
rlen: int = int(paf_cols[1]) # read length
rstart: int = int(paf_cols[2]) # start alignment on read
rend: int = int(paf_cols[3]) # end alignment on read
strand: str = paf_cols[4]
cg: str = paf_cols[-2] # cigar string
cs: str = paf_cols[-1] # cigar diff string
q: int = int(paf_cols[11]) # Q score
iseq: str | None = None
sample: str | None = None

# Determine identifier
if complex_identifier:
i5_or_i7 = entry["adapter"].split("_")[-1]
if entry["strand"] == "+":
i5_or_i7 = adapter.split("_")[-1]
if strand == "+":
strand_str = "positive"
else:
strand_str = "negative"
Expand All @@ -116,10 +116,25 @@ def parse_paf_lines(
log.debug(f"Could not find all paf columns: {read}")
continue

if entry["q"] < min_qual:
if q < min_qual:
log.debug(f"Low quality alignment: {read}")
continue

# Compile entry
entry = {
"read": read,
"adapter": adapter,
"rlen": rlen,
"rstart": rstart,
"rend": rend,
"strand": strand,
"cg": cg,
"cs": cs,
"q": q,
"iseq": iseq,
"sample": sample,
}

if key in entries.keys():
entries[key].append(entry)
else:
Expand Down Expand Up @@ -177,8 +192,8 @@ def layout_matches(


def cluster_matches(
sample_adaptor: dict[tuple[str, str], list],
matches: tuple[dict, dict, dict, dict],
sample_adaptor: list[tuple[str, Adaptor, str]],
matches: dict,
max_distance: int,
i7_reversed: bool = False,
i5_reversed: bool = False,
Expand All @@ -188,8 +203,7 @@ def cluster_matches(
matched_bed = []
unmatched_bed = []
for read, alignments in matches.items():
i5 = False
i7 = False

if (
alignments[0]["adapter"][-2:] == "i5"
and alignments[1]["adapter"][-2:] == "i7"
Expand Down
12 changes: 7 additions & 5 deletions anglerfish/demux/samplesheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import re
from dataclasses import dataclass
from itertools import combinations
from typing import cast

import Levenshtein as lev

from anglerfish.demux.adaptor import Adaptor, load_adaptors

ADAPTORS = load_adaptors(raw=True)
ADAPTORS = cast(dict, load_adaptors(raw=True))


@dataclass
Expand All @@ -32,7 +33,8 @@ def __init__(self, input_csv: str, ont_barcodes_enabled: bool):
self.samplesheet = []
try:
csvfile = open(input_csv)
dialect = csv.Sniffer().sniff(csvfile.readline(), [",", ";", "\t"])
csv_first_line: str = csvfile.readline()
dialect = csv.Sniffer().sniff(csv_first_line, ",;\t")
csvfile.seek(0)
data = csv.DictReader(
csvfile,
Expand Down Expand Up @@ -113,14 +115,14 @@ def minimum_bc_distance(self) -> int:
or within each ONT barcode group.
"""

ont_bc_to_adaptors = {}
ont_bc_to_adaptors: dict = {}
for entry in self.samplesheet:
if entry.ont_barcode in ont_bc_to_adaptors:
ont_bc_to_adaptors[entry.ont_barcode].append(entry.adaptor)
else:
ont_bc_to_adaptors[entry.ont_barcode] = [entry.adaptor]

testset = {}
testset: dict = {}
for ont_barcode, adaptors in ont_bc_to_adaptors.items():
testset[ont_barcode] = []
for adaptor in adaptors:
Expand Down Expand Up @@ -148,7 +150,7 @@ def minimum_bc_distance(self) -> int:
min_distances_all_barcodes.append(min(distances_within_barcode))
return min(min_distances_all_barcodes)

def get_fastastring(self, adaptor_name: str = None) -> str:
def get_fastastring(self, adaptor_name: str | None = None) -> str:
fastas = {}
for entry in self.samplesheet:
if entry.adaptor.name == adaptor_name or adaptor_name is None:
Expand Down
14 changes: 9 additions & 5 deletions anglerfish/explore/explore.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import uuid
from typing import cast

import pandas as pd

Expand All @@ -13,8 +14,8 @@


def run_explore(
fastq: os.PathLike,
outdir: os.PathLike,
fastq: str,
outdir: str,
threads: int,
use_existing: bool,
good_hit_threshold: float,
Expand Down Expand Up @@ -42,8 +43,8 @@ def run_explore(
log.info("Running anglerfish explore")
log.info(f"Run uuid {run_uuid}")

adaptors: list[Adaptor] = load_adaptors()
alignments: list[tuple[Adaptor, os.PathLike]] = []
adaptors = cast(list[Adaptor], load_adaptors())
alignments: list[tuple[Adaptor, str]] = []

# Map all reads against all adaptors
for adaptor in adaptors:
Expand All @@ -67,7 +68,7 @@ def run_explore(
)

# Parse alignments
entries = {}
entries: dict = {}
adaptors_included = []
for adaptor, aln_path in alignments:
log.info(f"Parsing {adaptor.name}")
Expand Down Expand Up @@ -114,6 +115,9 @@ def run_explore(
["i5", "i7"], [adaptor.i5, adaptor.i7]
):
if adaptor_end.has_index:
assert adaptor_end.len_before_index is not None
assert adaptor_end.len_after_index is not None

# Alignment thresholds
before_thres = round(adaptor_end.len_before_index * good_hit_threshold)
after_thres = round(adaptor_end.len_after_index * good_hit_threshold)
Expand Down

0 comments on commit 7d5a6ee

Please sign in to comment.