Skip to content

Commit

Permalink
--single for duplex testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Psy-Fer committed Jul 1, 2024
1 parent 34ee6b0 commit 843b003
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 15 deletions.
5 changes: 5 additions & 0 deletions src/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def submit_reads(args, client, batch):
read_store = {}
for read in batch:
read_id = read['read_id']
# TODO: remove the signal from the read_store to reduce memory usage
# if args.seq_sum:
read_store[read_id] = read
# calculate scale
Expand Down Expand Up @@ -316,6 +317,7 @@ def basecaller_proc(args, iq, rq, address, config, params, N):
if args.profile:
pr = cProfile.Profile()
pr.enable()

client_sub = pclient(address=address, config=config)
client_sub.set_params(params)
# submit a batch of reads to be basecalled
Expand All @@ -324,11 +326,14 @@ def basecaller_proc(args, iq, rq, address, config, params, N):
batch = iq.get()
if batch is None:
break
print("[BASECALLER] - submitting channel: {}".format(batch[0]["channel_number"]))
# Submit to be basecalled
read_counter, read_store = submit_reads(args, client, batch)
# now collect the basecalled reads
print("[BASECALLER] - getting basecalled channel: {}".format(batch[0]["channel_number"]))
bcalled_list = get_reads(args, client, read_counter, read_store)
# TODO: make a skipped queue to handle skipped reads
print("[BASECALLER] - writing channel: {}".format(batch[0]["channel_number"]))
rq.put(bcalled_list)
iq.task_done()

Expand Down
50 changes: 35 additions & 15 deletions src/buttery_eel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ._version import __version__
from .cli import get_args
from .reader import read_worker, duplex_read_worker
from .reader import read_worker, duplex_read_worker, duplex_read_worker_single
from .writer import write_worker
from .basecaller import start_guppy_server_and_client, basecaller_proc

Expand Down Expand Up @@ -238,22 +238,42 @@ def main():
processes = []

if args.duplex:
print("Duplex mode active - a duplex model must be used to output duplex reads")
print("Buttery-eel does not have checks for this, as the model names are in flux")
print()
duplex_pre_queue = mp.JoinableQueue()
# create the same number of queues as there are worker processes so each has its own queue
queue_names = range(args.procs)
duplex_queues = {name: mp.JoinableQueue() for name in queue_names}
reader = mp.Process(target=duplex_read_worker, args=(args, duplex_queues, duplex_pre_queue), name='duplex_read_worker')
reader.start()
out_writer = mp.Process(target=write_worker, args=(args, result_queue, OUT, SAM_OUT), name='write_worker')
out_writer.start()
# set up each worker to have a unique queue, so it only processes 1 channel at a time
for name in queue_names:
basecall_worker = mp.Process(target=basecaller_proc, args=(args, duplex_queues[name], result_queue, address, config, params, name), daemon=True, name='basecall_worker_{}'.format(name))
if args.single:
print("Duplex mode active - a duplex model must be used to output duplex reads")
print("Buttery-eel does not have checks for this, as the model names are in flux")
print("SINGLE MODE ACTIVATED - FOR TESTING")
print()
duplex_pre_queue = mp.JoinableQueue()
# create the same number of queues as there are worker processes so each has its own queue
# queue_names = range(args.procs)
# duplex_queues = {name: mp.JoinableQueue() for name in queue_names}
duplex_queue = mp.JoinableQueue()
reader = mp.Process(target=duplex_read_worker_single, args=(args, duplex_queue, duplex_pre_queue), name='duplex_read_worker_single')
reader.start()
out_writer = mp.Process(target=write_worker, args=(args, result_queue, OUT, SAM_OUT), name='write_worker')
out_writer.start()
# set up each worker to have a unique queue, so it only processes 1 channel at a time
basecall_worker = mp.Process(target=basecaller_proc, args=(args, duplex_queue, result_queue, address, config, params, 0), daemon=True, name='basecall_worker_{}'.format(0))
basecall_worker.start()
processes.append(basecall_worker)

else:
print("Duplex mode active - a duplex model must be used to output duplex reads")
print("Buttery-eel does not have checks for this, as the model names are in flux")
print()
duplex_pre_queue = mp.JoinableQueue()
# create the same number of queues as there are worker processes so each has its own queue
queue_names = range(args.procs)
duplex_queues = {name: mp.JoinableQueue() for name in queue_names}
reader = mp.Process(target=duplex_read_worker, args=(args, duplex_queues, duplex_pre_queue), name='duplex_read_worker')
reader.start()
out_writer = mp.Process(target=write_worker, args=(args, result_queue, OUT, SAM_OUT), name='write_worker')
out_writer.start()
# set up each worker to have a unique queue, so it only processes 1 channel at a time
for name in queue_names:
basecall_worker = mp.Process(target=basecaller_proc, args=(args, duplex_queues[name], result_queue, address, config, params, name), daemon=True, name='basecall_worker_{}'.format(name))
basecall_worker.start()
processes.append(basecall_worker)
else:
reader = mp.Process(target=read_worker, args=(args, input_queue), name='read_worker')
reader.start()
Expand Down
2 changes: 2 additions & 0 deletions src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def get_args():
# Duplex
duplex.add_argument("--duplex", action="store_true",
help="Turn on duplex calling - channel based")
duplex.add_argument("--single", action="store_true",
help="use only a single proc for testing")


# parser.add_argument("--max_queued_reads", default="2000",
Expand Down
52 changes: 52 additions & 0 deletions src/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,58 @@ def duplex_read_worker(args, dq, pre_dq):
for qname in dq_names:
dq[qname].put(None)

# if profiling, dump info into log files in current dir
if args.profile:
pr.disable()
s = io.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
with open("read_worker.log", 'w') as f:
print(s.getvalue(), file=f)


def duplex_read_worker_single(args, dq, pre_dq):
'''
Single proc method
'''
if args.profile:
pr = cProfile.Profile()
pr.enable()

# read the file once, and get all the channels into the pre_dq queue
get_data_by_channel(args, pre_dq)

s5 = pyslow5.Open(args.input, 'r')
filename_slow5 = args.input.split("/")[-1]
header_array = {}
num_read_groups = s5.get_num_read_groups()
for read_group in range(num_read_groups):
header_array[read_group] = s5.get_all_headers(read_group=read_group)
# reads = s5.seq_reads_multi(threads=args.slow5_threads, batchsize=args.slow5_batchsize, aux='all')

# readers = {}
# break call
# ending = False
# pull from the pre_dq
while True:
ch = pre_dq.get()
if ch is None:
break
channel = ch[0]
print("[READER] - processing channel: {}".format(channel))
data = ch[1]
read_list = [i for i, _ in data]
reads = s5.get_read_list_multi(read_list, threads=args.slow5_threads, batchsize=args.slow5_batchsize, aux='all')
batches = _get_slow5_batch(args, s5, reads, size=args.slow5_batchsize, slow5_filename=filename_slow5, header_array=header_array)
for batch in chain(batches):
if dq.qsize() < 5:
dq.put(batch)
else:
time.sleep(0.1)

dq.put(None)

# if profiling, dump info into log files in current dir
if args.profile:
pr.disable()
Expand Down

0 comments on commit 843b003

Please sign in to comment.