Skip to content

Commit

Permalink
hidden args, move fake reads, duplex pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Psy-Fer committed Aug 20, 2024
1 parent 6b50789 commit cc7934d
Showing 1 changed file with 97 additions and 43 deletions.
140 changes: 97 additions & 43 deletions src/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def start_guppy_server_and_client(args, server_args):
basecaller_bin = arg
gotten = True
continue
if arg in ["--do_read_splitting", "--detect_adapter", "--detect_mid_strand_adapter", "--detect_mid_strand_barcodes"]:
print("==========================================================================\n Depricated Arguments Detected\n==========================================================================")
print("{} no longer used, please check docs to learn more".format(arg))
print("\n")
continue
tmp_args.append(arg)

if basecaller_bin is None:
Expand Down Expand Up @@ -189,27 +194,9 @@ def submit_reads(args, client, sk, batch):
break
if result:
read_counter += 1
if args.duplex:
# send fake reads with different channels to trick server to flushing cache
for chhh in range(90000, 90011):
result = client.pass_read(helper_functions.package_read(
raw_data=np.frombuffer(read['signal'], np.int16),
read_id=str(chhh),
start_time=read['start_time'],
daq_offset=read['offset'],
daq_scaling=scale,
sampling_rate=read['sampling_rate'],
mux=read['start_mux'],
channel=int(chhh),
run_id=read_store[read_id]["header_array"]["run_id"],
duration=read['len_raw_signal'],
))
if result:
read_counter += 1
if len(skipped) > 0:
for i in skipped:
sk.put(i)

return read_counter, read_store


Expand Down Expand Up @@ -241,15 +228,21 @@ def get_reads(args, client, read_counter, sk, read_store):
if len(calls) > 1:
split_reads = True
for call in calls:
# if int(call['metadata']['channel']) > 20000:
# print("Fake read:", call['metadata']['channel'])
# continue
# for i in call:
# if isinstance(call[i], dict):
# for j in call[i]:
# print("{}: {}".format(j, call[i][j]))
# else:
# print("{}: {}".format(i, call[i]))
if int(call['metadata']['channel']) > 20000:
# print("Fake read:", call['metadata']['channel'])
# print("Channel: {} - fake read {}/{}".format(call['metadata']['channel'], done, read_counter))
done -= 1
# print("ending set")
# ending = True
continue
# print("Channel: {} - read {}/{}".format(call['metadata']['channel'], done, read_counter))
# if call['metadata']['read_id'] == "d209a644-9086-4296-9a6f-7d037dcb959f":
# for i in call:
# if isinstance(call[i], dict):
# for j in call[i]:
# print("{}: {}".format(j, call[i][j]))
# else:
# print("{}: {}".format(i, call[i]))
try:
bcalled_read = {}
bcalled_read["sam_record"] = ""
Expand All @@ -261,8 +254,16 @@ def get_reads(args, client, read_counter, sk, read_store):
bcalled_read["read_id"] = read_id
bcalled_read["read_qscore"] = call['metadata']['mean_qscore']
bcalled_read["int_read_qscore"] = int(call['metadata']['mean_qscore'])
bcalled_read["header"] = "@{} parent_read_id={} model_version_id={} mean_qscore={}".format(bcalled_read["read_id"], bcalled_read["parent_read_id"], call['metadata'].get('model_version_id', model_id), bcalled_read["int_read_qscore"])
if args.call_mods:
bcalled_read["header"] = "@{} parent_read_id={} model_version_id={} modbase_model_version_id={} mean_qscore={}".format(bcalled_read["read_id"], bcalled_read["parent_read_id"], call['metadata'].get('model_version_id', model_id), call['metadata'].get('modbase_model_version_id', model_id), bcalled_read["int_read_qscore"])
else:
bcalled_read["header"] = "@{} parent_read_id={} model_version_id={} mean_qscore={}".format(bcalled_read["read_id"], bcalled_read["parent_read_id"], call['metadata'].get('model_version_id', model_id), bcalled_read["int_read_qscore"])
bcalled_read["sequence"] = call['datasets']['sequence']
if args.duplex:
bcalled_read["duplex_parent"] = call['metadata']['is_duplex_parent']
bcalled_read["duplex_strand_1"] = call['metadata'].get('duplex_strand_1', None)
bcalled_read["duplex_strand_2"] = call['metadata'].get('duplex_strand_2', None)

except Exception as error:
# handle the exception
print("An exception occurred in stage 1:", type(error).__name__, "-", error)
Expand Down Expand Up @@ -386,21 +387,74 @@ def basecaller_proc(args, iq, rq, sk, address, config, params, N):
client_sub.set_params(params)
# submit a batch of reads to be basecalled
with client_sub as client:
while True:
batch = iq.get()
if batch is None:
break
# print("[BASECALLER] - submitting channel: {}".format(batch[0]["channel_number"]))
# Submit to be basecalled
read_counter, read_store = submit_reads(args, client, sk, batch)
# now collect the basecalled reads
# print("[BASECALLER] - getting basecalled channel: {}".format(batch[0]["channel_number"]))
bcalled_list = get_reads(args, client, read_counter, sk, read_store)
# TODO: make a skipped queue to handle skipped reads
# print("[BASECALLER] - writing channel: {}".format(batch[0]["channel_number"]))
rq.put(bcalled_list)
iq.task_done()

if args.duplex:
fake_channel_start = 90000 # we are going to increment this so it doesn't try storing the same channels
read_counter = 0
read_store = {}
while True:
batch = iq.get()
if batch is None:
break
if batch == "end":
# send 10 fake reads with different channels to trick server to flushing cache
# server flushes after 10
read_id = list(read_store.keys())[0]
read = read_store[read_id]
print("[BASECALLER] - Sending fake reads to trick basecaller for channel: {}".format(ch))
for chhh in range(fake_channel_start, fake_channel_start+10):
result = client.pass_read(helper_functions.package_read(
raw_data=np.frombuffer(read['signal'], np.int16),
read_id=str(chhh),
start_time=read['start_time'],
daq_offset=read['offset'],
daq_scaling=calibration(read['digitisation'], read['range']),
sampling_rate=read['sampling_rate'],
mux=read['start_mux'],
channel=int(chhh),
run_id=read_store[read_id]["header_array"]["run_id"],
duration=read['len_raw_signal'],
))
# if result:
# read_counter += 0
# else:
# # TODO: put in some throttle/while loop for this
# print("failed to stuff in fake reads")
# sys.exit(1)
# increase counter by 1 to get the 1 fake read but not the other 10
bcalled_list = get_reads(args, client, read_counter, sk, read_store)
# TODO: make a skipped queue to handle skipped reads
print("[BASECALLER] - writing channel: {}".format(ch))
rq.put(bcalled_list)
last = False
read_counter = 0
read_store = {}
fake_channel_start += 10
else:
print("[BASECALLER] - submitting channel: {}".format(batch[0]["channel_number"]))
# Submit to be basecalled
rc, rs = submit_reads(args, client, sk, batch)
read_counter += rc
read_store.update(rs)
# now collect the basecalled reads
print("[BASECALLER] - getting basecalled channel: {}".format(batch[0]["channel_number"]))
ch = batch[0]["channel_number"]
iq.task_done()
else:
while True:
batch = iq.get()
if batch is None:
break
print("[BASECALLER] - submitting channel: {}".format(batch[0]["channel_number"]))
# Submit to be basecalled
read_counter, read_store = submit_reads(args, client, sk, batch, last)
# now collect the basecalled reads
print("[BASECALLER] - getting basecalled channel: {}".format(batch[0]["channel_number"]))
bcalled_list = get_reads(args, client, read_counter, sk, read_store, last)
# TODO: make a skipped queue to handle skipped reads
print("[BASECALLER] - writing channel: {}".format(batch[0]["channel_number"]))
rq.put(bcalled_list)
iq.task_done()

if args.profile:
pr.disable()
s = io.StringIO()
Expand Down

0 comments on commit cc7934d

Please sign in to comment.