diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index c91105b16..c7f0d664d 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -1,6 +1,7 @@ # coding: utf-8 import os import functools +import threading import pandas as pd from typing import List, Union from loguru import logger @@ -230,28 +231,63 @@ def to_arrow(df: pd.DataFrame): return sink.getvalue() +# def timeout_decorator(timeout): +# """超时装饰器 +# +# :param timeout: int, 超时时间,单位秒 +# """ +# +# def decorator(func): +# @functools.wraps(func) +# def wrapper(*args, **kwargs): +# from concurrent.futures import ThreadPoolExecutor, TimeoutError +# +# with ThreadPoolExecutor() as executor: +# future = executor.submit(func, *args, **kwargs) +# try: +# result = future.result(timeout=timeout) +# return result +# except TimeoutError: +# logger.warning( +# f"{func.__name__} timed out after {timeout} seconds;" f"args: {args}; kwargs: {kwargs}" +# ) +# return None +# +# return wrapper +# +# return decorator + + def timeout_decorator(timeout): - """超时装饰器 + """Timeout decorator using threading - :param timeout: int, 超时时间,单位秒 + :param timeout: int, timeout duration in seconds """ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - from concurrent.futures import ThreadPoolExecutor, TimeoutError + result = [None] + exception = [None] - with ThreadPoolExecutor() as executor: - future = executor.submit(func, *args, **kwargs) + def target(): try: - result = future.result(timeout=timeout) - return result - except TimeoutError: - # print(f"{func.__name__} timed out after {timeout} seconds") - logger.warning( - f"{func.__name__} timed out after {timeout} seconds;" f"args: {args}; kwargs: {kwargs}" - ) - raise ValueError(f"{func.__name__} timed out after {timeout} seconds") + result[0] = func(*args, **kwargs) + except Exception as e: + exception[0] = e + + thread = threading.Thread(target=target) + thread.start() + thread.join(timeout) + + if thread.is_alive(): + logger.warning(f"{func.__name__} timed out after {timeout} seconds; args: {args}; kwargs: {kwargs}") + return None + + if exception[0]: + raise exception[0] + + return result[0] return wrapper diff --git a/test/test_utils.py b/test/test_utils.py index 05bc49024..ac95ae04b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -360,8 +360,7 @@ def fast_function(): def test_timeout_decorator_timeout(): @timeout_decorator(1) def slow_function(): - time.sleep(2) + time.sleep(5) return "Completed" - with pytest.raises(ValueError, match="timed out after 1 seconds"): - slow_function() + assert slow_function() is None