diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml index 5b909bf5..fb18d865 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml @@ -1,6 +1,10 @@ name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch on: + push: + branches: [ main ] + paths: + - "text-generation-inference/**" pull_request: branches: [ main ] paths: diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index b6b27852..6aa2623b 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -15,5 +15,5 @@ from packaging.version import parse -__version__ = "0.1.5" +__version__ = "0.2.0" VERSION = parse(__version__) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 49cc73b8..f7e48ef3 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -19,12 +19,35 @@ if TYPE_CHECKING: from transformers import PretrainedConfig + from transformers import AutoConfig from .compatibility import model_can_use_jetstream_pt from .models import GemmaModel, LlamaModel, MixtralModel +class OptimumJetstreamEngine(PyTorchEngine): + """This is essentially the same as the PytorchEngine, but it also supports a callback for sampling in prefill and + generation that can change on each request while not needing to be JIT'ed. + """ + prefill_ex = PyTorchEngine.prefill + + def __init__( + self, + pt_model: torch.nn.Module, + env: JetEngineEnvironment, + weights=None, + ): + super().__init__(pt_model, env, weights) + # Call model prefill and generate needs to be JIT'ed, because it is called with sharded notations, and it would + # otherwise not work for some models. + self._call_model_prefill = jax.jit( + self._call_model_prefill, + ) + self._call_model_generate = jax.jit( + self._call_model_generate, + ) + def _get_head_dim(config: "PretrainedConfig") -> int: if hasattr(config, "head_dim"): return config.head_dim @@ -174,8 +197,9 @@ def create_engine( logger.info(f"Quantization took {end - start:.2f} seconds") model_weights = model.state_dict() sharded_weights = shard_weights(env, model_weights, weight_shardings) - return PyTorchEngine( + engine = OptimumJetstreamEngine( pt_model=model, env=env, weights=torchjax.from_torch_with_copy(sharded_weights), ) + return engine diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index b0a3e250..97061421 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -55,7 +55,6 @@ def clear(self): self._state = Slot.State.EMPTY self._batch_id = None self._request_id = None - self._inputs = "" self._generation_config = None self._tokens = [] self._selector = None @@ -263,9 +262,8 @@ def __init__( tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - - # Slots are empty to begin with, they will be populated as new batches arrive - self.slots = [] + # Slots number is static, it cannot grow over the size of the batch + self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] self.batch_id = 0 # Note: this index will _never_ be decremented, and that's fine. self.slot_index = 0 @@ -363,13 +361,11 @@ def warmup(self, batch: Batch) -> int: seq_len = self.engine.env.seq_len return batch_size * seq_len - def _get_slot_id(self): - """Get the next available slot id.""" - batch_size = self.engine.env.batch_size - used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY] - for i in range(batch_size): - if i not in used_ids: - return i + def _get_slot(self): + """Get the next available slot.""" + for slot in self.slots: + if slot.state == Slot.State.EMPTY: + return slot # if we reach this point, all slots were used - this should not happen raise ValueError("All slots are used, but we should have stopped earlier") @@ -388,6 +384,8 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]: """ if max_length == 0: max_length = self.model.config.sequence_length + # Remove one to max_length because BOS is going to be added when padding + max_length -= 1 input_ids = self.tokenizer.encode( text, return_tensors="np", @@ -417,14 +415,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: A list of `Generation` for each request and a `CachedBatch` containing all pending requests. """ - slots = {state: [] for state in Slot.State} - for slot in self.slots: - slots[slot.state].append(slot) - len_active_slots = len(slots[Slot.State.READY]) - # Delete all empty slots, no need to have them anymore - empty_slots = slots[Slot.State.EMPTY] - for slot in empty_slots: - self.slots.remove(slot) + active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY] + len_active_slots = len(active_slots) + len_requests = len(batch.requests) model_batch_size = self.model.config.batch_size if model_batch_size is not None and model_batch_size < len_active_slots + len_requests: @@ -439,10 +432,10 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # Assign each request to an empty slot logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)") generations = [] - + prefilled_active_slots = [] for request in batch.requests: # Dynamically create a new slot for each request - slot = Slot(self._get_slot_id(), self.tokenizer) + slot = self._get_slot() self.prefill_slot.set(slot) self.slot_index += 1 slot.assign(self.batch_id, request, self.model.generation_config) @@ -462,7 +455,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # To allow jit'ing the select function, we need to wrap it in a partial slot_select = jax.tree_util.Partial(self.prefill_slot.select) # Ask for prefill and insert - prefill_results, _result_tokens = self.engine.prefill( + prefill_results, _result_tokens = self.engine.prefill_ex( params=self.params, padded_tokens=input_ids, true_length=true_lengths, @@ -473,20 +466,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: self._post_generate(slot, next_token, generations) if not slot.empty: - # append current to list of active slots - self.slots.append(slot) - len_active_slots += 1 - - batch = None - if len_active_slots > 0: - # Whatever initial batch these requests came from, we always return all pending requests in a single batch - request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] - batch = self._cached_batch(self.batch_id, request_ids) - else: - logger.debug("No more pending requests") + prefilled_active_slots.append(slot) + + cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots) self.batch_id += 1 logger.debug("Model ready for decoding") - return generations, batch + return generations, cached_batch def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray: pad_token_id = self.tokenizer.pad_token_id @@ -494,7 +479,7 @@ def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndar tokens = jnp.full((batch_size, 1), pad_token_id) for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots): # Every slot might have a different selection criteria, so we are obliged to call select in a loop - next_token = slot.select(logits) + next_token = slot.select(logits[slot.id : slot.id + 1, :]) tokens = tokens.at[slot.id].set(next_token) return tokens @@ -539,10 +524,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") # Use a custom function to select the next token for each slot - select_fn = jax.tree_util.Partial(self._select_from_slots) - self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn) + self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots) - newly_empty = [] generations = [] for slot in active_slots: # Get the next token. @@ -555,20 +538,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa raise ValueError("Unexpected Slot is not ready for decoding") self._post_generate(slot, next_token, generations) - if slot.empty: - newly_empty.append(slot) - - # Remove empty slots - for slot in newly_empty: - self.slots.remove(slot) - batch = None - if len(self.slots) > 0: - # Whatever initial batch these requests came from, we always return all pending requests in a single batch - request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] - batch = self._cached_batch(next_batch_id, request_ids) - else: - logger.debug("No more pending requests") - return generations, batch + + cached_batch = self._cached_batch(next_batch_id, active_slots) + return generations, cached_batch def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None: """Post-generate a slot after the generation has been completed. @@ -616,7 +588,13 @@ def _post_generate(self, slot: Slot, next_token: int, generations: List[Generati ) ) - def _cached_batch(self, batch_id: int, request_ids: List): + def _cached_batch(self, batch_id: int, active_slots: List): + """Create a CachedBatch from the active slots. + """ + request_ids = [slot.request_id for slot in active_slots if slot.state == Slot.State.READY] + if len(request_ids) == 0: + logger.debug("No more pending requests") + return None size = len(request_ids) max_tokens = size * self.model.config.sequence_length return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py index fe31ad9c..ce0820c4 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -175,11 +175,13 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: """ scores = self.logits_processor(input_ids, logits) if self.mode == GenerationMode.SAMPLE: - return self._sample(scores) + # split the key to avoid reusing the same key for multiple samples + subkey, self.key = jax.random.split(self.key) + return self._sample(scores, subkey) else: return jnp.argmax(scores, axis=-1) - def _sample(self, scores: jnp.ndarray) -> jnp.ndarray: + def _sample(self, scores: jnp.ndarray, key) -> jnp.ndarray: do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1] do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0 @@ -188,14 +190,14 @@ def _sample(self, scores: jnp.ndarray) -> jnp.ndarray: scores, self.logits_warper.top_k, self.logits_warper.temperature, - self.key, + key, ) elif do_top_p: return sampling_utils.sample_nucleus_topp_logits( scores, self.logits_warper.top_p, self.logits_warper.temperature, - self.key, + key, ) - return jax.random.categorical(self.key, scores / self.logits_warper.temperature) + return jax.random.categorical(key, scores / self.logits_warper.temperature) diff --git a/text-generation-inference/server/text_generation_server/version.py b/text-generation-inference/server/text_generation_server/version.py index 55988374..30c8700b 100644 --- a/text-generation-inference/server/text_generation_server/version.py +++ b/text-generation-inference/server/text_generation_server/version.py @@ -1,5 +1,5 @@ from pkg_resources import parse_version -__version__ = "0.1.5" +__version__ = "0.2.0" VERSION = parse_version(__version__) diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py new file mode 100644 index 00000000..5566bccd --- /dev/null +++ b/text-generation-inference/tests/test_tinyllama.py @@ -0,0 +1,179 @@ + +import pytest +from helpers import create_request, prepare_model +from text_generation_server.auto_generator import AutoGenerator +from text_generation_server.pb.generate_pb2 import Batch +from tqdm import tqdm + +from optimum.tpu.jetstream_pt_support import jetstream_pt_available + + +MODEL_ID = "Maykeye/TinyLLama-v0" +SEQUENCE_LENGTH = 256 + + +@pytest.fixture(scope="module") +def model_path(): + return prepare_model(MODEL_ID, SEQUENCE_LENGTH) + + +def test_jetstream_info(model_path): + """Verify the model info is correctly loaded and check expected results.""" + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) + info = generator.info + assert info.requires_padding is True + assert info.device_type == "meta" + assert info.window_size == 0 + assert info.speculate == 0 + + +@pytest.mark.parametrize( + "input_text, token_id, token_text, do_sample", + [ + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 347, + " The", + False, + ], + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 13, + "\n", + True, + ], + ], + ids=["greedy", "sample"], +) +@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) +def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + """Verify that prefilling a batch with a single request with different sampling techniques. + """ + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + requests = [] + max_new_tokens = 20 + for i in range(batch_size): + requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) + # Let's be pessimistic when estimating max_tokens + batch_size * (len(input_text) + max_new_tokens) + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == batch_size + # Whatever was passed as max_tokens, the server will correct it + # because of static batching + assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH + assert len(generations) == batch_size + for g in generations: + tokens = g.tokens + assert tokens.ids == [token_id] + assert tokens.texts == [token_text] + + +def test_jetstream_prefill_change_sampling(model_path): + """Verify changing the sampling strategy between requests in the same batch works as expected. + """ + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + input_text = "It was a bright cold day in April, and the clocks were striking thirteen." + batch_size = 1 + greedy_expected_token_id = 347 + greedy_expected_text = " The" + sampling_expected_token_id = 13 + sampling_expected_text = "\n" + + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + max_new_tokens = 20 + + def check_request(do_sample, expected_token_id, expected_text): + requests = [create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)] + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) + generations, _ = generator.prefill(batch) + tokens = generations[0].tokens + print(tokens) + assert tokens.ids == [expected_token_id] + assert tokens.texts == [expected_text] + generator.clear() + + # First request is greedy + check_request(False, greedy_expected_token_id, greedy_expected_text) + # Second request is sampling + check_request(True, sampling_expected_token_id, sampling_expected_text) + # Third request is greedy again + check_request(False, greedy_expected_token_id, greedy_expected_text) + + +def test_jetstream_decode_multiple(model_path): + """Verify that two requests added to the batch at different generation steps + generate the same outputs (continuous batching). + """ + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + generator = AutoGenerator.from_pretrained(model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH) + input_text = "Once upon a time" + max_new_tokens = 20 + # Prefill a single request, remembering the generated token + tokens = {0: [], 1: []} + request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) + batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == 1 + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == 1 + # Decode a few tokens + gen_tokens = 4 + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"): + generations, next_batch = generator.decode([next_batch]) + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == gen_tokens + assert next_batch.size == 1 + # Add a second request + request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens) + batch = Batch(id=1, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch_1 = generator.prefill(batch) + assert next_batch_1.size == 1 + # We should have generated only a single token + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == gen_tokens + assert len(tokens[1]) == 1 + # Decode more tokens until we reach the maximum for the first request + batches = [next_batch, next_batch_1] + for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"): + generations, next_batch = generator.decode(batches) + for g in generations: + tokens[g.request_id].append(g.tokens.ids[0]) + batches = [next_batch] + # Verify we now only have one pending request + assert next_batch.size == 1 + assert len(tokens[0]) == max_new_tokens + assert len(tokens[1]) == max_new_tokens - gen_tokens + 1 + # Verify we have the output for the first request + for g in generations: + if g.request_id == 0: + output = g.generated_text + assert output.text != "" + assert output.generated_tokens == max_new_tokens + generated_text = output.text + # Continue decoding until the end of the second request + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"): + generations, next_batch = generator.decode([next_batch]) + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert next_batch is None + output = generations[0].generated_text + assert output.generated_tokens == max_new_tokens + assert tokens[0] == tokens[1] + assert output.text == generated_text