diff --git a/src/basecaller.py b/src/basecaller.py index 9fd5792..068d21e 100644 --- a/src/basecaller.py +++ b/src/basecaller.py @@ -44,6 +44,11 @@ def start_guppy_server_and_client(args, server_args): basecaller_bin = arg gotten = True continue + if arg in ["--do_read_splitting", "--detect_adapter", "--detect_mid_strand_adapter", "--detect_mid_strand_barcodes"]: + print("==========================================================================\n Depricated Arguments Detected\n==========================================================================") + print("{} no longer used, please check docs to learn more".format(arg)) + print("\n") + continue tmp_args.append(arg) if basecaller_bin is None: @@ -189,27 +194,9 @@ def submit_reads(args, client, sk, batch): break if result: read_counter += 1 - if args.duplex: - # send fake reads with different channels to trick server to flushing cache - for chhh in range(90000, 90011): - result = client.pass_read(helper_functions.package_read( - raw_data=np.frombuffer(read['signal'], np.int16), - read_id=str(chhh), - start_time=read['start_time'], - daq_offset=read['offset'], - daq_scaling=scale, - sampling_rate=read['sampling_rate'], - mux=read['start_mux'], - channel=int(chhh), - run_id=read_store[read_id]["header_array"]["run_id"], - duration=read['len_raw_signal'], - )) - if result: - read_counter += 1 if len(skipped) > 0: for i in skipped: sk.put(i) - return read_counter, read_store @@ -241,15 +228,21 @@ def get_reads(args, client, read_counter, sk, read_store): if len(calls) > 1: split_reads = True for call in calls: - # if int(call['metadata']['channel']) > 20000: - # print("Fake read:", call['metadata']['channel']) - # continue - # for i in call: - # if isinstance(call[i], dict): - # for j in call[i]: - # print("{}: {}".format(j, call[i][j])) - # else: - # print("{}: {}".format(i, call[i])) + if int(call['metadata']['channel']) > 20000: + # print("Fake read:", call['metadata']['channel']) + # print("Channel: {} - fake read {}/{}".format(call['metadata']['channel'], done, read_counter)) + done -= 1 + # print("ending set") + # ending = True + continue + # print("Channel: {} - read {}/{}".format(call['metadata']['channel'], done, read_counter)) + # if call['metadata']['read_id'] == "d209a644-9086-4296-9a6f-7d037dcb959f": + # for i in call: + # if isinstance(call[i], dict): + # for j in call[i]: + # print("{}: {}".format(j, call[i][j])) + # else: + # print("{}: {}".format(i, call[i])) try: bcalled_read = {} bcalled_read["sam_record"] = "" @@ -261,8 +254,16 @@ def get_reads(args, client, read_counter, sk, read_store): bcalled_read["read_id"] = read_id bcalled_read["read_qscore"] = call['metadata']['mean_qscore'] bcalled_read["int_read_qscore"] = int(call['metadata']['mean_qscore']) - bcalled_read["header"] = "@{} parent_read_id={} model_version_id={} mean_qscore={}".format(bcalled_read["read_id"], bcalled_read["parent_read_id"], call['metadata'].get('model_version_id', model_id), bcalled_read["int_read_qscore"]) + if args.call_mods: + bcalled_read["header"] = "@{} parent_read_id={} model_version_id={} modbase_model_version_id={} mean_qscore={}".format(bcalled_read["read_id"], bcalled_read["parent_read_id"], call['metadata'].get('model_version_id', model_id), call['metadata'].get('modbase_model_version_id', model_id), bcalled_read["int_read_qscore"]) + else: + bcalled_read["header"] = "@{} parent_read_id={} model_version_id={} mean_qscore={}".format(bcalled_read["read_id"], bcalled_read["parent_read_id"], call['metadata'].get('model_version_id', model_id), bcalled_read["int_read_qscore"]) bcalled_read["sequence"] = call['datasets']['sequence'] + if args.duplex: + bcalled_read["duplex_parent"] = call['metadata']['is_duplex_parent'] + bcalled_read["duplex_strand_1"] = call['metadata'].get('duplex_strand_1', None) + bcalled_read["duplex_strand_2"] = call['metadata'].get('duplex_strand_2', None) + except Exception as error: # handle the exception print("An exception occurred in stage 1:", type(error).__name__, "-", error) @@ -386,21 +387,74 @@ def basecaller_proc(args, iq, rq, sk, address, config, params, N): client_sub.set_params(params) # submit a batch of reads to be basecalled with client_sub as client: - while True: - batch = iq.get() - if batch is None: - break - # print("[BASECALLER] - submitting channel: {}".format(batch[0]["channel_number"])) - # Submit to be basecalled - read_counter, read_store = submit_reads(args, client, sk, batch) - # now collect the basecalled reads - # print("[BASECALLER] - getting basecalled channel: {}".format(batch[0]["channel_number"])) - bcalled_list = get_reads(args, client, read_counter, sk, read_store) - # TODO: make a skipped queue to handle skipped reads - # print("[BASECALLER] - writing channel: {}".format(batch[0]["channel_number"])) - rq.put(bcalled_list) - iq.task_done() - + if args.duplex: + fake_channel_start = 90000 # we are going to increment this so it doesn't try storing the same channels + read_counter = 0 + read_store = {} + while True: + batch = iq.get() + if batch is None: + break + if batch == "end": + # send 10 fake reads with different channels to trick server to flushing cache + # server flushes after 10 + read_id = list(read_store.keys())[0] + read = read_store[read_id] + print("[BASECALLER] - Sending fake reads to trick basecaller for channel: {}".format(ch)) + for chhh in range(fake_channel_start, fake_channel_start+10): + result = client.pass_read(helper_functions.package_read( + raw_data=np.frombuffer(read['signal'], np.int16), + read_id=str(chhh), + start_time=read['start_time'], + daq_offset=read['offset'], + daq_scaling=calibration(read['digitisation'], read['range']), + sampling_rate=read['sampling_rate'], + mux=read['start_mux'], + channel=int(chhh), + run_id=read_store[read_id]["header_array"]["run_id"], + duration=read['len_raw_signal'], + )) + # if result: + # read_counter += 0 + # else: + # # TODO: put in some throttle/while loop for this + # print("failed to stuff in fake reads") + # sys.exit(1) + # increase counter by 1 to get the 1 fake read but not the other 10 + bcalled_list = get_reads(args, client, read_counter, sk, read_store) + # TODO: make a skipped queue to handle skipped reads + print("[BASECALLER] - writing channel: {}".format(ch)) + rq.put(bcalled_list) + last = False + read_counter = 0 + read_store = {} + fake_channel_start += 10 + else: + print("[BASECALLER] - submitting channel: {}".format(batch[0]["channel_number"])) + # Submit to be basecalled + rc, rs = submit_reads(args, client, sk, batch) + read_counter += rc + read_store.update(rs) + # now collect the basecalled reads + print("[BASECALLER] - getting basecalled channel: {}".format(batch[0]["channel_number"])) + ch = batch[0]["channel_number"] + iq.task_done() + else: + while True: + batch = iq.get() + if batch is None: + break + print("[BASECALLER] - submitting channel: {}".format(batch[0]["channel_number"])) + # Submit to be basecalled + read_counter, read_store = submit_reads(args, client, sk, batch, last) + # now collect the basecalled reads + print("[BASECALLER] - getting basecalled channel: {}".format(batch[0]["channel_number"])) + bcalled_list = get_reads(args, client, read_counter, sk, read_store, last) + # TODO: make a skipped queue to handle skipped reads + print("[BASECALLER] - writing channel: {}".format(batch[0]["channel_number"])) + rq.put(bcalled_list) + iq.task_done() + if args.profile: pr.disable() s = io.StringIO()