Skip to content

Commit

Permalink
[CompilerFlag] Detect if FlashInfer is enabled from libinfo (#1941)
Browse files Browse the repository at this point in the history
This PR supports the detection of if FlashInfer is enabled when
building TVM, so that FlashInfer won't be enabled when TVM is
not built with FlashInfer enabled.
  • Loading branch information
MasterJH5574 authored Mar 15, 2024
1 parent 01527e9 commit 09fe1bc
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/mlc_llm/interface/compiler_flags.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Flags for overriding model config."""

import dataclasses
from io import StringIO
from typing import Optional

import tvm

from mlc_llm.support import argparse, logging
from mlc_llm.support.config import ConfigOverrideBase

Expand Down Expand Up @@ -65,6 +68,8 @@ def _flashinfer(target) -> bool:
return False
if target.kind.name != "cuda":
return False
if tvm.get_global_func("support.GetLibInfo")()["USE_FLASHINFER"] != "ON":
return False
arch_list = detect_cuda_arch_list(target)
for arch in arch_list:
if arch < 80:
Expand Down

0 comments on commit 09fe1bc

Please sign in to comment.