Skip to content

Commit

Permalink
Add IREE numerics test for Llama 3.1 8B FP16 TP8
Browse files Browse the repository at this point in the history
Introduce a Llama 3.1 8B FP16 TP8 test that appears to not have good
numerical accuracy. It is compared to an fp64 unsharded torch variant
to ensure that the reference is of high accuracy.

Refactor the sharded Llama tests. Increase code reuse and use the
TorchGenerator in the toy-sized tests. Use the shard_llm_dataset and
export_paged_llm_v1 scripts in the test flow to increase their test
coverage.
  • Loading branch information
sogartar committed Oct 31, 2024
1 parent 6ff055e commit 15fae20
Show file tree
Hide file tree
Showing 6 changed files with 435 additions and 364 deletions.
13 changes: 5 additions & 8 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@
from sharktank.layers import *
from sharktank.types import *

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.llama.sharding import shard_theta
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from .. import ops


def main():
def main(raw_args: list[str] | None = None):
from ..utils import cli

parser = cli.create_parser()
Expand Down Expand Up @@ -60,7 +57,7 @@ def main():
choices=["decomposed", "torch"],
)

args = cli.parse(parser)
args = cli.parse(parser, args=raw_args)
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)
Expand Down Expand Up @@ -110,7 +107,7 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):

fxb = FxProgramsBuilder(model)

def setup_cache(model, shard_count):
def setup_cache(model):
if model.config.kv_cache_type == "paged":
cache_state = model.cache.allocate(
page_count=hp.context_length // llama_config.block_seq_stride
Expand Down Expand Up @@ -161,7 +158,7 @@ def generate_batch_prefill(bs: int):
sl_dim = llama_config.block_seq_stride * block_dim

cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
model, llama_config.tensor_parallelism_size
model
)

# We need to offset the indices for the cache
Expand Down Expand Up @@ -234,7 +231,7 @@ def generate_batch_decode(bs: int):
cache_shard_dim,
cache_dynamic_shapes,
arg_affinities,
) = setup_cache(model, llama_config.tensor_parallelism_size)
) = setup_cache(model)

# We need to offset the indices for the cache
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}
Expand Down
5 changes: 4 additions & 1 deletion sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def shard_state(
"""Shard an unsharded state.
We can't just split the slab on the sub page dims.
First it needs to be reinterpreted into the actual shape.
The split the head dimension, then flatten each shard.
Then split the head dimension, then flatten each shard.
This is a work-around for the lack of block-cyclic sharded tensor type."""
if self.shard_count == 1:
return state
Expand All @@ -324,6 +324,9 @@ def shard_state(
flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1)
return [flat_sharded_page_table]

def unshard_state(self, state: list[SplitPrimitiveTensor]) -> list[torch.Tensor]:
return [ops.unshard(self.unflatten_page_table(state)).flatten(start_dim=1)]

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride
Expand Down
23 changes: 0 additions & 23 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,6 @@ def decode(
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)

if self.config.tensor_parallelism_size > 1:
if not isinstance(tokens, ReplicatedTensor):
tokens = ops.replicate(
tokens, count=self.config.tensor_parallelism_size
)
if not isinstance(attention_mask, ReplicatedTensor):
attention_mask = ops.replicate(
attention_mask, count=self.config.tensor_parallelism_size
)
if not isinstance(start_positions, ReplicatedTensor):
start_positions = ops.replicate(
start_positions, count=self.config.tensor_parallelism_size
)
if not isinstance(seq_block_ids, ReplicatedTensor):
seq_block_ids = ops.replicate(
seq_block_ids, count=self.config.tensor_parallelism_size
)
# If the user provided unsharded arguments they probably want
# an unsharded result as well.
unshard_result = True
else:
unshard_result = False

bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
Expand Down
33 changes: 33 additions & 0 deletions sharktank/sharktank/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
from typing import Any, Callable
from operator import eq
from collections.abc import Iterable
import pytest
from sharktank.utils.tokenizer import InferenceTokenizer

from ..types import *

longrun = pytest.mark.skipif("not config.getoption('longrun')")

# Range of torch.rand() is [0,1)
# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
def make_rand_torch(shape, dtype=torch.float32):
Expand All @@ -31,6 +35,16 @@ def tearDown(self):
shutil.rmtree(self._temp_dir, ignore_errors=True)


@pytest.mark.usefixtures("path_prefix")
class PathPrefixTestBase(TempDirTestBase):
"""Creates a temporary directory and uses it if a path prefix is not given."""

def setUp(self):
super().setUp()
if self.path_prefix is None:
self.path_prefix = f"{self._temp_dir}/"


class MainRunnerTestBase(TempDirTestBase):
"""Performs an in-process test of a `main(args)` func."""

Expand All @@ -54,6 +68,25 @@ def assertFileWritten(self, p: Path):
self.assertGreater(p.stat().st_size, 0, msg=f"Expected file {p} had zero size")


class ModuloTokenizer(InferenceTokenizer):
"""A tokenizer used for testing where we take a modulo of each character.
Guarantees that we are producing tokens of up to the max token ID."""

def __init__(self, vocabulary_size: int):
self.vocabulary_size = vocabulary_size

def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]:
return [
[ord(character) % self.vocabulary_size for character in text]
for text in texts
]

def _decode(self, tokens: list[list[int]]) -> list[str]:
return [
"".join([chr(token) for token in prompt_tokens]) for prompt_tokens in tokens
]


@contextlib.contextmanager
def temporary_directory(identifier: str):
"""Returns a context manager TemporaryDirectory suitable for testing.
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def pad_tokens(
return token_ids, lengths

@abstractmethod
def _encode(self, texts: list[str]) -> list[list[int]]:
def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]:
...

@abstractmethod
Expand Down
Loading

0 comments on commit 15fae20

Please sign in to comment.