Skip to content

Commit

Permalink
working with 7.3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
Psy-Fer committed May 15, 2024
1 parent 0860e66 commit 55c63b5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 41 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
pyslow5>=1.1.0
ont-pyguppy-client-lib==7.1.4
# ont-pyguppy-client-lib==7.2.13
ont-pybasecall-client-lib==7.3.10
98 changes: 58 additions & 40 deletions src/buttery_eel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,21 @@

import pyslow5

import pyguppy_client_lib
from pyguppy_client_lib.pyclient import PyGuppyClient
from pyguppy_client_lib import helper_functions
try:
import pybasecall_client_lib
from pybasecall_client_lib.pyclient import PyBasecallClient as pclient
from pybasecall_client_lib import helper_functions

except ImportError:
# maybe i can do a version check or something? hard to do from this side of the software
print("Could not load pybasecall, trying for version earlier versions <=7.2.15 pyguppy lib")
try:
import pyguppy_client_lib
from pyguppy_client_lib.pyclient import PyGuppyClient as pclient
from pyguppy_client_lib import helper_functions
except ImportError:
print("Can't import pybasecall_client_lib or pyguppy_client_lib, please check environment and try again.")
sys.exit(1)

import cProfile, pstats, io

Expand Down Expand Up @@ -100,7 +112,7 @@ def start_guppy_server_and_client(args, server_args):
address = "{}".format(port)
else:
address = "localhost:{}".format(port)
client = PyGuppyClient(address=address, config=args.config)
client = pclient(address=address, config=args.config)


print("Setting params...")
Expand Down Expand Up @@ -205,10 +217,10 @@ def read_worker(args, iq):
if sfile.endswith(('.blow5', '.slow5')):
s5 = pyslow5.Open(os.path.join(dirpath, sfile), 'r')
reads = s5.seq_reads_multi(threads=args.slow5_threads, batchsize=args.slow5_batchsize, aux='all')
if args.seq_sum:
num_read_groups = s5.get_num_read_groups()
for read_group in range(num_read_groups):
header_array[read_group] = s5.get_all_headers(read_group=read_group)
# if args.seq_sum:
num_read_groups = s5.get_num_read_groups()
for read_group in range(num_read_groups):
header_array[read_group] = s5.get_all_headers(read_group=read_group)
batches = get_slow5_batch(args, s5, reads, size=args.slow5_batchsize, slow5_filename=sfile, header_array=header_array)
# put batches of reads onto the queue
for batch in chain(batches):
Expand All @@ -224,10 +236,10 @@ def read_worker(args, iq):
s5 = pyslow5.Open(args.input, 'r')
filename_slow5 = args.input.split("/")[-1]
reads = s5.seq_reads_multi(threads=args.slow5_threads, batchsize=args.slow5_batchsize, aux='all')
if args.seq_sum:
num_read_groups = s5.get_num_read_groups()
for read_group in range(num_read_groups):
header_array[read_group] = s5.get_all_headers(read_group=read_group)
# if args.seq_sum:
num_read_groups = s5.get_num_read_groups()
for read_group in range(num_read_groups):
header_array[read_group] = s5.get_all_headers(read_group=read_group)
batches = get_slow5_batch(args, s5, reads, size=args.slow5_batchsize, slow5_filename=filename_slow5, header_array=header_array)
# this adds a limit to how many reads it will load into memory so we
# don't blow the ram up
Expand Down Expand Up @@ -474,7 +486,7 @@ def submit_read(args, iq, rq, address, config, params, N):
done = 0
bcalled_list = []

client_sub = PyGuppyClient(address=address, config=config)
client_sub = pclient(address=address, config=config)
client_sub.set_params(params)
# submit a batch of reads to be basecalled
with client_sub as client:
Expand All @@ -485,21 +497,25 @@ def submit_read(args, iq, rq, address, config, params, N):
break
for read in batch:
read_id = read['read_id']
if args.seq_sum:
read_store[read_id] = read
# if args.seq_sum:
read_store[read_id] = read
# calculate scale
scale = calibration(read['digitisation'], read['range'])
result = False
tries = 0
while not result:
result = client.pass_read(
helper_functions.package_read(
read_id=read_id,
result = client.pass_read(helper_functions.package_read(
raw_data=np.frombuffer(read['signal'], np.int16),
read_id=read_id,
start_time=read['start_time'],
daq_offset=read['offset'],
daq_scaling=scale,
)
)
sampling_rate=read['sampling_rate'],
mux=read['start_mux'],
channel=int(read["channel_number"]),
run_id=read_store[read_id]["header_array"]["run_id"],
duration=read['len_raw_signal'],
))
if tries > 1:
time.sleep(client.throttle)
tries += 1
Expand Down Expand Up @@ -574,8 +590,10 @@ def submit_read(args, iq, rq, address, config, params, N):
bcalled_read["trimmed_samples"] = call['metadata']['trimmed_samples']
trimmed_duration = call['metadata']['trimmed_duration']
bcalled_read["num_samples"] = trimmed_duration + bcalled_read["trimmed_samples"]
if bcalled_read["num_samples"] != raw_num_samples:
print("WARNING: {} ns:i:{} != raw_num_samples:{}".format(bcalled_read["read_id"], bcalled_read["num_samples"], raw_num_samples))
# turning this warning off, as it was put here to alert us to ns tag changing.
# ONT has changed the definition of this field, and this warning told us about it.
# if bcalled_read["num_samples"] != raw_num_samples:
# print("WARNING: {} ns:i:{} != raw_num_samples:{}".format(bcalled_read["read_id"], bcalled_read["num_samples"], raw_num_samples))
except Exception as error:
# handle the exception
print("An exception occurred in stage 2:", type(error).__name__, "-", error)
Expand Down Expand Up @@ -692,25 +710,25 @@ def get_slow5_batch(args, slow5_obj, reads, size=4096, slow5_filename=None, head
"""
batch = []
for read in reads:
if args.seq_sum:
# get header once for each read group
read_group = read["read_group"]
# if read_group not in header_array:
# header_array[read_group] = slow5_obj.get_all_headers(read_group=read_group)
# get aux data for ead read
# if args.seq_sum:
# get header once for each read group
read_group = read["read_group"]
# if read_group not in header_array:
# header_array[read_group] = slow5_obj.get_all_headers(read_group=read_group)
# get aux data for ead read


aux_data = {"channel_number": read["channel_number"],
"start_mux": read["start_mux"],
"start_time": read["start_time"],
"read_number": read["read_number"],
"end_reason": read["end_reason"],
"median_before": read["median_before"],
"end_reason_labels": slow5_obj.get_aux_enum_labels('end_reason')
}
read["aux_data"] = aux_data
read["header_array"] = header_array[read_group]
read["slow5_filename"] = slow5_filename
aux_data = {"channel_number": read["channel_number"],
"start_mux": read["start_mux"],
"start_time": read["start_time"],
"read_number": read["read_number"],
"end_reason": read["end_reason"],
"median_before": read["median_before"],
"end_reason_labels": slow5_obj.get_aux_enum_labels('end_reason')
}
read["aux_data"] = aux_data
read["header_array"] = header_array[read_group]
read["slow5_filename"] = slow5_filename

batch.append(read)
if len(batch) >= size:
Expand Down

0 comments on commit 55c63b5

Please sign in to comment.