diff --git a/README.md b/README.md index 8e041dd..c4cf669 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,9 @@ from pyguppyclient import GuppyBasecallerClient, yield_reads config = "dna_r9.4.1_450bps_fast" read_file = "reads.fast5" -with GuppyBasecallerClient(config_name=config) as client: +with GuppyBasecallerClient(config_name=config, trace=True) as client: for read in yield_reads(read_file): - called = client.basecall(read, trace=True) + called = client.basecall(read) print(read.read_id, called.seq[:50], called.move) ``` diff --git a/pyguppyclient/client.py b/pyguppyclient/client.py index fd0f1c2..af77e89 100644 --- a/pyguppyclient/client.py +++ b/pyguppyclient/client.py @@ -2,10 +2,10 @@ Guppy Client """ -from collections import deque import time import asyncio import logging +from collections import deque import zmq import zmq.asyncio @@ -26,7 +26,7 @@ class GuppyClientBase: """ Blocking Guppy Base Client """ - def __init__(self, config_name, host="localhost", port=5555, timeout=0.1, retries=50): + def __init__(self, config_name, host="localhost", port=5555, timeout=0.1, retries=50, state=False, trace=False): self.timeout = timeout self.retries = retries self.config_name = parse_config(config_name) @@ -37,6 +37,8 @@ def __init__(self, config_name, host="localhost", port=5555, timeout=0.1, retrie self.socket.connect("tcp://%s:%s" % (host, port)) self.client_id = 0 self.pcl_client = PCLClient("%s:%s" % (host, port), self.config_name) + self.pcl_client.set_params({'state_data_enabled': state}) + self.pcl_client.set_params({'move_and_trace_enabled': trace}) _init_pcl_client(self.pcl_client) def __enter__(self): @@ -77,8 +79,9 @@ def connect(self): pass elif ret != result.success: raise ConnectionError( - "Connect with '{}' failed: {}".format(self.config_name, - self.pcl_client.get_error_message()) + "Connect with '{}' failed: {}".format( + self.config_name, self.pcl_client.get_error_message() + ) ) def disconnect(self): @@ -109,23 +112,19 @@ class GuppyBasecallerClient(GuppyClientBase): """ Blocking Guppy Basecall Client """ - def __init__(self, **kwargs): super().__init__(**kwargs) self.read_cache = deque() - def basecall(self, read, state=False, trace=False): + def basecall(self, read): """ Basecall a `ReadData` object and get a `CalledReadData` object - - :param trace: flag for returning the flipflop trace table from the server. - :param state: flag for returning the state table (requires --post_out). """ n = 0 self.pass_read(read) while n < self.retries: n += 1 - result = self._get_called_read(state=state, trace=trace) + result = self._get_called_read() if result is not None: return result time.sleep(self.timeout) @@ -134,7 +133,7 @@ def basecall(self, read, state=False, trace=False): "Basecall response not received after {}s for read '{}'".format(self.timeout, read.read_id) ) - def _get_called_read(self, state=False, trace=False): + def _get_called_read(self): """ Get the `CalledReadData` object back from the server """ @@ -221,7 +220,7 @@ async def pass_read(self, read): } return await self.pcl_client.pass_read(read_dict) - async def get_called_read(self, trace=False, state=False): + async def get_called_read(self): """ Get the `CalledReadData` object back from the server """ diff --git a/tests/client_tests.py b/tests/client_tests.py index b24884d..9fc4dc7 100644 --- a/tests/client_tests.py +++ b/tests/client_tests.py @@ -16,7 +16,7 @@ class ClientTest(TestCase): def setUp(self): self.read_loader = yield_reads(self.read_file) - self.client = GuppyBasecallerClient(config_name=self.config_fast, port=self.port) + self.client = GuppyBasecallerClient(config_name=self.config_fast, port=self.port, trace=True, state=True) self.client.connect() def tearDown(self): @@ -35,18 +35,20 @@ def test_read_without_state(self): """ test a read without state """ self.client.pass_read(next(self.read_loader)) time.sleep(1) - self.client._get_called_read(state=False) + self.client._get_called_read() def test_read_with_state(self): """ test a read with state """ self.client.pass_read(next(self.read_loader)) time.sleep(1) - self.client._get_called_read(state=True) + res, called = self.client._get_called_read() + self.assertTrue(called.state is not None) + self.assertTrue(called.trace is not None) + self.assertTrue(called.move is not None) def test_invalid_config(self): """ try and load in invalid config """ - bad_client = GuppyBasecallerClient(config_name="not_a_config", - port=self.port) + bad_client = GuppyBasecallerClient(config_name="not_a_config", port=self.port) with self.assertRaises(ConnectionError): bad_client.connect()