Skip to content

Commit

Permalink
Add exporting and numerics verification for CLIP Large text model wit…
Browse files Browse the repository at this point in the history
…h IREE (#664)

Add exporting to MLIR and IRRE parameters. We don't make the context
length dynamic since the maximum is only 77 anyway, so the token
sequences are padded to 77. We could explore later making this dynamic.

This adds comparison of IREE execution of float32, bfloat16 model
variants against float32 torch eager. For bfloat16 results are close up
to 1.43e-2 using cosine similarity.

Toy-sized model comparison for float32 and bfloat16 is also provided.
  • Loading branch information
sogartar authored Dec 10, 2024
1 parent 4c015d4 commit 36c0859
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 79 deletions.
32 changes: 29 additions & 3 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Any, Optional
import torch

from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name

__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]


Expand Down Expand Up @@ -287,9 +289,10 @@ class ClipTextConfig:
output_attentions: bool = False
output_hidden_states: bool = False
use_return_dict: bool = True
dtype: torch.dtype = torch.float32

@staticmethod
def from_transformers_clip_text_config(
def from_hugging_face_clip_text_model_config(
config: "transformers.CLIPTextConfig",
) -> "ClipTextConfig":
return ClipTextConfig(
Expand All @@ -308,7 +311,30 @@ def from_transformers_clip_text_config(
output_attentions=config.output_attentions,
output_hidden_states=config.output_hidden_states,
use_return_dict=config.use_return_dict,
dtype=config.torch_dtype or torch.float32,
)

def as_properties(self) -> dict[str, Any]:
return asdict(self)
def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig":
kwargs = self.to_properties()
kwargs["torch_dtype"] = kwargs["dtype"]
del kwargs["dtype"]
kwargs["return_dict"] = kwargs["use_return_dict"]
del kwargs["use_return_dict"]
from transformers import CLIPTextConfig

return CLIPTextConfig(**kwargs)

@staticmethod
def from_properties(properties: dict[str, Any]) -> "ClipTextConfig":
kwargs = dict(properties)
kwargs.pop("SHARK_DATASET_VERSION")
if "dtype" in kwargs and kwargs["dtype"] is not None:
kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"])

return ClipTextConfig(**kwargs)

def to_properties(self) -> dict[str, Any]:
res = asdict(self)
if self.dtype is not None:
res["dtype"] = dtype_to_serialized_name(self.dtype)
return res
47 changes: 33 additions & 14 deletions sharktank/sharktank/models/clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
)
from collections import OrderedDict

from ...layers import BaseLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer
from ...layers import ThetaLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer
from ... import ops
from ...types.theta import Theta, Dataset
from ...types.tensors import DefaultPrimitiveTensor
from ...types.tensors import AnyTensor, DefaultPrimitiveTensor
from ...layers.configs import ClipTextConfig
from ...layers.activations import ACT2FN

Expand Down Expand Up @@ -68,11 +68,11 @@ def forward(
return embeddings


class ClipAttention(BaseLayer):
class ClipAttention(ThetaLayer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
Expand Down Expand Up @@ -182,9 +182,9 @@ def forward(
return attn_output, attn_weights_reshaped


class ClipMlp(BaseLayer):
class ClipMlp(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = LinearLayer(theta("fc1"))
Expand All @@ -197,9 +197,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class ClipEncoderLayer(BaseLayer):
class ClipEncoderLayer(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.embed_dim = config.hidden_size
self.self_attn = ClipAttention(theta=theta("self_attn"), config=config)
self.layer_norm1 = LayerNorm(
Expand Down Expand Up @@ -251,14 +251,14 @@ def forward(
return outputs


class ClipEncoder(BaseLayer):
class ClipEncoder(ThetaLayer):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`ClipEncoderLayer`].
"""

def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
self.layers = nn.ModuleList(
[
Expand Down Expand Up @@ -356,9 +356,9 @@ def forward(
)


class ClipTextTransformer(nn.Module):
class ClipTextTransformer(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
embed_dim = config.hidden_size
self.embeddings = ClipTextEmbeddings(theta=theta("embeddings"), config=config)
Expand Down Expand Up @@ -475,9 +475,9 @@ def forward(
)


class ClipTextModel(BaseLayer):
class ClipTextModel(ThetaLayer):
def __init__(self, theta: Theta, config: ClipTextConfig):
super().__init__()
super().__init__(theta)
self.config = config
self.text_model = ClipTextTransformer(theta=theta("text_model"), config=config)

Expand All @@ -487,6 +487,25 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value

def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]:
input_ids = (
torch.arange(
start=0,
end=batch_size * self.config.max_position_embeddings,
dtype=torch.long,
)
% self.config.vocab_size
)
input_ids = input_ids.reshape([batch_size, self.config.max_position_embeddings])
return OrderedDict(
[
(
"input_ids",
input_ids,
)
]
)

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
Expand Down
74 changes: 59 additions & 15 deletions sharktank/sharktank/models/clip/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,98 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Union
from typing import Optional, Union
import transformers
from transformers.models.clip.modeling_clip import (
CLIPAttention as TransformersCLIPAttention,
CLIPEncoderLayer as TransformersCLIPEncoderLayer,
CLIPEncoder as TransformersCLIPEncoder,
CLIPAttention as HfCLIPAttention,
CLIPEncoderLayer as HfCLIPEncoderLayer,
CLIPEncoder as HfCLIPEncoder,
)
from os import PathLike
import torch

from ...types.theta import Theta, Dataset, torch_module_to_theta
from ...types.tensors import DefaultPrimitiveTensor
from ...layers.configs import ClipTextConfig
from .clip import ClipTextModel
from iree.turbine.aot import FxProgramsBuilder, export


def transformers_clip_attention_to_theta(model: TransformersCLIPAttention) -> Theta:
def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta:
return torch_module_to_theta(model)


def transformers_clip_encoder_layer_to_theta(model: TransformersCLIPEncoder) -> Theta:
def hugging_face_clip_encoder_layer_to_theta(model: HfCLIPEncoder) -> Theta:
return torch_module_to_theta(model)


def transformers_clip_encoder_to_theta(model: TransformersCLIPEncoderLayer) -> Theta:
def hugging_face_clip_encoder_to_theta(model: HfCLIPEncoderLayer) -> Theta:
return torch_module_to_theta(model)


def transformers_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta:
def hugging_face_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta:
return torch_module_to_theta(model)


def transformers_clip_text_model_to_dataset(
def hugging_face_clip_text_model_to_dataset(
model: transformers.CLIPTextModel,
) -> Dataset:
config = ClipTextConfig.from_transformers_clip_text_config(model.config)
properties = config.as_properties()
theta = transformers_clip_text_model_to_theta(model)
config = ClipTextConfig.from_hugging_face_clip_text_model_config(model.config)
properties = config.to_properties()
theta = hugging_face_clip_text_model_to_theta(model)
theta.rename_tensors_to_paths()
return Dataset(properties, theta)


def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset:
return Dataset(properties=model.config.to_properties(), root_theta=model.theta)


def export_clip_text_model_dataset_from_hugging_face(
model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel],
output_path: Union[str, PathLike],
dtype: Optional[torch.dtype] = None,
):
if isinstance(model_or_name_or_path, transformers.CLIPTextModel):
assert dtype is None
model = model_or_name_or_path
else:
model = transformers.CLIPTextModel.from_pretrained(model_or_name_or_path)
dataset = transformers_clip_text_model_to_dataset(model)
model = transformers.CLIPTextModel.from_pretrained(
model_or_name_or_path, torch_dtype=dtype
)
dataset = hugging_face_clip_text_model_to_dataset(model)
dataset.save(output_path)


def export_clip_text_model_mlir(
model: Union[ClipTextModel, PathLike],
batch_sizes: list[int],
mlir_output_path: str,
):
"""
Args:
model: either the torch module or path to GGUF/IRPA.
"""
if not isinstance(model, ClipTextModel):
dataset = Dataset.load(model)
config = ClipTextConfig.from_properties(dataset.properties)
model = ClipTextModel(theta=dataset.root_theta, config=config)

fxb = FxProgramsBuilder(model)

for batch_size in batch_sizes:
sample_inputs = model.sample_inputs(batch_size)

@fxb.export_program(
name=f"forward_bs{batch_size}",
args=tuple(sample_inputs.values()),
dynamic_shapes=None,
strict=False,
)
def _(
model,
input_ids,
):
return model(input_ids)

output = export(fxb, import_symbolic_shape_expressions=True)
output.save_mlir(mlir_output_path)
37 changes: 37 additions & 0 deletions sharktank/sharktank/models/clip/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ...layers.configs.llm_configs import ClipTextConfig
from ...types.theta import Theta
from .export import hugging_face_clip_text_model_to_theta
import torch


def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta:
from transformers import CLIPTextConfig as HfCLIPTextConfig
from transformers import CLIPTextModel as HfCLIPTextModel

hf_config = config.to_hugging_face_clip_text_model_config()
model = HfCLIPTextModel(hf_config)
return hugging_face_clip_text_model_to_theta(model)


def make_random_input_token_sequences(
batch_size: int, config: ClipTextConfig
) -> torch.LongTensor:
sequence_lens = torch.randint(
low=1, high=config.max_position_embeddings + 1, size=(batch_size,)
)
sequences = torch.full(
size=(batch_size, config.max_position_embeddings),
fill_value=config.eos_token_id,
dtype=torch.long,
)
for batch_idx, l in enumerate(sequence_lens):
sequences[batch_idx][0:l] = torch.randint(
low=0, high=config.vocab_size - 1, size=(l,), dtype=torch.long
)
return sequences
4 changes: 3 additions & 1 deletion sharktank/sharktank/types/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,14 @@ def rename_tensors_to_paths(self):


def torch_module_to_theta(module: torch.nn.Module) -> Theta:
return Theta(
res = Theta(
{
name: DefaultPrimitiveTensor(data=param)
for name, param in module.named_parameters()
}
)
res.rename_tensors_to_paths()
return res


def flat_to_nested_dict(flat: dict[str, Any]) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def round_up_to_multiple_of(x: Number, multiple: Number) -> Number:

def cosine_similarity(
a: torch.Tensor, b: torch.Tensor, /, *, dim: Optional[Union[int, tuple[int]]] = None
) -> float:
) -> torch.Tensor:
"""Compute cosine similarity over dimensions dim.
If dim is none computes over all dimensions."""
dot_product = torch.sum(a * b, dim=dim)
Expand Down
31 changes: 31 additions & 0 deletions sharktank/sharktank/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import gc

from ..types import *
from .math import cosine_similarity

# Range of torch.rand() is [0,1)
# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
Expand Down Expand Up @@ -184,6 +185,36 @@ def assert_iterables_equal(
), f"Iterables not equal at index {i} for elements {v1} and {v2}"


def assert_text_encoder_state_close(
actual: torch.Tensor, expected: torch.Tensor, atol: float
):
"""The cosine similarity has been suggested to compare encoder states.
Dehua Peng, Zhipeng Gui, Huayi Wu -
Interpreting the Curse of Dimensionality from Distance Concentration and Manifold
Effect (2023)
shows that cosine and all Minkowski distances suffer from the curse of
dimensionality.
The cosine similarity ignores the vector magnitudes. We can probably come up with a
better metric, but this is maybe good enough.
The functions expects that the last dimension is the features per token.
It will compute the cosine similarity for each token.
"""
cosine_similarity_per_token = cosine_similarity(
actual,
expected,
dim=-1,
)
torch.testing.assert_close(
cosine_similarity_per_token,
torch.ones_like(cosine_similarity_per_token),
atol=atol,
rtol=0,
)


SHARKTANK_TEST_SKIP_ENV_VAR = "SHARKTANK_TEST_SKIP"


Expand Down
Loading

0 comments on commit 36c0859

Please sign in to comment.