diff --git a/vllm_hpu_extension/capabilities.py b/vllm_hpu_extension/capabilities.py index e845d5a9..b6126c54 100644 --- a/vllm_hpu_extension/capabilities.py +++ b/vllm_hpu_extension/capabilities.py @@ -45,16 +45,27 @@ class Capabilities: def __init__(self, features, environment): self.all = set(features.keys()) self.enabled = set(name for name, check in features.items() if check(**environment)) + self.disabled = self.all - self.enabled def is_enabled(self, *names): return all(n in self.enabled for n in names) + def is_disabled(self, *names): + return all(n in self.disabled for n in names) + def __repr__(self): feature_list = [('+' if self.is_enabled(f) else '-') + f for f in sorted(self.all)] return f'[{(" ").join(feature_list)}]' + def _check(self, name): + if name.startswith('-'): + return self.is_disabled(name[1:]) + if name.startswith('+'): + return self.is_enabled(name[1:]) + return self.is_enabled(name) + def __contains__(self, names): - return self.is_enabled(*names.split(',')) + return all(self._check(name) for name in names.split(',')) @cache diff --git a/vllm_hpu_extension/environment.py b/vllm_hpu_extension/environment.py index 2b93f081..79da22b1 100644 --- a/vllm_hpu_extension/environment.py +++ b/vllm_hpu_extension/environment.py @@ -5,10 +5,13 @@ # LICENSE file in the root directory of this source tree. ############################################################################### -from vllm.logger import init_logger -from vllm_hpu_extension.utils import is_fake_hpu +from functools import cache -logger = init_logger(__name__) + +@cache +def lazy_logger(): + from vllm.logger import init_logger + return init_logger(__name__) def get_hw(): @@ -21,9 +24,10 @@ def get_hw(): return "gaudi2" case htexp.synDeviceType.synDeviceGaudi3: return "gaudi3" + from vllm_hpu_extension.utils import is_fake_hpu if is_fake_hpu(): return "cpu" - logger.warning(f'Unknown device type: {device_type}') + lazy_logger().warning(f'Unknown device type: {device_type}') return None @@ -39,7 +43,7 @@ def get_build(): match = version_re.search(output.stdout) if output.returncode == 0 and match: return match.group('version') - logger.warning("Unable to detect habana-torch-plugin version!") + lazy_logger().warning("Unable to detect habana-torch-plugin version!") return None diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 0bd542bc..97ef7923 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -110,7 +110,7 @@ def scatter_reduce(attn, batch_size, block_groups, **rest): return attn.sub_(grouped_max.unsqueeze(-1).unsqueeze(-1)) -DEFAULT_PA_SOFTMAX_IMPL = 'index_reduce' if 'index_reduce' in capabilities() else 'wsum_head_amax' +DEFAULT_PA_SOFTMAX_IMPL = 'wsum_head_amax' normalize = SoftmaxNormalization(os.environ.get('VLLM_PA_SOFTMAX_IMPL', DEFAULT_PA_SOFTMAX_IMPL).split(',')) diff --git a/vllm_hpu_extension/test_capabilities.py b/vllm_hpu_extension/test_capabilities.py index 6d632fd3..ac9ab26d 100644 --- a/vllm_hpu_extension/test_capabilities.py +++ b/vllm_hpu_extension/test_capabilities.py @@ -95,3 +95,11 @@ def test_capability_checks(capabilities): assert "foo,qux" in capabilities assert "qux,foo" in capabilities assert "foo,bar,qux" not in capabilities + + +def test_capability_signed_checks(capabilities): + assert "-bar" in capabilities + assert "+foo" in capabilities + assert "+foo,-bar,+qux" in capabilities + assert "+foo,bar,+qux" not in capabilities + assert "-foo,-bar,+qux" not in capabilities