Skip to content

Commit

Permalink
Merge pull request #151 from mobiusml/remove_test_cache
Browse files Browse the repository at this point in the history
Removal of Test Cache and Redesign of Deployment Tests
  • Loading branch information
movchan74 authored Aug 8, 2024
2 parents 309b9b0 + fb80a26 commit 8af6495
Show file tree
Hide file tree
Showing 287 changed files with 8,932 additions and 5,143 deletions.
2 changes: 0 additions & 2 deletions .env
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
CUDA_VISIBLE_DEVICES=""
USE_DEPLOYMENT_CACHE = False
SAVE_DEPLOYMENT_CACHE = False
HF_HUB_ENABLE_HF_TRANSFER = 1
HF_TOKEN=""
20 changes: 1 addition & 19 deletions aana/configs/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
HfTextGenerationDeployment,
)
from aana.deployments.idefics_2_deployment import Idefics2Config, Idefics2Deployment
from aana.deployments.stablediffusion2_deployment import (
StableDiffusion2Config,
StableDiffusion2Deployment,
)
from aana.deployments.vad_deployment import VadConfig, VadDeployment
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment
from aana.deployments.whisper_deployment import (
Expand Down Expand Up @@ -130,17 +126,6 @@
)
available_deployments["whisper_medium_deployment"] = whisper_medium_deployment

stablediffusion2_deployment = StableDiffusion2Deployment.options(
num_replicas=1,
max_ongoing_requests=1000,
ray_actor_options={"num_gpus": 1},
user_config=StableDiffusion2Config(
model="stabilityai/stable-diffusion-2",
dtype=Dtype.FLOAT16,
).model_dump(mode="json"),
)
available_deployments["stablediffusion2_deployment"] = stablediffusion2_deployment

vad_deployment = VadDeployment.options(
num_replicas=1,
max_ongoing_requests=1000,
Expand Down Expand Up @@ -201,9 +186,7 @@
).model_dump(mode="json"),
)

available_deployments[
"idefics_2_deployment"
] = idefics_2_deployment
available_deployments["idefics_2_deployment"] = idefics_2_deployment

__all__ = [
"vllm_llama2_7b_chat_deployment",
Expand All @@ -212,7 +195,6 @@
"microsoft_phi_3_mini_instruct_deployment",
"hf_blip2_opt_2_7b_deployment",
"whisper_medium_deployment",
"stablediffusion2_deployment",
"vad_deployment",
"hf_blip2_opt_2_7b_pipeline_deployment",
"hf_phi3_mini_4k_instruct_text_gen_deployment",
Expand Down
6 changes: 2 additions & 4 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@ class TestSettings(BaseSettings):
Attributes:
test_mode (bool): Flag indicating if the SDK is in test mode.
use_deployment_cache (bool): Flag indicating if the SDK should use cached deployment results for testing.
save_deployment_cache (bool): Flag indicating if the SDK should save deployment results to cache for testing.
save_expected_output (bool): Flag indicating if the expected output should be saved (to create test cases).
"""

test_mode: bool = False
use_deployment_cache: bool = False # use cached deployment results for testing
save_deployment_cache: bool = False # save deployment results to cache for testing
save_expected_output: bool = False


class TaskQueueSettings(BaseSettings):
Expand Down
6 changes: 5 additions & 1 deletion aana/core/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
class Dtype(str, Enum):
"""Data types.
Possible values are "auto", "float32", "float16", and "int8".
Possible values are "auto", "float32", "float16", "bfloat16" and "int8".
Attributes:
AUTO (str): auto
FLOAT32 (str): float32
FLOAT16 (str): float16
BFLOAT16 (str): bfloat16
INT8 (str): int8
"""

AUTO = "auto"
FLOAT32 = "float32"
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
INT8 = "int8"

def to_torch(self) -> torch.dtype | str:
Expand All @@ -38,6 +40,8 @@ def to_torch(self) -> torch.dtype | str:
return torch.float32
case self.FLOAT16:
return torch.float16
case self.BFLOAT16:
return torch.bfloat16
case self.INT8:
return torch.int8
case _:
Expand Down
186 changes: 2 additions & 184 deletions aana/deployments/base_deployment.py
Original file line number Diff line number Diff line change
@@ -1,171 +1,6 @@
import inspect
import pickle
from functools import wraps
from importlib import resources
from pathlib import Path
from typing import Any

import rapidfuzz

from aana.configs.settings import settings
from aana.utils.core import get_object_hash
from aana.utils.json import jsonify


def test_cache(func): # noqa: C901
"""Decorator for caching and loading the results of a deployment function in testing mode.
Keep in mind that this decorator only works for async functions and async generator functions.
Use this decorator to annotate deployment functions that you want to cache in testing mode.
There are 3 environment variables that control the behavior of the decorator:
- TEST_MODE: set to "true" to enable testing mode
(default is "false", should only be set to "true" if you are running tests)
- USE_DEPLOYMENT_CACHE: set to "true" to enable cache usage
- SAVE_DEPLOYMENT_CACHE: set to "true" to enable cache saving
The decorator behaves differently in testing and production modes.
In production mode, the decorator is a no-op.
In testing mode, the behavior of the decorator is controlled by the environment variables USE_DEPLOYMENT_CACHE and SAVE_DEPLOYMENT_CACHE.
If USE_DEPLOYMENT_CACHE is set to "true", the decorator will load the result from the cache if it exists. SAVE_DEPLOYMENT_CACHE is ignored.
The decorator takes a hash of the deployment configuration and the function arguments and keyword arguments (args and kwargs) to locate the cache file.
If the cache file exists, the decorator will load the result from the cache and return it.
If the cache file does not exist, the decorator will try to find the cache file with the closest args and load the result from that cache file
(function name and deployment configuration should match exactly, fuzzy matching only applies to args and kwargs).
If USE_DEPLOYMENT_CACHE is set to "false", the decorator will execute the function and save the result to the cache if SAVE_DEPLOYMENT_CACHE is set to "true".
"""
if not settings.test.test_mode:
# If we are in production, the decorator is a no-op
return func

def get_cache_path(args, kwargs):
"""Get the path to the cache file."""
self = args[0]

func_name = func.__name__
deployment_name = self.__class__.__name__

config = args[0].config
config_hash = get_object_hash(config)

args_hash = get_object_hash({"args": args[1:], "kwargs": kwargs})

return (
resources.path("aana.tests.files.cache", "")
/ Path(deployment_name)
/ Path(f"{func_name}_{config_hash}_{args_hash}.pkl")
)

def save_cache(cache_path, cache, args, kwargs):
"""Save the cache to a file."""
cache_obj = {
"args": jsonify({"args": args[1:], "kwargs": kwargs}),
}
if "exception" in cache:
cache_obj["exception"] = cache[
"exception"
] # if the cache contains an exception, save it
else:
cache_obj["cache"] = cache # otherwise, cache the result
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.open("wb").write(pickle.dumps(cache_obj))

def find_matching_cache(cache_path, args, kwargs):
"""Find the cache file with the closest args."""

def get_args(path):
cache = pickle.loads(path.open("rb").read()) # noqa: S301
return cache["args"]

args_str = jsonify({"args": args[1:], "kwargs": kwargs})
pattern = cache_path.name.replace(cache_path.name.split("_")[-1], "*")
candidate_cache_files = list(cache_path.parent.glob(pattern))

if len(candidate_cache_files) == 0:
raise FileNotFoundError(f"{cache_path.parent}/{pattern}")

# find the cache with the closest args
path = min(
candidate_cache_files,
key=lambda path: rapidfuzz.distance.Levenshtein.distance(
args_str, get_args(path)
),
)
return Path(path)

@wraps(func)
async def wrapper(*args, **kwargs):
"""Wrapper for the deployment function."""
cache_path = get_cache_path(args, kwargs)

if settings.test.use_deployment_cache:
# load from cache
if not cache_path.exists():
# raise FileNotFoundError(cache_path)
cache_path = find_matching_cache(cache_path, args, kwargs)
cache = pickle.loads(cache_path.open("rb").read()) # noqa: S301
# raise exception if the cache contains an exception
if "exception" in cache:
raise cache["exception"]
return cache["cache"]
else:
# execute the function
try:
result = await func(*args, **kwargs)
except Exception as e:
result = {"exception": e}
raise
finally:
if settings.test.save_deployment_cache and not cache_path.exists():
# save to cache
save_cache(cache_path, result, args, kwargs)
return result

@wraps(func)
async def wrapper_generator(*args, **kwargs):
"""Wrapper for the deployment generator function."""
cache_path = get_cache_path(args, kwargs)

if settings.test.use_deployment_cache:
# load from cache
if not cache_path.exists():
# raise FileNotFoundError(cache_path)
cache_path = find_matching_cache(cache_path, args, kwargs)

cache = pickle.loads(cache_path.open("rb").read()) # noqa: S301
# raise exception if the cache contains an exception
if "exception" in cache:
raise cache["exception"]
for item in cache["cache"]:
yield item
else:
cache = []
try:
# execute the function
async for item in func(*args, **kwargs):
yield item
if settings.test.save_deployment_cache:
cache.append(item)
except Exception as e:
cache = {"exception": e}
raise
finally:
if settings.test.save_deployment_cache and not cache_path.exists():
# save to cache
save_cache(cache_path, cache, args, kwargs)

wrapper_generator.test_cache_enabled = True
wrapper.test_cache_enabled = True

if inspect.isasyncgenfunction(func):
return wrapper_generator
else:
return wrapper


class BaseDeployment:
"""Base class for all deployments.
Expand All @@ -185,18 +20,8 @@ async def reconfigure(self, config: dict[str, Any]):
The method is called when the deployment is updated.
"""
self.config = config
if (
settings.test.test_mode
and settings.test.use_deployment_cache
and self.__check_test_cache_enabled()
):
# If we are in testing mode and we want to use the cache,
# we don't need to load the model
self._configured = True
return
else:
await self.apply_config(config)
self._configured = True
await self.apply_config(config)
self._configured = True

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.
Expand Down Expand Up @@ -236,10 +61,3 @@ async def get_methods(self) -> dict:
if method.__doc__:
methods_info[name]["doc"] = method.__doc__
return methods_info

def __check_test_cache_enabled(self):
"""Check if the deployment has any methods decorated with test_cache."""
for method in self.__class__.__dict__.values():
if callable(method) and getattr(method, "test_cache_enabled", False):
return True
return False
7 changes: 1 addition & 6 deletions aana/deployments/base_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aana.core.chat.chat_template import apply_chat_template
from aana.core.models.chat import ChatDialog, ChatMessage
from aana.core.models.sampling import SamplingParams
from aana.deployments.base_deployment import BaseDeployment, test_cache
from aana.deployments.base_deployment import BaseDeployment


class LLMOutput(TypedDict):
Expand Down Expand Up @@ -57,7 +57,6 @@ class BaseTextGenerationDeployment(BaseDeployment):
You can also override these methods to implement custom inference logic.
"""

@test_cache
async def generate_stream(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand All @@ -72,7 +71,6 @@ async def generate_stream(
"""
raise NotImplementedError

@test_cache
async def generate(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> LLMOutput:
Expand All @@ -90,7 +88,6 @@ async def generate(
generated_text += chunk["text"]
return LLMOutput(text=generated_text)

@test_cache
async def generate_batch(
self, prompts: list[str], sampling_params: SamplingParams | None = None
) -> LLMBatchOutput:
Expand All @@ -111,7 +108,6 @@ async def generate_batch(

return LLMBatchOutput(texts=texts)

@test_cache
async def chat(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> ChatOutput:
Expand All @@ -131,7 +127,6 @@ async def chat(
response_message = ChatMessage(content=response["text"], role="assistant")
return ChatOutput(message=response_message)

@test_cache
async def chat_stream(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand Down
4 changes: 1 addition & 3 deletions aana/deployments/hf_blip2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aana.core.models.captions import Caption, CaptionsList
from aana.core.models.image import Image
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment, test_cache
from aana.deployments.base_deployment import BaseDeployment
from aana.exceptions.runtime import InferenceException
from aana.processors.batch import BatchProcessor

Expand Down Expand Up @@ -103,7 +103,6 @@ async def apply_config(self, config: dict[str, Any]):
self.processor = Blip2Processor.from_pretrained(self.model_id)
self.model.to(self.device)

@test_cache
async def generate(self, image: Image) -> CaptioningOutput:
"""Generate captions for the given image.
Expand All @@ -122,7 +121,6 @@ async def generate(self, image: Image) -> CaptioningOutput:
)
return CaptioningOutput(caption=captions["captions"][0])

@test_cache
async def generate_batch(self, **kwargs) -> CaptioningBatchOutput:
"""Generate captions for the given images.
Expand Down
3 changes: 1 addition & 2 deletions aana/deployments/hf_pipeline_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aana.core.models.custom_config import CustomConfig
from aana.core.models.image import Image
from aana.deployments.base_deployment import BaseDeployment, test_cache
from aana.deployments.base_deployment import BaseDeployment


class HfPipelineConfig(BaseModel):
Expand Down Expand Up @@ -77,7 +77,6 @@ async def apply_config(self, config: dict[str, Any]):
else:
raise

@test_cache
async def call(self, *args, **kwargs):
"""Call the pipeline.
Expand Down
Loading

0 comments on commit 8af6495

Please sign in to comment.