Skip to content

Commit

Permalink
tests: proper get_config dependency override
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi authored and fedirz committed Jan 3, 2025
1 parent 74ecebe commit ede9e6a
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 61 deletions.
21 changes: 14 additions & 7 deletions src/faster_whisper_server/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import lru_cache
import logging
from typing import Annotated

from fastapi import Depends, HTTPException, status
Expand All @@ -11,7 +12,13 @@
from faster_whisper_server.config import Config
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager

logger = logging.getLogger(__name__)

# NOTE: `get_config` is called directly instead of using sub-dependencies so that these functions could be used outside of `FastAPI` # noqa: E501


# https://fastapi.tiangolo.com/advanced/settings/?h=setti#creating-the-settings-only-once-with-lru_cache
# WARN: Any new module that ends up calling this function directly (not through `FastAPI` dependency injection) should be patched in `tests/conftest.py` # noqa: E501
@lru_cache
def get_config() -> Config:
return Config()
Expand All @@ -22,7 +29,7 @@ def get_config() -> Config:

@lru_cache
def get_model_manager() -> WhisperModelManager:
config = get_config() # HACK
config = get_config()
return WhisperModelManager(config.whisper)


Expand All @@ -31,8 +38,8 @@ def get_model_manager() -> WhisperModelManager:

@lru_cache
def get_piper_model_manager() -> PiperModelManager:
config = get_config() # HACK
return PiperModelManager(config.whisper.ttl) # HACK
config = get_config()
return PiperModelManager(config.whisper.ttl) # HACK: should have its own config


PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
Expand All @@ -53,7 +60,7 @@ async def verify_api_key(

@lru_cache
def get_completion_client() -> AsyncCompletions:
config = get_config() # HACK
config = get_config()
oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
return oai_client.chat.completions

Expand All @@ -63,9 +70,9 @@ def get_completion_client() -> AsyncCompletions:

@lru_cache
def get_speech_client() -> AsyncSpeech:
config = get_config() # HACK
config = get_config()
if config.speech_base_url is None:
# this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501
# this might not work as expected if `speech_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
from faster_whisper_server.routers.speech import (
router as speech_router,
)
Expand All @@ -86,7 +93,7 @@ def get_speech_client() -> AsyncSpeech:
def get_transcription_client() -> AsyncTranscriptions:
config = get_config()
if config.transcription_base_url is None:
# this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501
# this might not work as expected if `transcription_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
from faster_whisper_server.routers.stt import (
router as stt_router,
)
Expand Down
7 changes: 2 additions & 5 deletions src/faster_whisper_server/logger.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging

from faster_whisper_server.dependencies import get_config


def setup_logger() -> None:
config = get_config() # HACK
def setup_logger(log_level: str) -> None:
logging.getLogger().setLevel(logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(config.log_level.upper())
logger.setLevel(log_level.upper())
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
9 changes: 4 additions & 5 deletions src/faster_whisper_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@


def create_app() -> FastAPI:
setup_logger()

config = get_config() # HACK
setup_logger(config.log_level)
logger = logging.getLogger(__name__)

logger.debug(f"Config: {config}")

if platform.machine() == "x86_64":
from faster_whisper_server.routers.speech import (
router as speech_router,
Expand All @@ -39,9 +41,6 @@ def create_app() -> FastAPI:
logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
speech_router = None

config = get_config() # HACK
logger.debug(f"Config: {config}")

model_manager = get_model_manager() # HACK

@asynccontextmanager
Expand Down
59 changes: 51 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,78 @@
from collections.abc import AsyncGenerator, Generator
from contextlib import AbstractAsyncContextManager, asynccontextmanager
import logging
import os
from typing import Protocol

from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient
from huggingface_hub import snapshot_download
from openai import AsyncOpenAI
import pytest
import pytest_asyncio
from pytest_mock import MockerFixture

from faster_whisper_server.config import Config, WhisperConfig
from faster_whisper_server.dependencies import get_config
from faster_whisper_server.main import create_app

disable_loggers = ["multipart.multipart", "faster_whisper"]
DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"]
OPENAI_BASE_URL = "https://api.openai.com/v1"
DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en"
# TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests # noqa: E501
DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0)
DEFAULT_CONFIG = Config(
whisper=DEFAULT_WHISPER_CONFIG,
# disable the UI as it slightly increases the app startup time due to the imports it's doing
enable_ui=False,
)


def pytest_configure() -> None:
for logger_name in disable_loggers:
for logger_name in DISABLE_LOGGERS:
logger = logging.getLogger(logger_name)
logger.disabled = True


# NOTE: not being used. Keeping just in case
# NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory`
@pytest.fixture
def client() -> Generator[TestClient, None, None]:
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
with TestClient(create_app()) as client:
yield client


# https://stackoverflow.com/questions/74890214/type-hint-callback-function-with-optional-parameters-aka-callable-with-optional
class AclientFactory(Protocol):
def __call__(self, config: Config = DEFAULT_CONFIG) -> AbstractAsyncContextManager[AsyncClient]: ...


@pytest_asyncio.fixture()
async def aclient() -> AsyncGenerator[AsyncClient, None]:
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
async def aclient_factory(mocker: MockerFixture) -> AclientFactory:
"""Returns a context manager that provides an `AsyncClient` instance with `app` using the provided configuration."""

@asynccontextmanager
async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient, None]:
# NOTE: all calls to `get_config` should be patched. One way to test that this works is to update the original `get_config` to raise an exception and see if the tests fail # noqa: E501
mocker.patch("faster_whisper_server.dependencies.get_config", return_value=config)
mocker.patch("faster_whisper_server.main.get_config", return_value=config)
# NOTE: I couldn't get the following to work but it shouldn't matter
# mocker.patch(
# "faster_whisper_server.text_utils.Transcription._ensure_no_word_overlap.get_config", return_value=config
# )

app = create_app()
# https://fastapi.tiangolo.com/advanced/testing-dependencies/
app.dependency_overrides[get_config] = lambda: config
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
yield aclient

return inner


@pytest_asyncio.fixture()
async def aclient(aclient_factory: AclientFactory) -> AsyncGenerator[AsyncClient, None]:
async with aclient_factory() as aclient:
yield aclient


Expand All @@ -43,11 +84,13 @@ def openai_client(aclient: AsyncClient) -> AsyncOpenAI:
@pytest.fixture
def actual_openai_client() -> AsyncOpenAI:
return AsyncOpenAI(
base_url="https://api.openai.com/v1"
) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
# `base_url` is provided in case `OPENAI_BASE_URL` is set to a different value
base_url=OPENAI_BASE_URL
)


# TODO: remove the download after running the tests
# TODO: do not download when not needed
@pytest.fixture(scope="session", autouse=True)
def download_piper_voices() -> None:
# Only download `voices.json` and the default voice
Expand Down
67 changes: 31 additions & 36 deletions tests/model_manager_test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import asyncio
import os

import anyio
from httpx import ASGITransport, AsyncClient
import pytest

from faster_whisper_server.main import create_app
from faster_whisper_server.config import Config, WhisperConfig
from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory

MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable


@pytest.mark.asyncio
async def test_model_unloaded_after_ttl() -> None:
async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None:
ttl = 5
model = "Systran/faster-whisper-tiny.en"
os.environ["WHISPER__TTL"] = str(ttl)
os.environ["ENABLE_UI"] = "false"
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 0
await aclient.post(f"/api/ps/{model}")
await aclient.post(f"/api/ps/{MODEL}")
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
Expand All @@ -26,13 +25,11 @@ async def test_model_unloaded_after_ttl() -> None:


@pytest.mark.asyncio
async def test_ttl_resets_after_usage() -> None:
async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
ttl = 5
model = "Systran/faster-whisper-tiny.en"
os.environ["WHISPER__TTL"] = str(ttl)
os.environ["ENABLE_UI"] = "false"
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
await aclient.post(f"/api/ps/{model}")
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
await aclient.post(f"/api/ps/{MODEL}")
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
Expand All @@ -43,7 +40,9 @@ async def test_ttl_resets_after_usage() -> None:
data = await f.read()
res = (
await aclient.post(
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
"/v1/audio/transcriptions",
files={"file": ("audio.wav", data, "audio/wav")},
data={"model": MODEL},
)
).json()
res = (await aclient.get("/api/ps")).json()
Expand All @@ -60,28 +59,28 @@ async def test_ttl_resets_after_usage() -> None:
# this just ensures the model can be loaded again after being unloaded
res = (
await aclient.post(
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
"/v1/audio/transcriptions",
files={"file": ("audio.wav", data, "audio/wav")},
data={"model": MODEL},
)
).json()


@pytest.mark.asyncio
async def test_model_cant_be_unloaded_when_used() -> None:
async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None:
ttl = 0
model = "Systran/faster-whisper-tiny.en"
os.environ["WHISPER__TTL"] = str(ttl)
os.environ["ENABLE_UI"] = "false"
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()

task = asyncio.create_task(
aclient.post(
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL}
)
)
await asyncio.sleep(0.1) # wait for the server to start processing the request
res = await aclient.delete(f"/api/ps/{model}")
res = await aclient.delete(f"/api/ps/{MODEL}")
assert res.status_code == 409

await task
Expand All @@ -90,27 +89,23 @@ async def test_model_cant_be_unloaded_when_used() -> None:


@pytest.mark.asyncio
async def test_model_cant_be_loaded_twice() -> None:
async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None:
ttl = -1
model = "Systran/faster-whisper-tiny.en"
os.environ["ENABLE_UI"] = "false"
os.environ["WHISPER__TTL"] = str(ttl)
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
res = await aclient.post(f"/api/ps/{model}")
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
res = await aclient.post(f"/api/ps/{MODEL}")
assert res.status_code == 201
res = await aclient.post(f"/api/ps/{model}")
res = await aclient.post(f"/api/ps/{MODEL}")
assert res.status_code == 409
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1


@pytest.mark.asyncio
async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None:
ttl = 0
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
os.environ["WHISPER__TTL"] = str(ttl)
os.environ["ENABLE_UI"] = "false"
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
res = await aclient.post(
Expand Down

0 comments on commit ede9e6a

Please sign in to comment.