Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
koenvanderveen committed Dec 15, 2023
1 parent 4ef5042 commit 7697188
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 49 deletions.
10 changes: 5 additions & 5 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,10 @@ def init_stores(

# We add the python id of the current node in order
# to create one connection per Node object in MongoClientCache
# so that we avoid closing the connection from a
# so that we avoid closing the connection from a
# different thread through the garbage collection
if isinstance(self.document_store_config, MongoStoreConfig):
self.document_store_config.client_config.node_obj_python_id = id(self)
self.document_store_config.client_config.node_obj_python_id = id(self)

self.document_store = document_store(
root_verify_key=self.verify_key,
Expand All @@ -768,10 +768,10 @@ def init_stores(
elif isinstance(action_store_config, MongoStoreConfig):
# We add the python id of the current node in order
# to create one connection per Node object in MongoClientCache
# so that we avoid closing the connection from a
# so that we avoid closing the connection from a
# different thread through the garbage collection
action_store_config.client_config.node_obj_python_id = id(self)
action_store_config.client_config.node_obj_python_id = id(self)

self.action_store = MongoActionStore(
root_verify_key=self.verify_key, store_config=action_store_config
)
Expand Down
70 changes: 32 additions & 38 deletions packages/syft/src/syft/store/blob_storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ...types.blob_storage import BlobFileType
from ...types.blob_storage import BlobStorageEntry
from ...types.blob_storage import CreateBlobStorageEntry
from ...types.blob_storage import DEFAULT_CHUNK_SIZE
from ...types.blob_storage import SecureFilePathLocation
from ...types.grid_url import GridURL
from ...types.syft_migration import migrate
Expand All @@ -72,6 +73,9 @@
from ...types.transforms import make_set_default
from ...types.uid import UID

DEFAULT_TIMEOUT = 10
MAX_RETRIES = 20


@serializable()
class BlobRetrievalV1(SyftObject):
Expand Down Expand Up @@ -169,48 +173,35 @@ class BlobRetrievalByURLV1(BlobRetrievalV1):
url: GridURL


def generate(blob_url, chunk_size):
max_tries = 20
pending = None
start_byte = 0
for attempt in range(max_tries):
def syft_iter_content(
blob_url, chunk_size, max_retries=MAX_RETRIES, timeout=DEFAULT_TIMEOUT
):
"""custom iter content with smart retries (start from last byte read)"""
current_byte = 0
for attempt in range(max_retries):
try:
headers = {'Range': f'bytes={start_byte}-'}
with requests.get(str(blob_url), stream=True, headers=headers, timeout=(10, 10)) as response:
headers = {"Range": f"bytes={current_byte}-"}
with requests.get(
str(blob_url), stream=True, headers=headers, timeout=(timeout, timeout)
) as response:
response.raise_for_status()
for chunk in response.iter_content(
chunk_size=chunk_size, decode_unicode=False
):
start_byte += len(chunk)
if b'\n' in chunk:
if pending is not None:
chunk = pending + chunk

lines = chunk.splitlines()

if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]:
pending = lines.pop()
else:
pending = None

yield from lines
else:
if pending is None:
pending = chunk
else:
pending = pending + chunk

if pending is not None:
yield pending
current_byte += len(chunk)
yield chunk
return

except requests.exceptions.RequestException as e:
if attempt < max_tries:
print(start_byte)
print(f"Attempt {attempt}/{max_tries} failed: {e}. Retrying...")
if attempt < max_retries:
print(
f"Attempt {attempt}/{max_retries} failed: {e} at byte {current_byte}. Retrying..."
)
else:
print(f"Max retries reached. Failed with error: {e}")
raise


class BlobRetrievalByURLV2(BlobRetrievalV1):
__canonical_name__ = "BlobRetrievalByURL"
__version__ = SYFT_OBJECT_VERSION_2
Expand All @@ -237,7 +228,7 @@ def read(self) -> Union[SyftObject, SyftError]:
else:
return self._read_data()

def _read_data(self, stream=False, chunk_size=512):
def _read_data(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE):
# relative
from ...client.api import APIRegistry

Expand All @@ -252,14 +243,17 @@ def _read_data(self, stream=False, chunk_size=512):
else:
blob_url = self.url
try:
response = requests.get(str(blob_url), stream=stream) # nosec
response.raise_for_status()
if self.type_ is BlobFileType:
if stream:
return generate(blob_url, chunk_size)
if stream:
return syft_iter_content(blob_url, chunk_size)
else:
response = requests.get(str(blob_url), stream=False) # nosec
response.raise_for_status()
return response.content
return deserialize(response.content, from_bytes=True)
else:
response = requests.get(str(blob_url), stream=stream) # nosec
response.raise_for_status()
return deserialize(response.content, from_bytes=True)
except requests.RequestException as e:
return SyftError(message=f"Failed to retrieve with Error: {e}")

Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/store/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class MongoStoreClientConfig(StoreClientConfig):
# Testing and connection reuse
client: Any = None

# this allows us to have one connection per `Node` object
# in the MongoClientCache
node_obj_python_id: Optional[int] = None


class MongoClientCache:
__client_cache__: Dict[str, Type["MongoClient"]] = {}
Expand All @@ -139,7 +143,7 @@ class MongoClient:
client: PyMongoClient = None

def __init__(self, config: MongoStoreClientConfig, cache: bool = True) -> None:
self.config=config
self.config = config
if config.client is not None:
self.client = config.client
elif cache:
Expand Down
31 changes: 26 additions & 5 deletions packages/syft/src/syft/types/blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .uid import UID

READ_EXPIRATION_TIME = 1800 # seconds
DEFAULT_CHUNK_SIZE = 10000 * 1024


@serializable()
Expand All @@ -65,7 +66,7 @@ class BlobFile(SyftObject):

__repr_attrs__ = ["id", "file_name"]

def read(self, stream=False, chunk_size=512, force=False):
def read(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, force=False):
# get blob retrieval object from api + syft_blob_storage_entry_id
read_method = from_api_or_context(
"blob_storage.read", self.syft_node_location, self.syft_client_verify_key
Expand All @@ -80,9 +81,29 @@ def upload_from_path(self, path, client):

return sy.ActionObject.from_path(path=path).send(client).syft_action_data

def _iter_lines(self, chunk_size=512):
"""Synchronous version of the async iter_lines"""
return self.read(stream=True, chunk_size=chunk_size)
def _iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE):
"""Synchronous version of the async iter_lines. This implementation
is also optimized in terms of splitting chunks, making it faster for
larger lines"""
pending = None
for chunk in self.read(stream=True, chunk_size=chunk_size):
if b"\n" in chunk:
if pending is not None:
chunk = pending + chunk
lines = chunk.splitlines()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]:
pending = lines.pop()
else:
pending = None
yield from lines
else:
if pending is None:
pending = chunk
else:
pending = pending + chunk

if pending is not None:
yield pending

def read_queue(self, queue, chunk_size, progress=False, buffer_lines=10000):
total_read = 0
Expand All @@ -103,7 +124,7 @@ def read_queue(self, queue, chunk_size, progress=False, buffer_lines=10000):
# Put anything not a string at the end
queue.put(0)

def iter_lines(self, chunk_size=512, progress=False):
def iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE, progress=False):
item_queue: Queue = Queue()
threading.Thread(
target=self.read_queue,
Expand Down

0 comments on commit 7697188

Please sign in to comment.