Skip to content

Commit

Permalink
Add run_async_in_parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed Oct 24, 2024
1 parent 364ea7f commit b5867b1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
44 changes: 44 additions & 0 deletions chainlite/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import logging
from typing import Optional

from tqdm import tqdm


def get_logger(name: Optional[str] = None):
logger = logging.getLogger(name)
Expand All @@ -12,3 +15,44 @@ def get_logger(name: Optional[str] = None):
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger


logger = get_logger(__name__)


async def run_async_in_parallel(
async_function, iterable, max_concurrency: int, desc: str = ""
):
semaphore = asyncio.Semaphore(max_concurrency) # Limit concurrent tasks

async def async_function_with_semaphore(f, i, original_index) -> tuple:
# Acquire the semaphore to limit the number of concurrent tasks
async with semaphore:
try:
# Execute the asynchronous function and get the result
result = await f(i)
# Return the original index, result, and no error
return original_index, result, None
except Exception as e:
# If an exception occurs, return the original index, no result, and the error message
logger.exception(f"Task {original_index} failed with error: {e}")
return original_index, None, str(e)

tasks = []
for original_index, item in enumerate(iterable):
tasks.append(
async_function_with_semaphore(async_function, item, original_index)
)

ret = [None] * len(tasks)
for future in tqdm(
asyncio.as_completed(tasks), total=len(tasks), smoothing=0, desc=desc
):
original_index, result, error = await future
if error:
# logger.error(f"Task {original_index} failed with error: {error}")
ret[original_index] = None # set it to some error indicator
else:
ret[original_index] = result

return ret
16 changes: 16 additions & 0 deletions tests/test_llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime
import time
from typing import List
Expand All @@ -14,6 +15,7 @@
register_prompt_constants,
get_total_cost
)
from chainlite.utils import run_async_in_parallel
from chainlite.llm_config import GlobalVars
from pydantic import BaseModel
import random
Expand Down Expand Up @@ -192,3 +194,17 @@ async def test_cache():
), "The second (cached) LLM call should be much faster than the first call"
assert first_cost > 0, "The cost should be greater than 0"
assert second_cost == first_cost, "The cost should not change after a cached LLM call"


@pytest.mark.asyncio(scope="session")
async def test_run_async_in_parallel():

async def async_function(i):
await asyncio.sleep(1)
return i

test_inputs = range(10)
max_concurrency = 5
desc = "test"
ret = await run_async_in_parallel(async_function, test_inputs, max_concurrency, desc)
assert ret == list(test_inputs)

0 comments on commit b5867b1

Please sign in to comment.