Skip to content

Commit

Permalink
perf: serve quantized versions of phi2, neuralchat, and psyfighter1 (#80
Browse files Browse the repository at this point in the history
)

* perf: serve quantized versions of phi2, neuralchat, and psyfighters

* revert: psyfighter2 back to unquantized & A100

* refactor: remove unnecessary usage fields & code

* perf: limit max concurrency to 5 and drop batch size to 4

* chore: clean up dead code
  • Loading branch information
sambarnes authored Mar 19, 2024
1 parent cb1180c commit 2b1cb4a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 95 deletions.
50 changes: 26 additions & 24 deletions modal/runner/containers/vllm_unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
get_logger,
get_observability_secrets,
)
from shared.protocol import GPUType
from shared.volumes import (
does_model_exist,
get_model_path,
Expand All @@ -31,12 +30,6 @@ def _make_container(
"""Helper function to create a container with the given GPU configuration."""

num_gpus = gpu.count
if isinstance(gpu, modal.gpu.A100):
gpu_type = GPUType.A100_80G if gpu.memory == 80 else GPUType.A100_40G
elif isinstance(gpu, modal.gpu.H100):
gpu_type = GPUType.H100_80G
else:
raise ValueError(f"Unknown GPU type: {gpu}")

# Avoid wasting resources & money in dev
if keep_warm and is_env_dev():
Expand All @@ -63,21 +56,13 @@ def __init__(self):
ray.init(num_gpus=num_gpus, ignore_reinit_error=True)

super().__init__(
gpu_type=gpu_type,
params=VllmParams(
model=str(model_path),
tensor_parallel_size=num_gpus,
**vllm_opts,
),
)

# For any containers with keep_warm, we need to skip cold-start usage
# billing. This is because the first request might be minutes after
# the container is started, and we don't want to record that time as
# usage.
if keep_warm:
self.is_first_request = False

# Performance improvement from https://github.com/vllm-project/vllm/issues/2073#issuecomment-1853422529
if num_gpus > 1:
import subprocess
Expand Down Expand Up @@ -118,24 +103,34 @@ def __init__(self):
# Automatically populated by _make_container.
REGISTERED_CONTAINERS = {}

_phi2 = "TheBloke/phi-2-GPTQ"
VllmContainer_MicrosoftPhi2 = _make_container(
name="VllmContainer_MicrosoftPhi2",
model_name="microsoft/phi-2",
gpu=modal.gpu.A100(count=1, memory=40),
concurrent_inputs=120,
model_name=_phi2,
gpu=modal.gpu.A10G(count=1),
concurrent_inputs=4,
max_containers=5,
)

_neural_chat = "TheBloke/neural-chat-7b-v3-1-GPTQ"
VllmContainer_IntelNeuralChat7B = _make_container(
name="VllmContainer_IntelNeuralChat7B",
model_name="Intel/neural-chat-7b-v3-1",
gpu=modal.gpu.A100(count=1, memory=40),
concurrent_inputs=100,
model_name=_neural_chat,
gpu=modal.gpu.A10G(count=1),
concurrent_inputs=4,
max_containers=5,
)

_psyfighter = "TheBloke/Psyfighter-13B-GPTQ"
VllmContainer_JebCarterPsyfighter13B = _make_container(
"VllmContainer_JebCarterPsyfighter13B",
model_name="jebcarter/Psyfighter-13B",
gpu=modal.gpu.A100(count=1, memory=40),
concurrent_inputs=32,
model_name=_psyfighter,
gpu=modal.gpu.A10G(count=1),
concurrent_inputs=4,
max_containers=5,
)

# TODO: quantize this one too. shipping the others first to limit blast radius
VllmContainer_KoboldAIPsyfighter2 = _make_container(
name="VllmContainer_KoboldAIPsyfighter2",
model_name="KoboldAI/LLaMA2-13B-Psyfighter2",
Expand Down Expand Up @@ -171,7 +166,14 @@ def __init__(self):
# A re-mapping of model names to their respective quantized models.
# From the outside, the model name is the original, but internally,
# we use the quantized model name.
#
# NOTE: When serving quantized models, the throughput can suffer a ton
# at high batch sizes. Read this thread to learn why:
# https://github.com/vllm-project/vllm/issues/1002#issuecomment-1712824199
QUANTIZED_MODELS = {
"microsoft/phi-2": _phi2,
"Intel/neural-chat-7b-v3-1": _neural_chat,
"jebcarter/Psyfighter-13B": _psyfighter,
"NeverSleep/Noromaid-v0.1-mixtral-8x7b-Instruct-v3": _noromaid,
"jondurbin/bagel-34b-v0.2": _bagel,
}
50 changes: 5 additions & 45 deletions modal/runner/engines/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from shared.protocol import (
CompletionPayload,
GPUType,
ResponseBody,
Usage,
create_error_text,
Expand Down Expand Up @@ -62,14 +61,7 @@ class VllmParams(BaseModel):


class VllmEngine(BaseEngine):
def __init__(
self,
gpu_type: GPUType,
params: VllmParams,
):
self.gpu_type = gpu_type
self.is_first_request = True
self.t_cold_start = time.time()
def __init__(self, params: VllmParams):
self.engine = None
self.engine_args = AsyncEngineArgs(
**params.dict(),
Expand All @@ -81,44 +73,14 @@ def startup(self):
with timer("engine init", model=self.engine_args.model):
self.engine = AsyncLLMEngine.from_engine_args(self.engine_args)

@property
def gpu_count(self) -> int:
return self.engine_args.tensor_parallel_size

@property
def cost_per_second(self) -> float:
return self.gpu_count * self.gpu_type.cost_per_second

# @method()
# async def tokenize_prompt(self, payload: Payload) -> List[int]:
# return self.tokenizer(payload.prompt).input_ids

# @method()
# async def max_model_len(self) -> int:
# engine_model_config = await self.engine.get_model_config()
# return engine_model_config.max_model_len

@method()
async def generate(self, payload: CompletionPayload, params):
assert self.engine is not None, "Engine not initialized"

# Track usage as a running total. For the first request to the
# container, cold-start time is included in the usage duration.
t_start_inference = time.time()
t_start_usage_duration = t_start_inference
if self.is_first_request:
self.is_first_request = False
t_start_usage_duration = self.t_cold_start

resp = ResponseBody(
text="",
usage=Usage(
prompt_tokens=0,
completion_tokens=0,
duration=0.0,
gpu_type=self.gpu_type,
gpu_count=self.gpu_count,
),
usage=Usage(prompt_tokens=0, completion_tokens=0),
)

try:
Expand All @@ -134,7 +96,6 @@ async def generate(self, payload: CompletionPayload, params):
finish_reason = current.outputs[0].finish_reason
resp.usage.prompt_tokens = len(current.prompt_token_ids)
resp.usage.completion_tokens = len(current.outputs[0].token_ids)
resp.usage.duration = time.time() - t_start_usage_duration

# Non-streaming requests continue generating w/o yielding intermediate results
if not payload.stream:
Expand All @@ -158,15 +119,14 @@ async def generate(self, payload: CompletionPayload, params):
data = resp.json(ensure_ascii=False)
yield sse(data) if payload.stream else data

duration = time.time() - t_start_inference
logger.info(
"Completed generation",
extra={
"model": self.engine_args.model,
"tokens": resp.usage.completion_tokens,
"tps": resp.usage.completion_tokens
/ (time.time() - t_start_inference),
"duration": resp.usage.duration,
"cost": resp.usage.duration * self.cost_per_second,
"tps": resp.usage.completion_tokens / duration,
"duration": duration,
},
)
except Exception as err:
Expand Down
27 changes: 1 addition & 26 deletions modal/shared/protocol.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,8 @@
from enum import Enum
from typing import Final, List, Optional, Union
from typing import List, Optional, Union

from fastapi.responses import JSONResponse, PlainTextResponse
from pydantic import BaseModel

_COST_PER_SECOND_A100_40G: Final[float] = 0.001036
_COST_PER_SECOND_A100_80G: Final[float] = 0.001553
_COST_PER_SECOND_H100_80G: Final[float] = 0.002125


class GPUType(Enum):
A100_40G = "A100_40G"
A100_80G = "A100_80G"
H100_80G = "H100_80G"

@property
def cost_per_second(self) -> float:
match self:
case GPUType.A100_40G:
return _COST_PER_SECOND_A100_40G
case GPUType.A100_80G:
return _COST_PER_SECOND_A100_80G
case GPUType.H100_80G:
return _COST_PER_SECOND_H100_80G


# https://github.com/vllm-project/vllm/blob/320a622ec4d098f2da5d097930f4031517e7327b/vllm/sampling_params.py#L7-L52
# Lines were sorted for consistency
Expand Down Expand Up @@ -68,10 +47,6 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int

duration: float
gpu_type: GPUType
gpu_count: int


class ResponseBody(BaseModel):
text: str
Expand Down

0 comments on commit 2b1cb4a

Please sign in to comment.