Skip to content

Commit

Permalink
Allow specifying negative compatibilities + disable index_reduce (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana authored Nov 5, 2024
1 parent 565204e commit 0063520
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
13 changes: 12 additions & 1 deletion vllm_hpu_extension/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,27 @@ class Capabilities:
def __init__(self, features, environment):
self.all = set(features.keys())
self.enabled = set(name for name, check in features.items() if check(**environment))
self.disabled = self.all - self.enabled

def is_enabled(self, *names):
return all(n in self.enabled for n in names)

def is_disabled(self, *names):
return all(n in self.disabled for n in names)

def __repr__(self):
feature_list = [('+' if self.is_enabled(f) else '-') + f for f in sorted(self.all)]
return f'[{(" ").join(feature_list)}]'

def _check(self, name):
if name.startswith('-'):
return self.is_disabled(name[1:])
if name.startswith('+'):
return self.is_enabled(name[1:])
return self.is_enabled(name)

def __contains__(self, names):
return self.is_enabled(*names.split(','))
return all(self._check(name) for name in names.split(','))


@cache
Expand Down
14 changes: 9 additions & 5 deletions vllm_hpu_extension/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

from vllm.logger import init_logger
from vllm_hpu_extension.utils import is_fake_hpu
from functools import cache

logger = init_logger(__name__)

@cache
def lazy_logger():
from vllm.logger import init_logger
return init_logger(__name__)


def get_hw():
Expand All @@ -21,9 +24,10 @@ def get_hw():
return "gaudi2"
case htexp.synDeviceType.synDeviceGaudi3:
return "gaudi3"
from vllm_hpu_extension.utils import is_fake_hpu
if is_fake_hpu():
return "cpu"
logger.warning(f'Unknown device type: {device_type}')
lazy_logger().warning(f'Unknown device type: {device_type}')
return None


Expand All @@ -39,7 +43,7 @@ def get_build():
match = version_re.search(output.stdout)
if output.returncode == 0 and match:
return match.group('version')
logger.warning("Unable to detect habana-torch-plugin version!")
lazy_logger().warning("Unable to detect habana-torch-plugin version!")
return None


Expand Down
2 changes: 1 addition & 1 deletion vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def scatter_reduce(attn, batch_size, block_groups, **rest):
return attn.sub_(grouped_max.unsqueeze(-1).unsqueeze(-1))


DEFAULT_PA_SOFTMAX_IMPL = 'index_reduce' if 'index_reduce' in capabilities() else 'wsum_head_amax'
DEFAULT_PA_SOFTMAX_IMPL = 'wsum_head_amax'
normalize = SoftmaxNormalization(os.environ.get('VLLM_PA_SOFTMAX_IMPL', DEFAULT_PA_SOFTMAX_IMPL).split(','))


Expand Down
8 changes: 8 additions & 0 deletions vllm_hpu_extension/test_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,11 @@ def test_capability_checks(capabilities):
assert "foo,qux" in capabilities
assert "qux,foo" in capabilities
assert "foo,bar,qux" not in capabilities


def test_capability_signed_checks(capabilities):
assert "-bar" in capabilities
assert "+foo" in capabilities
assert "+foo,-bar,+qux" in capabilities
assert "+foo,bar,+qux" not in capabilities
assert "-foo,-bar,+qux" not in capabilities

0 comments on commit 0063520

Please sign in to comment.