diff --git a/chainlite/utils.py b/chainlite/utils.py index a11d99d..d0a3b69 100644 --- a/chainlite/utils.py +++ b/chainlite/utils.py @@ -1,6 +1,9 @@ +import asyncio import logging from typing import Optional +from tqdm import tqdm + def get_logger(name: Optional[str] = None): logger = logging.getLogger(name) @@ -12,3 +15,44 @@ def get_logger(name: Optional[str] = None): handler.setFormatter(formatter) logger.addHandler(handler) return logger + + +logger = get_logger(__name__) + + +async def run_async_in_parallel( + async_function, iterable, max_concurrency: int, desc: str = "" +): + semaphore = asyncio.Semaphore(max_concurrency) # Limit concurrent tasks + + async def async_function_with_semaphore(f, i, original_index) -> tuple: + # Acquire the semaphore to limit the number of concurrent tasks + async with semaphore: + try: + # Execute the asynchronous function and get the result + result = await f(i) + # Return the original index, result, and no error + return original_index, result, None + except Exception as e: + # If an exception occurs, return the original index, no result, and the error message + logger.exception(f"Task {original_index} failed with error: {e}") + return original_index, None, str(e) + + tasks = [] + for original_index, item in enumerate(iterable): + tasks.append( + async_function_with_semaphore(async_function, item, original_index) + ) + + ret = [None] * len(tasks) + for future in tqdm( + asyncio.as_completed(tasks), total=len(tasks), smoothing=0, desc=desc + ): + original_index, result, error = await future + if error: + # logger.error(f"Task {original_index} failed with error: {error}") + ret[original_index] = None # set it to some error indicator + else: + ret[original_index] = result + + return ret diff --git a/tests/test_llm_generate.py b/tests/test_llm_generate.py index ffac75b..c1cb29a 100644 --- a/tests/test_llm_generate.py +++ b/tests/test_llm_generate.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime import time from typing import List @@ -14,6 +15,7 @@ register_prompt_constants, get_total_cost ) +from chainlite.utils import run_async_in_parallel from chainlite.llm_config import GlobalVars from pydantic import BaseModel import random @@ -192,3 +194,17 @@ async def test_cache(): ), "The second (cached) LLM call should be much faster than the first call" assert first_cost > 0, "The cost should be greater than 0" assert second_cost == first_cost, "The cost should not change after a cached LLM call" + + +@pytest.mark.asyncio(scope="session") +async def test_run_async_in_parallel(): + + async def async_function(i): + await asyncio.sleep(1) + return i + + test_inputs = range(10) + max_concurrency = 5 + desc = "test" + ret = await run_async_in_parallel(async_function, test_inputs, max_concurrency, desc) + assert ret == list(test_inputs) \ No newline at end of file