diff --git a/src/faster_whisper_server/dependencies.py b/src/faster_whisper_server/dependencies.py index 3266b14..980a2d1 100644 --- a/src/faster_whisper_server/dependencies.py +++ b/src/faster_whisper_server/dependencies.py @@ -1,4 +1,5 @@ from functools import lru_cache +import logging from typing import Annotated from fastapi import Depends, HTTPException, status @@ -11,7 +12,13 @@ from faster_whisper_server.config import Config from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager +logger = logging.getLogger(__name__) +# NOTE: `get_config` is called directly instead of using sub-dependencies so that these functions could be used outside of `FastAPI` # noqa: E501 + + +# https://fastapi.tiangolo.com/advanced/settings/?h=setti#creating-the-settings-only-once-with-lru_cache +# WARN: Any new module that ends up calling this function directly (not through `FastAPI` dependency injection) should be patched in `tests/conftest.py` # noqa: E501 @lru_cache def get_config() -> Config: return Config() @@ -22,7 +29,7 @@ def get_config() -> Config: @lru_cache def get_model_manager() -> WhisperModelManager: - config = get_config() # HACK + config = get_config() return WhisperModelManager(config.whisper) @@ -31,8 +38,8 @@ def get_model_manager() -> WhisperModelManager: @lru_cache def get_piper_model_manager() -> PiperModelManager: - config = get_config() # HACK - return PiperModelManager(config.whisper.ttl) # HACK + config = get_config() + return PiperModelManager(config.whisper.ttl) # HACK: should have its own config PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)] @@ -53,7 +60,7 @@ async def verify_api_key( @lru_cache def get_completion_client() -> AsyncCompletions: - config = get_config() # HACK + config = get_config() oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key) return oai_client.chat.completions @@ -63,9 +70,9 @@ def get_completion_client() -> AsyncCompletions: @lru_cache def get_speech_client() -> AsyncSpeech: - config = get_config() # HACK + config = get_config() if config.speech_base_url is None: - # this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501 + # this might not work as expected if `speech_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501 from faster_whisper_server.routers.speech import ( router as speech_router, ) @@ -86,7 +93,7 @@ def get_speech_client() -> AsyncSpeech: def get_transcription_client() -> AsyncTranscriptions: config = get_config() if config.transcription_base_url is None: - # this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501 + # this might not work as expected if `transcription_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501 from faster_whisper_server.routers.stt import ( router as stt_router, ) diff --git a/src/faster_whisper_server/logger.py b/src/faster_whisper_server/logger.py index fb283ef..da7de12 100644 --- a/src/faster_whisper_server/logger.py +++ b/src/faster_whisper_server/logger.py @@ -1,11 +1,8 @@ import logging -from faster_whisper_server.dependencies import get_config - -def setup_logger() -> None: - config = get_config() # HACK +def setup_logger(log_level: str) -> None: logging.getLogger().setLevel(logging.INFO) logger = logging.getLogger(__name__) - logger.setLevel(config.log_level.upper()) + logger.setLevel(log_level.upper()) logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s") diff --git a/src/faster_whisper_server/main.py b/src/faster_whisper_server/main.py index f59036b..d367539 100644 --- a/src/faster_whisper_server/main.py +++ b/src/faster_whisper_server/main.py @@ -27,10 +27,12 @@ def create_app() -> FastAPI: - setup_logger() - + config = get_config() # HACK + setup_logger(config.log_level) logger = logging.getLogger(__name__) + logger.debug(f"Config: {config}") + if platform.machine() == "x86_64": from faster_whisper_server.routers.speech import ( router as speech_router, @@ -39,9 +41,6 @@ def create_app() -> FastAPI: logger.warning("`/v1/audio/speech` is only supported on x86_64 machines") speech_router = None - config = get_config() # HACK - logger.debug(f"Config: {config}") - model_manager = get_model_manager() # HACK @asynccontextmanager diff --git a/tests/conftest.py b/tests/conftest.py index 2236926..608c6e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ from collections.abc import AsyncGenerator, Generator +from contextlib import AbstractAsyncContextManager, asynccontextmanager import logging import os +from typing import Protocol from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient @@ -8,19 +10,31 @@ from openai import AsyncOpenAI import pytest import pytest_asyncio +from pytest_mock import MockerFixture +from faster_whisper_server.config import Config, WhisperConfig +from faster_whisper_server.dependencies import get_config from faster_whisper_server.main import create_app -disable_loggers = ["multipart.multipart", "faster_whisper"] +DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"] +OPENAI_BASE_URL = "https://api.openai.com/v1" +DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en" +# TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests # noqa: E501 +DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0) +DEFAULT_CONFIG = Config( + whisper=DEFAULT_WHISPER_CONFIG, + # disable the UI as it slightly increases the app startup time due to the imports it's doing + enable_ui=False, +) def pytest_configure() -> None: - for logger_name in disable_loggers: + for logger_name in DISABLE_LOGGERS: logger = logging.getLogger(logger_name) logger.disabled = True -# NOTE: not being used. Keeping just in case +# NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory` @pytest.fixture def client() -> Generator[TestClient, None, None]: os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" @@ -28,10 +42,37 @@ def client() -> Generator[TestClient, None, None]: yield client +# https://stackoverflow.com/questions/74890214/type-hint-callback-function-with-optional-parameters-aka-callable-with-optional +class AclientFactory(Protocol): + def __call__(self, config: Config = DEFAULT_CONFIG) -> AbstractAsyncContextManager[AsyncClient]: ... + + @pytest_asyncio.fixture() -async def aclient() -> AsyncGenerator[AsyncClient, None]: - os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" - async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: +async def aclient_factory(mocker: MockerFixture) -> AclientFactory: + """Returns a context manager that provides an `AsyncClient` instance with `app` using the provided configuration.""" + + @asynccontextmanager + async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient, None]: + # NOTE: all calls to `get_config` should be patched. One way to test that this works is to update the original `get_config` to raise an exception and see if the tests fail # noqa: E501 + mocker.patch("faster_whisper_server.dependencies.get_config", return_value=config) + mocker.patch("faster_whisper_server.main.get_config", return_value=config) + # NOTE: I couldn't get the following to work but it shouldn't matter + # mocker.patch( + # "faster_whisper_server.text_utils.Transcription._ensure_no_word_overlap.get_config", return_value=config + # ) + + app = create_app() + # https://fastapi.tiangolo.com/advanced/testing-dependencies/ + app.dependency_overrides[get_config] = lambda: config + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient: + yield aclient + + return inner + + +@pytest_asyncio.fixture() +async def aclient(aclient_factory: AclientFactory) -> AsyncGenerator[AsyncClient, None]: + async with aclient_factory() as aclient: yield aclient @@ -43,11 +84,13 @@ def openai_client(aclient: AsyncClient) -> AsyncOpenAI: @pytest.fixture def actual_openai_client() -> AsyncOpenAI: return AsyncOpenAI( - base_url="https://api.openai.com/v1" - ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value + # `base_url` is provided in case `OPENAI_BASE_URL` is set to a different value + base_url=OPENAI_BASE_URL + ) # TODO: remove the download after running the tests +# TODO: do not download when not needed @pytest.fixture(scope="session", autouse=True) def download_piper_voices() -> None: # Only download `voices.json` and the default voice diff --git a/tests/model_manager_test.py b/tests/model_manager_test.py index e05b1f3..66701a3 100644 --- a/tests/model_manager_test.py +++ b/tests/model_manager_test.py @@ -1,23 +1,22 @@ import asyncio -import os import anyio -from httpx import ASGITransport, AsyncClient import pytest -from faster_whisper_server.main import create_app +from faster_whisper_server.config import Config, WhisperConfig +from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory + +MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable @pytest.mark.asyncio -async def test_model_unloaded_after_ttl() -> None: +async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None: ttl = 5 - model = "Systran/faster-whisper-tiny.en" - os.environ["WHISPER__TTL"] = str(ttl) - os.environ["ENABLE_UI"] = "false" - async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + async with aclient_factory(config) as aclient: res = (await aclient.get("/api/ps")).json() assert len(res["models"]) == 0 - await aclient.post(f"/api/ps/{model}") + await aclient.post(f"/api/ps/{MODEL}") res = (await aclient.get("/api/ps")).json() assert len(res["models"]) == 1 await asyncio.sleep(ttl + 1) # wait for the model to be unloaded @@ -26,13 +25,11 @@ async def test_model_unloaded_after_ttl() -> None: @pytest.mark.asyncio -async def test_ttl_resets_after_usage() -> None: +async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None: ttl = 5 - model = "Systran/faster-whisper-tiny.en" - os.environ["WHISPER__TTL"] = str(ttl) - os.environ["ENABLE_UI"] = "false" - async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: - await aclient.post(f"/api/ps/{model}") + config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + async with aclient_factory(config) as aclient: + await aclient.post(f"/api/ps/{MODEL}") res = (await aclient.get("/api/ps")).json() assert len(res["models"]) == 1 await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded @@ -43,7 +40,9 @@ async def test_ttl_resets_after_usage() -> None: data = await f.read() res = ( await aclient.post( - "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} + "/v1/audio/transcriptions", + files={"file": ("audio.wav", data, "audio/wav")}, + data={"model": MODEL}, ) ).json() res = (await aclient.get("/api/ps")).json() @@ -60,28 +59,28 @@ async def test_ttl_resets_after_usage() -> None: # this just ensures the model can be loaded again after being unloaded res = ( await aclient.post( - "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} + "/v1/audio/transcriptions", + files={"file": ("audio.wav", data, "audio/wav")}, + data={"model": MODEL}, ) ).json() @pytest.mark.asyncio -async def test_model_cant_be_unloaded_when_used() -> None: +async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None: ttl = 0 - model = "Systran/faster-whisper-tiny.en" - os.environ["WHISPER__TTL"] = str(ttl) - os.environ["ENABLE_UI"] = "false" - async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + async with aclient_factory(config) as aclient: async with await anyio.open_file("audio.wav", "rb") as f: data = await f.read() task = asyncio.create_task( aclient.post( - "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} + "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL} ) ) await asyncio.sleep(0.1) # wait for the server to start processing the request - res = await aclient.delete(f"/api/ps/{model}") + res = await aclient.delete(f"/api/ps/{MODEL}") assert res.status_code == 409 await task @@ -90,27 +89,23 @@ async def test_model_cant_be_unloaded_when_used() -> None: @pytest.mark.asyncio -async def test_model_cant_be_loaded_twice() -> None: +async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None: ttl = -1 - model = "Systran/faster-whisper-tiny.en" - os.environ["ENABLE_UI"] = "false" - os.environ["WHISPER__TTL"] = str(ttl) - async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: - res = await aclient.post(f"/api/ps/{model}") + config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + async with aclient_factory(config) as aclient: + res = await aclient.post(f"/api/ps/{MODEL}") assert res.status_code == 201 - res = await aclient.post(f"/api/ps/{model}") + res = await aclient.post(f"/api/ps/{MODEL}") assert res.status_code == 409 res = (await aclient.get("/api/ps")).json() assert len(res["models"]) == 1 @pytest.mark.asyncio -async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None: +async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None: ttl = 0 - os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" - os.environ["WHISPER__TTL"] = str(ttl) - os.environ["ENABLE_UI"] = "false" - async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + async with aclient_factory(config) as aclient: async with await anyio.open_file("audio.wav", "rb") as f: data = await f.read() res = await aclient.post(