Skip to content

Commit

Permalink
Support mllama (llama 3.2) model for HPU (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
yisonzhu authored Dec 10, 2024
1 parent 3473bc1 commit 239739c
Show file tree
Hide file tree
Showing 7 changed files with 935 additions and 42 deletions.
76 changes: 75 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,80 @@ def hf_runner():
return HfRunner


class HfHPURunner(HfRunner):

def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if device is None:
device = "cpu" if current_platform.is_cpu() else "hpu"

if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}

if hasattr(x, "device") and x.device.type == device:
return x

return x.to(device)

def __init__(
self,
model_name: str,
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

self.model_name = model_name

model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs,
).eval())

from habana_frameworks.torch.hpu import wrap_in_hpu_graph
wrap_done = False
if hasattr(self.model, "language_model"):
self.model.language_model = wrap_in_hpu_graph(
self.model.language_model)
wrap_done = True
if hasattr(self.model, "vision_model"):
self.model.vision_model = wrap_in_hpu_graph(
self.model.vision_model)
wrap_done = True
if not wrap_done:
self.model = wrap_in_hpu_graph(self.model)

self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)

# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
self.dtype = dtype
self.postprocess_inputs = postprocess_inputs


@pytest.fixture(scope="session")
def hf_hpu_runner():
return HfHPURunner


class VllmRunner:

def __init__(
Expand All @@ -663,7 +737,7 @@ def __init__(
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
block_size: int = 16 if not current_platform.is_hpu() else 128,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = False,
Expand Down
22 changes: 22 additions & 0 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,28 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("sizes", [
[(512, 512), (512, 512), (512, 512)],
])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_hpu_models(hf_hpu_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_hpu_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
Expand Down
142 changes: 138 additions & 4 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
cross_block_indices: Optional[torch.Tensor] = None
cross_block_offsets: Optional[torch.Tensor] = None
cross_block_list: Optional[torch.Tensor] = None
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_mapping: Optional[torch.Tensor] = None
cross_block_groups: Optional[torch.Tensor] = None
cross_block_scales: Optional[torch.Tensor] = None
cross_block_usage: Optional[torch.Tensor] = None
cross_attn_bias: Optional[torch.Tensor] = None


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down Expand Up @@ -174,11 +186,22 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if (attn_type != AttentionType.DECODER
and attn_type != AttentionType.ENCODER_DECODER):
raise NotImplementedError("Encoder self-attention "
"is not implemented for "
"HPUAttentionImpl")
if attn_type == AttentionType.ENCODER_DECODER:
return self.forward_encoder_decoder(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
k_scale=k_scale,
v_scale=v_scale,
)

batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

Expand Down Expand Up @@ -274,6 +297,117 @@ def forward(
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

def forward_encoder_decoder(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
batch_size, hidden_size = query.shape

if attn_metadata.is_prompt:
batch_size = attn_metadata.num_prefills
batched_tokens, _ = query.shape
batched_kv_tokens, _, _ = key.shape
assert batch_size > 0, (
"In prefill stage the num_prefills should be > 0")
assert batched_tokens % batch_size == 0
assert batched_kv_tokens % batch_size == 0
seq_len = batched_tokens // batch_size

query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None

block_indices = attn_metadata.cross_block_indices
block_offsets = attn_metadata.cross_block_offsets
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if (key is not None) and (value is not None):
# During cross-attention decode, key & value will be None,
# we don't need to cache them.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
batch_size = attn_metadata.num_prefills

query_shape = (batch_size, -1, self.num_heads, self.head_size)
kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)
# Just a workaround, to make ops.prompt_attention go into the
# torch ops assembly path.
# TODO: add new prompt_attention op in vllm_hpu_extension
# which calls FusedSDPA with causal = False.
attn_bias = torch.zeros((batch_size, 1, 1, 1),
device=query.device,
dtype=torch.bool)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
block_list = attn_metadata.cross_block_list
block_mapping = attn_metadata.cross_block_mapping
block_scales = attn_metadata.cross_block_scales
block_groups = attn_metadata.cross_block_groups
attn_bias = attn_metadata.cross_attn_bias
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=block_list,
block_mapping=block_mapping,
block_bias=attn_bias,
block_scales=block_scales,
block_groups=block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, -1, hidden_size)


def _make_alibi_bias(
alibi_slopes: torch.Tensor,
Expand Down
16 changes: 16 additions & 0 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sequence import SequenceData
from vllm.utils import is_list_of

Expand All @@ -63,6 +64,8 @@
MLLAMA_IMAGE_TOKEN_ID = 128256
MLLAMA_IMAGE_TOKEN = "<|image|>"

is_hpu = current_platform.is_hpu()


class MllamaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
Expand Down Expand Up @@ -947,6 +950,14 @@ def forward(
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# the rank of full_text_row_masked_out_mask is 2, not match with
# the hidden_states, so expand its rank to 3.
# TODO: Change input_tokens tensor at the beginning of model execution
# to 2D tensor to align with public vllm input_tokens shape. But this
# will face the graph building failure issue, still need to investigate.
if len(hidden_states.shape) == 3:
full_text_row_masked_out_mask = full_text_row_masked_out_mask.view(
hidden_states.size(0), -1, 1)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
) * hidden_states
Expand Down Expand Up @@ -1016,6 +1027,11 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds

if is_hpu:
for idx, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, LlamaDecoderLayer):
self.layers[idx].self_attn.rotary_emb.prepare_cos_sin(
positions)
for idx, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention:
Expand Down
Loading

0 comments on commit 239739c

Please sign in to comment.