Skip to content

Commit

Permalink
[JIT] Support overriding optimization flags in JIT (#3032)
Browse files Browse the repository at this point in the history
This PR adds the optimization flags override (`"opt"`) for MLCEngine,
chat and serve when running JIT compilation. Prior to this PR,
the JIT compilation always uses O2 as the optimization flags.
  • Loading branch information
MasterJH5574 authored Nov 16, 2024
1 parent 3578e79 commit e283cd0
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/mlc_llm/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class EngineConfigOverride: # pylint: disable=too-many-instance-attributes
attention_sink_size: Optional[int] = None
tensor_parallel_shards: Optional[int] = None
pipeline_parallel_stages: Optional[int] = None
opt: Optional[str] = None

def __repr__(self) -> str:
out = StringIO()
Expand All @@ -53,6 +54,7 @@ def __repr__(self) -> str:
print(f";attention_sink_size={self.attention_sink_size}", file=out, end="")
print(f";tensor_parallel_shards={self.tensor_parallel_shards}", file=out, end="")
print(f";pipeline_parallel_stages={self.pipeline_parallel_stages}", file=out, end="")
print(f";opt={self.opt}", file=out, end="")
return out.getvalue().rstrip()

@staticmethod
Expand All @@ -75,6 +77,7 @@ def from_str(source: str) -> "EngineConfigOverride":
parser.add_argument("--attention_sink_size", type=int, default=None)
parser.add_argument("--tensor_parallel_shards", type=int, default=None)
parser.add_argument("--pipeline_parallel_stages", type=int, default=None)
parser.add_argument("--opt", type=str, default=None)
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return EngineConfigOverride(
max_num_sequence=results.max_num_sequence,
Expand All @@ -92,6 +95,7 @@ def from_str(source: str) -> "EngineConfigOverride":
attention_sink_size=results.attention_sink_size,
tensor_parallel_shards=results.tensor_parallel_shards,
pipeline_parallel_stages=results.pipeline_parallel_stages,
opt=results.opt,
)


Expand Down Expand Up @@ -210,6 +214,7 @@ def main(argv):
additional_models=additional_models,
tensor_parallel_shards=parsed.overrides.tensor_parallel_shards,
pipeline_parallel_stages=parsed.overrides.pipeline_parallel_stages,
opt=parsed.overrides.opt,
speculative_mode=parsed.speculative_mode,
prefix_cache_mode=parsed.prefix_cache_mode,
max_num_sequence=parsed.overrides.max_num_sequence,
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/interface/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ class ModelConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-insta
attention_sink_size: Optional[int] = None
tensor_parallel_shards: Optional[int] = None
pipeline_parallel_stages: Optional[int] = None
opt: Optional[str] = None

@staticmethod
def from_str(source: str) -> "ModelConfigOverride":
"""Parse model config override values from a string."""
parser = argparse.ArgumentParser(description="model config override values")
parser.add_argument("--tensor_parallel_shards", type=int, default=None)
parser.add_argument("--pipeline_parallel_stages", type=int, default=None)
parser.add_argument("--opt", type=str, default=None)
parser.add_argument("--context_window_size", type=int, default=None)
parser.add_argument("--sliding_window_size", type=int, default=None)
parser.add_argument("--prefill_chunk_size", type=int, default=None)
Expand All @@ -105,6 +107,7 @@ def from_str(source: str) -> "ModelConfigOverride":
return ModelConfigOverride(
tensor_parallel_shards=results.tensor_parallel_shards,
pipeline_parallel_stages=results.pipeline_parallel_stages,
opt=results.opt,
context_window_size=results.context_window_size,
sliding_window_size=results.sliding_window_size,
prefill_chunk_size=results.prefill_chunk_size,
Expand Down Expand Up @@ -294,6 +297,7 @@ def chat(
attention_sink_size=overrides.attention_sink_size,
tensor_parallel_shards=overrides.tensor_parallel_shards,
pipeline_parallel_stages=overrides.pipeline_parallel_stages,
opt=overrides.opt,
),
)
).chat()
2 changes: 2 additions & 0 deletions python/mlc_llm/interface/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def serve(
additional_models: List[Union[str, Tuple[str, str]]],
tensor_parallel_shards: Optional[int],
pipeline_parallel_stages: Optional[int],
opt: Optional[str],
max_num_sequence: Optional[int],
max_total_sequence_length: Optional[int],
max_single_sequence_length: Optional[int],
Expand Down Expand Up @@ -61,6 +62,7 @@ def serve(
additional_models=additional_models,
tensor_parallel_shards=tensor_parallel_shards,
pipeline_parallel_stages=pipeline_parallel_stages,
opt=opt,
max_num_sequence=max_num_sequence,
max_total_sequence_length=max_total_sequence_length,
max_single_sequence_length=max_single_sequence_length,
Expand Down
14 changes: 14 additions & 0 deletions python/mlc_llm/serve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,22 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes
tensor_parallel_shards : Optional[int]
Number of shards to split the model into in tensor parallelism multi-gpu inference.
When "model_lib" is given, this field will be ignored, and the tensor_parallel_shards
in the model_lib metadata will be used.
pipeline_parallel_stages : Optional[int]
Number of pipeline stages to split the model layers for pipeline parallelism.
When "model_lib" is given, this field will be ignored, and the pipeline_parallel_stages
in the model_lib metadata will be used.
opt : Optional[str]
The optimization flags for JIT compilation.
When "model_lib" is given, this field will be ignored.
MLC LLM maintains a predefined set of optimization flags,
denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them,
and O3 represents extreme optimization that could potentially break the system.
Meanwhile, optimization flags could be explicitly specified via details knobs, e.g.
"cublas_gemm=1;cudagraph=0".
gpu_memory_utilization : Optional[float]
A number in (0, 1) denoting the fraction of GPU memory used by the server in total.
Expand Down Expand Up @@ -127,6 +140,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes
mode: Optional[Literal["local", "interactive", "server"]] = None
tensor_parallel_shards: Optional[int] = None
pipeline_parallel_stages: Optional[int] = None
opt: Optional[str] = None
gpu_memory_utilization: Optional[float] = None
kv_cache_page_size: int = 16
max_num_sequence: Optional[int] = None
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/serve/engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]:
"tensor_parallel_shards": engine_config.tensor_parallel_shards,
"pipeline_parallel_stages": engine_config.pipeline_parallel_stages,
"max_batch_size": engine_config.max_num_sequence,
"opt": engine_config.opt,
}

model_lib = jit.jit(
Expand Down

0 comments on commit e283cd0

Please sign in to comment.