From 107016b5f31bbb87be8e7bdea3e08273d7084882 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 17 Oct 2024 15:37:28 +0000 Subject: [PATCH 01/12] fix(tgi): correct truncation in Jetstream Pytorch generator --- .../text_generation_server/jetstream_pt_support/generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index b0a3e250..683169e3 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -388,6 +388,8 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]: """ if max_length == 0: max_length = self.model.config.sequence_length + # Remove one to max_length because BOS is going to be added when padding + max_length -= 1 input_ids = self.tokenizer.encode( text, return_tensors="np", From 3a4ab501a9fa7b8c7801f0d8cf365551ff951645 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Tue, 19 Nov 2024 10:54:18 +0000 Subject: [PATCH 02/12] chore(ci): jetstream TGI tests also run on main on push --- .github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml index 5b909bf5..fb18d865 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml @@ -1,6 +1,10 @@ name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch on: + push: + branches: [ main ] + paths: + - "text-generation-inference/**" pull_request: branches: [ main ] paths: From d7fd56c2b65e46d65e9f218e812db4003c65a52e Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 6 Nov 2024 10:49:50 +0100 Subject: [PATCH 03/12] refactor(generator): inputs removed from slot This is not used anyway. --- .../text_generation_server/jetstream_pt_support/generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 683169e3..f9ff4940 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -55,7 +55,6 @@ def clear(self): self._state = Slot.State.EMPTY self._batch_id = None self._request_id = None - self._inputs = "" self._generation_config = None self._tokens = [] self._selector = None From 9785d2b4d2fe97eaee542d18cb29b96350b077de Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 18 Nov 2024 14:42:40 +0000 Subject: [PATCH 04/12] fix(generator): correct cached_batch and set slot numbers to batch_size The cached batch returned was wrong, because the generator expects only one cache batch returned per each prefill/decode call. Also, the slot size is now fixed: this will prevent creating and destroying elements in the slot list, so to better allow further optimizations and avoid JIT compilation. --- .../jetstream_pt_support/generator.py | 76 +++++++------------ 1 file changed, 27 insertions(+), 49 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index f9ff4940..74cc6eab 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -262,9 +262,8 @@ def __init__( tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - - # Slots are empty to begin with, they will be populated as new batches arrive - self.slots = [] + # Slots number is static, it cannot grow over the size of the batch + self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] self.batch_id = 0 # Note: this index will _never_ be decremented, and that's fine. self.slot_index = 0 @@ -362,13 +361,11 @@ def warmup(self, batch: Batch) -> int: seq_len = self.engine.env.seq_len return batch_size * seq_len - def _get_slot_id(self): - """Get the next available slot id.""" - batch_size = self.engine.env.batch_size - used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY] - for i in range(batch_size): - if i not in used_ids: - return i + def _get_slot(self): + """Get the next available slot.""" + for slot in self.slots: + if slot.state == Slot.State.EMPTY: + return slot # if we reach this point, all slots were used - this should not happen raise ValueError("All slots are used, but we should have stopped earlier") @@ -418,14 +415,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: A list of `Generation` for each request and a `CachedBatch` containing all pending requests. """ - slots = {state: [] for state in Slot.State} - for slot in self.slots: - slots[slot.state].append(slot) - len_active_slots = len(slots[Slot.State.READY]) - # Delete all empty slots, no need to have them anymore - empty_slots = slots[Slot.State.EMPTY] - for slot in empty_slots: - self.slots.remove(slot) + active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY] + len_active_slots = len(active_slots) + len_requests = len(batch.requests) model_batch_size = self.model.config.batch_size if model_batch_size is not None and model_batch_size < len_active_slots + len_requests: @@ -440,10 +432,10 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # Assign each request to an empty slot logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)") generations = [] - + prefilled_active_slots = [] for request in batch.requests: # Dynamically create a new slot for each request - slot = Slot(self._get_slot_id(), self.tokenizer) + slot = self._get_slot() self.prefill_slot.set(slot) self.slot_index += 1 slot.assign(self.batch_id, request, self.model.generation_config) @@ -474,20 +466,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: self._post_generate(slot, next_token, generations) if not slot.empty: - # append current to list of active slots - self.slots.append(slot) - len_active_slots += 1 - - batch = None - if len_active_slots > 0: - # Whatever initial batch these requests came from, we always return all pending requests in a single batch - request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] - batch = self._cached_batch(self.batch_id, request_ids) - else: - logger.debug("No more pending requests") + prefilled_active_slots.append(slot) + + cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots) self.batch_id += 1 logger.debug("Model ready for decoding") - return generations, batch + return generations, cached_batch def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray: pad_token_id = self.tokenizer.pad_token_id @@ -495,7 +479,7 @@ def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndar tokens = jnp.full((batch_size, 1), pad_token_id) for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots): # Every slot might have a different selection criteria, so we are obliged to call select in a loop - next_token = slot.select(logits) + next_token = slot.select(logits[slot.id : slot.id + 1, :]) tokens = tokens.at[slot.id].set(next_token) return tokens @@ -543,7 +527,6 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa select_fn = jax.tree_util.Partial(self._select_from_slots) self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn) - newly_empty = [] generations = [] for slot in active_slots: # Get the next token. @@ -556,20 +539,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa raise ValueError("Unexpected Slot is not ready for decoding") self._post_generate(slot, next_token, generations) - if slot.empty: - newly_empty.append(slot) - - # Remove empty slots - for slot in newly_empty: - self.slots.remove(slot) - batch = None - if len(self.slots) > 0: - # Whatever initial batch these requests came from, we always return all pending requests in a single batch - request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] - batch = self._cached_batch(next_batch_id, request_ids) - else: - logger.debug("No more pending requests") - return generations, batch + + cached_batch = self._cached_batch(next_batch_id, active_slots) + return generations, cached_batch def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None: """Post-generate a slot after the generation has been completed. @@ -617,7 +589,13 @@ def _post_generate(self, slot: Slot, next_token: int, generations: List[Generati ) ) - def _cached_batch(self, batch_id: int, request_ids: List): + def _cached_batch(self, batch_id: int, active_slots: List): + """Create a CachedBatch from the active slots. + """ + request_ids = [slot.request_id for slot in active_slots if slot.state == Slot.State.READY] + if len(request_ids) == 0: + logger.debug("No more pending requests") + return None size = len(request_ids) max_tokens = size * self.model.config.sequence_length return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) From 20ff0dba3730abf7be774a0a025dfea4128d3cf5 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Tue, 19 Nov 2024 14:04:18 +0000 Subject: [PATCH 05/12] feat(rng): improve randomness in sampling on Jetstream/Pt The randomness when sampling has been improved by splitting the key as suggested by the documentation of the JAX random submodule. --- .../jetstream_pt_support/token_selector.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py index fe31ad9c..ce0820c4 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -175,11 +175,13 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: """ scores = self.logits_processor(input_ids, logits) if self.mode == GenerationMode.SAMPLE: - return self._sample(scores) + # split the key to avoid reusing the same key for multiple samples + subkey, self.key = jax.random.split(self.key) + return self._sample(scores, subkey) else: return jnp.argmax(scores, axis=-1) - def _sample(self, scores: jnp.ndarray) -> jnp.ndarray: + def _sample(self, scores: jnp.ndarray, key) -> jnp.ndarray: do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1] do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0 @@ -188,14 +190,14 @@ def _sample(self, scores: jnp.ndarray) -> jnp.ndarray: scores, self.logits_warper.top_k, self.logits_warper.temperature, - self.key, + key, ) elif do_top_p: return sampling_utils.sample_nucleus_topp_logits( scores, self.logits_warper.top_p, self.logits_warper.temperature, - self.key, + key, ) - return jax.random.categorical(self.key, scores / self.logits_warper.temperature) + return jax.random.categorical(key, scores / self.logits_warper.temperature) From 071a7f9cfaef8a8c491f967ffc221000b531bcb8 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 18 Nov 2024 14:57:02 +0000 Subject: [PATCH 06/12] test(jetstream): added prefill and decode multiple tests A GPT2 test file exists to verify the generator behaviour when using the legacy Pytorch/XLA code, so now this test has been added to verify the same behaviour on the Jetstream/Pytorch counterpart. --- .../tests/test_tinyllama.py | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 text-generation-inference/tests/test_tinyllama.py diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py new file mode 100644 index 00000000..326288a1 --- /dev/null +++ b/text-generation-inference/tests/test_tinyllama.py @@ -0,0 +1,140 @@ + +import pytest +from helpers import create_request, prepare_model +from text_generation_server.auto_generator import AutoGenerator +from text_generation_server.pb.generate_pb2 import Batch +from tqdm import tqdm + +from optimum.tpu.jetstream_pt_support import jetstream_pt_available + + +MODEL_ID = "Maykeye/TinyLLama-v0" +SEQUENCE_LENGTH = 256 + + +@pytest.fixture(scope="module") +def model_path(): + return prepare_model(MODEL_ID, SEQUENCE_LENGTH) + + +def test_jetstream_info(model_path): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) + info = generator.info + assert info.requires_padding is True + assert info.device_type == "meta" + assert info.window_size == 0 + assert info.speculate == 0 + + +@pytest.mark.parametrize( + "input_text, token_id, token_text, do_sample", + [ + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 347, + " The", + False, + ], + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 13, + "\n", + True, + ], + ], + ids=["greedy", "sample"], +) +@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) +def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + requests = [] + max_new_tokens = 20 + for i in range(batch_size): + requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) + # Let's be pessimistic when estimating max_tokens + batch_size * (len(input_text) + max_new_tokens) + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == batch_size + # Whatever was passed as max_tokens, the server will correct it + # because of static batching + assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH + assert len(generations) == batch_size + for g in generations: + tokens = g.tokens + assert tokens.ids == [token_id] + assert tokens.texts == [token_text] + + +def test_jetstream_decode_multiple(model_path): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + generator = AutoGenerator.from_pretrained(model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH) + input_text = "Once upon a time" + max_new_tokens = 20 + # Prefill a single request, remembering the generated token + tokens = {0: [], 1: []} + request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) + batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == 1 + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == 1 + # Decode a few tokens + gen_tokens = 4 + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"): + generations, next_batch = generator.decode([next_batch]) + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == gen_tokens + assert next_batch.size == 1 + # Add a second request + request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens) + batch = Batch(id=1, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch_1 = generator.prefill(batch) + assert next_batch_1.size == 1 + # We should have generated only a single token + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == gen_tokens + assert len(tokens[1]) == 1 + # Decode more tokens until we reach the maximum for the first request + batches = [next_batch, next_batch_1] + for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"): + generations, next_batch = generator.decode(batches) + for g in generations: + tokens[g.request_id].append(g.tokens.ids[0]) + batches = [next_batch] + # Verify we now only have one pending request + assert next_batch.size == 1 + assert len(tokens[0]) == max_new_tokens + assert len(tokens[1]) == max_new_tokens - gen_tokens + 1 + # Verify we have the output for the first request + for g in generations: + if g.request_id == 0: + output = g.generated_text + assert output.text != "" + assert output.generated_tokens == max_new_tokens + generated_text = output.text + # Continue decoding until the end of the second request + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"): + generations, next_batch = generator.decode([next_batch]) + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert next_batch is None + output = generations[0].generated_text + assert output.generated_tokens == max_new_tokens + assert tokens[0] == tokens[1] + assert output.text == generated_text From f1dbc1327dba95a4f2fcdef1936ed000c1d5b969 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 20 Nov 2024 08:13:44 +0000 Subject: [PATCH 07/12] test(jetstream): added failing test to check sampling can be changed --- .../tests/test_tinyllama.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py index 326288a1..d7968bf0 100644 --- a/text-generation-inference/tests/test_tinyllama.py +++ b/text-generation-inference/tests/test_tinyllama.py @@ -70,6 +70,38 @@ def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_si assert tokens.texts == [token_text] +@pytest.xfail(reason="A bug prevents changing the sampling strategy after the first request") +def test_jetstream_prefill_change_sampling(model_path): + if not jetstream_pt_available(): + pytest.skip("Jetstream PyTorch is not available") + input_text = "It was a bright cold day in April, and the clocks were striking thirteen." + batch_size = 1 + greedy_expected_token_id = 347 + greedy_expected_text = " The" + sampling_expected_token_id = 13 + sampling_expected_text = "\n" + + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + max_new_tokens = 20 + + def check_request(do_sample, expected_token_id, expected_text): + requests = [create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)] + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) + generations, _ = generator.prefill(batch) + tokens = generations[0].tokens + print(tokens) + assert tokens.ids == [expected_token_id] + assert tokens.texts == [expected_text] + generator.clear() + + # First request is greedy + check_request(False, greedy_expected_token_id, greedy_expected_text) + # Second request is sampling + check_request(True, sampling_expected_token_id, sampling_expected_text) + # Third request is greedy again + check_request(False, greedy_expected_token_id, greedy_expected_text) + + def test_jetstream_decode_multiple(model_path): if not jetstream_pt_available(): pytest.skip("Jetstream PyTorch is not available") From d97a024e913f018f668e36bcb3f4e59e7e3c1b7a Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Tue, 19 Nov 2024 15:46:32 +0000 Subject: [PATCH 08/12] fix(jetstream): correct sampling for jetstream The Jetstream/Pt engine allows to pass a callback when using the prefill and generate methods. This callback is used to sample the generated token with custom function, but the caller function is JIT'ed, making a strong constraint on the callback signature. So far the callback was compiled on the first call, making it impossible to change the sampling algorithm on different requests. This commit fixes this issue by subclassing the PytorchEngine class and defining a new `prefill_ex` method that is not JIT'ed. The model calls are still compiled, so the performance should not be noticeably affected. --- .../jetstream_pt_support/engine_loader.py | 26 ++++++++++++++++++- .../jetstream_pt_support/generator.py | 6 ++--- .../tests/test_tinyllama.py | 1 - 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 49cc73b8..f7e48ef3 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -19,12 +19,35 @@ if TYPE_CHECKING: from transformers import PretrainedConfig + from transformers import AutoConfig from .compatibility import model_can_use_jetstream_pt from .models import GemmaModel, LlamaModel, MixtralModel +class OptimumJetstreamEngine(PyTorchEngine): + """This is essentially the same as the PytorchEngine, but it also supports a callback for sampling in prefill and + generation that can change on each request while not needing to be JIT'ed. + """ + prefill_ex = PyTorchEngine.prefill + + def __init__( + self, + pt_model: torch.nn.Module, + env: JetEngineEnvironment, + weights=None, + ): + super().__init__(pt_model, env, weights) + # Call model prefill and generate needs to be JIT'ed, because it is called with sharded notations, and it would + # otherwise not work for some models. + self._call_model_prefill = jax.jit( + self._call_model_prefill, + ) + self._call_model_generate = jax.jit( + self._call_model_generate, + ) + def _get_head_dim(config: "PretrainedConfig") -> int: if hasattr(config, "head_dim"): return config.head_dim @@ -174,8 +197,9 @@ def create_engine( logger.info(f"Quantization took {end - start:.2f} seconds") model_weights = model.state_dict() sharded_weights = shard_weights(env, model_weights, weight_shardings) - return PyTorchEngine( + engine = OptimumJetstreamEngine( pt_model=model, env=env, weights=torchjax.from_torch_with_copy(sharded_weights), ) + return engine diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 74cc6eab..98d6b019 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -455,7 +455,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # To allow jit'ing the select function, we need to wrap it in a partial slot_select = jax.tree_util.Partial(self.prefill_slot.select) # Ask for prefill and insert - prefill_results, _result_tokens = self.engine.prefill( + prefill_results, _result_tokens = self.engine.prefill_ex( params=self.params, padded_tokens=input_ids, true_length=true_lengths, @@ -524,8 +524,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") # Use a custom function to select the next token for each slot - select_fn = jax.tree_util.Partial(self._select_from_slots) - self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn) + # select_fn = jax.tree_util.Partial(self._select_from_slots) + self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots) generations = [] for slot in active_slots: diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py index d7968bf0..4ab71d06 100644 --- a/text-generation-inference/tests/test_tinyllama.py +++ b/text-generation-inference/tests/test_tinyllama.py @@ -70,7 +70,6 @@ def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_si assert tokens.texts == [token_text] -@pytest.xfail(reason="A bug prevents changing the sampling strategy after the first request") def test_jetstream_prefill_change_sampling(model_path): if not jetstream_pt_available(): pytest.skip("Jetstream PyTorch is not available") From d9e89bee8f7fe3f2d7a0bf1a8aa5ff35366963b6 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 20 Nov 2024 08:19:17 +0000 Subject: [PATCH 09/12] chore: bump version to 0.2.0 Minor version is increased mainly because of Jetstream Pytorch support on TGI. --- optimum/tpu/version.py | 2 +- .../server/text_generation_server/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index b6b27852..133af626 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -15,5 +15,5 @@ from packaging.version import parse -__version__ = "0.1.5" +__version__ = "0.1.6" VERSION = parse(__version__) diff --git a/text-generation-inference/server/text_generation_server/version.py b/text-generation-inference/server/text_generation_server/version.py index 55988374..493feadb 100644 --- a/text-generation-inference/server/text_generation_server/version.py +++ b/text-generation-inference/server/text_generation_server/version.py @@ -1,5 +1,5 @@ from pkg_resources import parse_version -__version__ = "0.1.5" +__version__ = "0.1.6" VERSION = parse_version(__version__) From 2308722fea138103549717a20e30545a50a8ee0f Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 20 Nov 2024 09:24:50 +0000 Subject: [PATCH 10/12] fix(version): version number was not correctly updated, fix it --- optimum/tpu/version.py | 2 +- .../server/text_generation_server/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index 133af626..6aa2623b 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -15,5 +15,5 @@ from packaging.version import parse -__version__ = "0.1.6" +__version__ = "0.2.0" VERSION = parse(__version__) diff --git a/text-generation-inference/server/text_generation_server/version.py b/text-generation-inference/server/text_generation_server/version.py index 493feadb..30c8700b 100644 --- a/text-generation-inference/server/text_generation_server/version.py +++ b/text-generation-inference/server/text_generation_server/version.py @@ -1,5 +1,5 @@ from pkg_resources import parse_version -__version__ = "0.1.6" +__version__ = "0.2.0" VERSION = parse_version(__version__) From 270435cedefa6af711a05ec7f8cf5d91458a0958 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 20 Nov 2024 09:27:02 +0000 Subject: [PATCH 11/12] review: remove commented code leftover --- .../text_generation_server/jetstream_pt_support/generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 98d6b019..97061421 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -524,7 +524,6 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") # Use a custom function to select the next token for each slot - # select_fn = jax.tree_util.Partial(self._select_from_slots) self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots) generations = [] From ad9e3e5c0f23c6f72dd77f0dfe37d61c65e84788 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 20 Nov 2024 09:56:50 +0000 Subject: [PATCH 12/12] review: add docstring to explain tests goals --- text-generation-inference/tests/test_tinyllama.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py index 4ab71d06..5566bccd 100644 --- a/text-generation-inference/tests/test_tinyllama.py +++ b/text-generation-inference/tests/test_tinyllama.py @@ -18,6 +18,7 @@ def model_path(): def test_jetstream_info(model_path): + """Verify the model info is correctly loaded and check expected results.""" if not jetstream_pt_available(): pytest.skip("Jetstream PyTorch is not available") generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) @@ -48,6 +49,8 @@ def test_jetstream_info(model_path): ) @pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + """Verify that prefilling a batch with a single request with different sampling techniques. + """ if not jetstream_pt_available(): pytest.skip("Jetstream PyTorch is not available") generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) @@ -71,6 +74,8 @@ def test_jetstream_prefill(input_text, token_id, token_text, do_sample, batch_si def test_jetstream_prefill_change_sampling(model_path): + """Verify changing the sampling strategy between requests in the same batch works as expected. + """ if not jetstream_pt_available(): pytest.skip("Jetstream PyTorch is not available") input_text = "It was a bright cold day in April, and the clocks were striking thirteen." @@ -102,6 +107,9 @@ def check_request(do_sample, expected_token_id, expected_text): def test_jetstream_decode_multiple(model_path): + """Verify that two requests added to the batch at different generation steps + generate the same outputs (continuous batching). + """ if not jetstream_pt_available(): pytest.skip("Jetstream PyTorch is not available") generator = AutoGenerator.from_pretrained(model_path,