Skip to content

Commit

Permalink
fix: make SharedResource threadsafe (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
jezekra1 authored Feb 2, 2024
1 parent 0d007b6 commit 70a889f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/genai/_utils/shared_instance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from abc import abstractmethod
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import Generic, Optional, TypeVar
Expand All @@ -16,6 +17,7 @@ class SharedResource(Generic[T], AbstractContextManager):
def __init__(self):
self._ref_count = 0
self._resource: Optional[T] = None
self._lock = threading.Lock()

@abstractmethod
def _enter(self) -> T:
Expand All @@ -35,18 +37,20 @@ def _exit(self) -> None:
raise NotImplementedError

def __enter__(self) -> T:
self._ref_count += 1
if self._ref_count == 1:
self._resource = self._enter()
with self._lock:
self._ref_count += 1
if self._ref_count == 1:
self._resource = self._enter()

assert self._resource
return self._resource
assert self._resource
return self._resource

def __exit__(self, exc_type, exc_val, exc_tb):
self._ref_count -= 1
if self._ref_count == 0:
self._exit()
self._resource = None
with self._lock:
self._ref_count -= 1
if self._ref_count == 0:
self._exit()
self._resource = None


class AsyncSharedResource(Generic[T], AbstractAsyncContextManager):
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/utils/test_async_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from asyncio import sleep
from unittest.mock import Mock
Expand Down Expand Up @@ -109,3 +110,25 @@ async def handler(input: str, *args) -> str:
def test_execute_empty_inputs(self):
for _ in execute_async(inputs=[], handler=Mock(), http_client=Mock(), throw_on_error=True):
...

@pytest.mark.asyncio
async def test_async_executor_can_be_used_in_async_context(self, http_client):
"""Async executor can be used in asyncio event loop using asyncio.to_thread"""

def _execute(input: str):
return list(
execute_async(
inputs=[input],
handler=self.get_handler([input]),
http_client=lambda: AsyncHttpxClient(),
throw_on_error=True,
ordered=True,
limiters=[LoopBoundLimiter(lambda: LocalLimiter(limit=10))],
)
)[0]

inputs = ["Hello", "World", "here", "are", "some", "inputs"] * 50
tasks = [asyncio.to_thread(_execute, input) for input in inputs]
results = await asyncio.gather(*tasks)

assert results == inputs

0 comments on commit 70a889f

Please sign in to comment.