diff --git a/a_sync/a_sync/modifiers/semaphores.py b/a_sync/a_sync/modifiers/semaphores.py index 59c31d64..8a781c94 100644 --- a/a_sync/a_sync/modifiers/semaphores.py +++ b/a_sync/a_sync/modifiers/semaphores.py @@ -6,9 +6,6 @@ from a_sync import exceptions, primitives from a_sync._typing import * -# We keep this here for now so we don't break downstream deps. Eventually will be removed. -from a_sync.primitives import ThreadsafeSemaphore, DummySemaphore - @overload def apply_semaphore( # type: ignore [misc] @@ -134,7 +131,7 @@ def apply_semaphore( `primitives.Semaphore` is a subclass of `asyncio.Semaphore`. Therefore, when the documentation refers to `asyncio.Semaphore`, it also includes `primitives.Semaphore` and any other subclasses. """ # Parse Inputs - if isinstance(coro_fn, (int, asyncio.Semaphore)): + if isinstance(coro_fn, (int, asyncio.Semaphore, primitives.Semaphore)): if semaphore is not None: raise ValueError("You can only pass in one arg.") semaphore = coro_fn @@ -146,7 +143,7 @@ def apply_semaphore( # Create the semaphore if necessary if isinstance(semaphore, int): semaphore = primitives.ThreadsafeSemaphore(semaphore) - elif not isinstance(semaphore, asyncio.Semaphore): + elif not isinstance(semaphore, (asyncio.Semaphore, primitives.Semaphore)): raise TypeError( f"'semaphore' must either be an integer or a Semaphore object. You passed {semaphore}" ) diff --git a/tests/a_sync/modifiers/test_apply_semaphore.py b/tests/a_sync/modifiers/test_apply_semaphore.py new file mode 100644 index 00000000..7c53271d --- /dev/null +++ b/tests/a_sync/modifiers/test_apply_semaphore.py @@ -0,0 +1,42 @@ +import asyncio +import pytest + +from a_sync import Semaphore, apply_semaphore +from a_sync.exceptions import FunctionNotAsync + + +@pytest.mark.asyncio_cooperative +async def test_apply_semaphore_int(): + apply_semaphore(asyncio.sleep, 1) + + +@pytest.mark.asyncio_cooperative +async def test_apply_semaphore_asyncio_semaphore(): + apply_semaphore(asyncio.sleep, asyncio.Semaphore(1)) + + +@pytest.mark.asyncio_cooperative +async def test_apply_semaphore_a_sync_semaphore(): + apply_semaphore(asyncio.sleep, Semaphore(1)) + + +def fail(): + pass + + +@pytest.mark.asyncio_cooperative +async def test_apply_semaphore_failure_int(): + with pytest.raises(FunctionNotAsync): + apply_semaphore(fail, 1) + + +@pytest.mark.asyncio_cooperative +async def test_apply_semaphore_failure_asyncio_semaphore(): + with pytest.raises(FunctionNotAsync): + apply_semaphore(fail, asyncio.Semaphore(1)) + + +@pytest.mark.asyncio_cooperative +async def test_apply_semaphore_failure_a_sync_semaphore(): + with pytest.raises(FunctionNotAsync): + apply_semaphore(fail, Semaphore(1))