diff --git a/ibllib/pipes/misc.py b/ibllib/pipes/misc.py index 39871ad00..44b025f0c 100644 --- a/ibllib/pipes/misc.py +++ b/ibllib/pipes/misc.py @@ -9,8 +9,9 @@ import sys import time import logging +from functools import wraps from pathlib import Path -from typing import Union, List +from typing import Union, List, Callable, Any from inspect import signature import uuid import socket @@ -1148,13 +1149,54 @@ class WindowsInhibitor: ES_CONTINUOUS = 0x80000000 ES_SYSTEM_REQUIRED = 0x00000001 - def __init__(self): - pass + @staticmethod + def _set_thread_execution_state(state: int) -> None: + result = ctypes.windll.kernel32.SetThreadExecutionState(state) + if result == 0: + log.error("Failed to set thread execution state.") - def inhibit(self): - print("Preventing Windows from going to sleep") - ctypes.windll.kernel32.SetThreadExecutionState(WindowsInhibitor.ES_CONTINUOUS | WindowsInhibitor.ES_SYSTEM_REQUIRED) + @staticmethod + def inhibit(quiet: bool = False): + if quiet: + log.debug("Preventing Windows from going to sleep") + else: + print("Preventing Windows from going to sleep") + WindowsInhibitor._set_thread_execution_state(WindowsInhibitor.ES_CONTINUOUS | WindowsInhibitor.ES_SYSTEM_REQUIRED) + + @staticmethod + def uninhibit(quiet: bool = False): + if quiet: + log.debug("Allowing Windows to go to sleep") + else: + print("Allowing Windows to go to sleep") + WindowsInhibitor._set_thread_execution_state(WindowsInhibitor.ES_CONTINUOUS) + + +def sleepless(func: Callable[..., Any]) -> Callable[..., Any]: + """ + Decorator to ensure that the system doesn't enter sleep or idle mode during a long-running task. + + This decorator wraps a function and sets the thread execution state to prevent + the system from entering sleep or idle mode while the decorated function is + running. + + Parameters + ---------- + func : callable + The function to decorate. + + Returns + ------- + callable + The decorated function. + """ - def uninhibit(self): - print("Allowing Windows to go to sleep") - ctypes.windll.kernel32.SetThreadExecutionState(WindowsInhibitor.ES_CONTINUOUS) + @wraps(func) + def inner(*args, **kwargs) -> Any: + if os.name == 'nt': + WindowsInhibitor().inhibit(quiet=True) + result = func(*args, **kwargs) + if os.name == 'nt': + WindowsInhibitor().uninhibit(quiet=True) + return result + return inner diff --git a/ibllib/tests/test_pipes.py b/ibllib/tests/test_pipes.py index ba5c282dd..cbe86462a 100644 --- a/ibllib/tests/test_pipes.py +++ b/ibllib/tests/test_pipes.py @@ -21,6 +21,7 @@ import ibllib.io.extractors.base import ibllib.tests.fixtures.utils as fu from ibllib.pipes import misc +from ibllib.pipes.misc import sleepless from ibllib.tests import TEST_DB import ibllib.pipes.scan_fix_passive_files as fix from ibllib.pipes.base_tasks import RegisterRawDataTask @@ -698,5 +699,23 @@ def test_rename_files(self): self.assertCountEqual(expected, files) +class TestSleeplessDecorator(unittest.TestCase): + + def test_decorator_argument_passing(self): + + def dummy_function(arg1, arg2): + return arg1, arg2 + + # Applying the decorator to the dummy function + decorated_func = sleepless(dummy_function) + + # Check if the function name is maintained + self.assertEqual(decorated_func.__name__, 'dummy_function') + + # Check if arguments are passed correctly + result = decorated_func("test1", "test2") + self.assertEqual(result, ("test1", "test2")) + + if __name__ == '__main__': unittest.main(exit=False, verbosity=2)