-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More Jetstream Pytorch fixes, prepare for release #116
Merged
Merged
Changes from 9 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
107016b
fix(tgi): correct truncation in Jetstream Pytorch generator
tengomucho 3a4ab50
chore(ci): jetstream TGI tests also run on main on push
tengomucho d7fd56c
refactor(generator): inputs removed from slot
tengomucho 9785d2b
fix(generator): correct cached_batch and set slot numbers to batch_size
tengomucho 20ff0db
feat(rng): improve randomness in sampling on Jetstream/Pt
tengomucho 071a7f9
test(jetstream): added prefill and decode multiple tests
tengomucho f1dbc13
test(jetstream): added failing test to check sampling can be changed
tengomucho d97a024
fix(jetstream): correct sampling for jetstream
tengomucho d9e89be
chore: bump version to 0.2.0
tengomucho 2308722
fix(version): version number was not correctly updated, fix it
tengomucho 270435c
review: remove commented code leftover
tengomucho ad9e3e5
review: add docstring to explain tests goals
tengomucho File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,5 +15,5 @@ | |
from packaging.version import parse | ||
|
||
|
||
__version__ = "0.1.5" | ||
__version__ = "0.1.6" | ||
VERSION = parse(__version__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,7 +55,6 @@ def clear(self): | |
self._state = Slot.State.EMPTY | ||
self._batch_id = None | ||
self._request_id = None | ||
self._inputs = "" | ||
self._generation_config = None | ||
self._tokens = [] | ||
self._selector = None | ||
|
@@ -263,9 +262,8 @@ def __init__( | |
tokenizer.truncation_side = "left" | ||
self.tokenizer = tokenizer | ||
self.special_tokens = self.tokenizer.all_special_ids | ||
|
||
# Slots are empty to begin with, they will be populated as new batches arrive | ||
self.slots = [] | ||
# Slots number is static, it cannot grow over the size of the batch | ||
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] | ||
self.batch_id = 0 | ||
# Note: this index will _never_ be decremented, and that's fine. | ||
self.slot_index = 0 | ||
|
@@ -363,13 +361,11 @@ def warmup(self, batch: Batch) -> int: | |
seq_len = self.engine.env.seq_len | ||
return batch_size * seq_len | ||
|
||
def _get_slot_id(self): | ||
"""Get the next available slot id.""" | ||
batch_size = self.engine.env.batch_size | ||
used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY] | ||
for i in range(batch_size): | ||
if i not in used_ids: | ||
return i | ||
def _get_slot(self): | ||
"""Get the next available slot.""" | ||
for slot in self.slots: | ||
if slot.state == Slot.State.EMPTY: | ||
return slot | ||
# if we reach this point, all slots were used - this should not happen | ||
raise ValueError("All slots are used, but we should have stopped earlier") | ||
|
||
|
@@ -388,6 +384,8 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]: | |
""" | ||
if max_length == 0: | ||
max_length = self.model.config.sequence_length | ||
# Remove one to max_length because BOS is going to be added when padding | ||
max_length -= 1 | ||
input_ids = self.tokenizer.encode( | ||
text, | ||
return_tensors="np", | ||
|
@@ -417,14 +415,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
A list of `Generation` for each request and a `CachedBatch` containing all pending requests. | ||
""" | ||
|
||
slots = {state: [] for state in Slot.State} | ||
for slot in self.slots: | ||
slots[slot.state].append(slot) | ||
len_active_slots = len(slots[Slot.State.READY]) | ||
# Delete all empty slots, no need to have them anymore | ||
empty_slots = slots[Slot.State.EMPTY] | ||
for slot in empty_slots: | ||
self.slots.remove(slot) | ||
active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY] | ||
len_active_slots = len(active_slots) | ||
|
||
len_requests = len(batch.requests) | ||
model_batch_size = self.model.config.batch_size | ||
if model_batch_size is not None and model_batch_size < len_active_slots + len_requests: | ||
|
@@ -439,10 +432,10 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
# Assign each request to an empty slot | ||
logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)") | ||
generations = [] | ||
|
||
prefilled_active_slots = [] | ||
for request in batch.requests: | ||
# Dynamically create a new slot for each request | ||
slot = Slot(self._get_slot_id(), self.tokenizer) | ||
slot = self._get_slot() | ||
self.prefill_slot.set(slot) | ||
self.slot_index += 1 | ||
slot.assign(self.batch_id, request, self.model.generation_config) | ||
|
@@ -462,7 +455,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
# To allow jit'ing the select function, we need to wrap it in a partial | ||
slot_select = jax.tree_util.Partial(self.prefill_slot.select) | ||
# Ask for prefill and insert | ||
prefill_results, _result_tokens = self.engine.prefill( | ||
prefill_results, _result_tokens = self.engine.prefill_ex( | ||
baptistecolle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
params=self.params, | ||
padded_tokens=input_ids, | ||
true_length=true_lengths, | ||
|
@@ -473,28 +466,20 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
|
||
self._post_generate(slot, next_token, generations) | ||
if not slot.empty: | ||
# append current to list of active slots | ||
self.slots.append(slot) | ||
len_active_slots += 1 | ||
|
||
batch = None | ||
if len_active_slots > 0: | ||
# Whatever initial batch these requests came from, we always return all pending requests in a single batch | ||
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] | ||
batch = self._cached_batch(self.batch_id, request_ids) | ||
else: | ||
logger.debug("No more pending requests") | ||
prefilled_active_slots.append(slot) | ||
|
||
cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots) | ||
self.batch_id += 1 | ||
logger.debug("Model ready for decoding") | ||
return generations, batch | ||
return generations, cached_batch | ||
|
||
def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray: | ||
pad_token_id = self.tokenizer.pad_token_id | ||
batch_size = logits.shape[0] | ||
tokens = jnp.full((batch_size, 1), pad_token_id) | ||
for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots): | ||
# Every slot might have a different selection criteria, so we are obliged to call select in a loop | ||
next_token = slot.select(logits) | ||
next_token = slot.select(logits[slot.id : slot.id + 1, :]) | ||
tokens = tokens.at[slot.id].set(next_token) | ||
return tokens | ||
|
||
|
@@ -539,10 +524,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") | ||
|
||
# Use a custom function to select the next token for each slot | ||
select_fn = jax.tree_util.Partial(self._select_from_slots) | ||
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn) | ||
# select_fn = jax.tree_util.Partial(self._select_from_slots) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Debug leftover |
||
self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots) | ||
|
||
newly_empty = [] | ||
generations = [] | ||
for slot in active_slots: | ||
# Get the next token. | ||
|
@@ -555,20 +539,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |
raise ValueError("Unexpected Slot is not ready for decoding") | ||
|
||
self._post_generate(slot, next_token, generations) | ||
if slot.empty: | ||
newly_empty.append(slot) | ||
|
||
# Remove empty slots | ||
for slot in newly_empty: | ||
self.slots.remove(slot) | ||
batch = None | ||
if len(self.slots) > 0: | ||
# Whatever initial batch these requests came from, we always return all pending requests in a single batch | ||
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] | ||
batch = self._cached_batch(next_batch_id, request_ids) | ||
else: | ||
logger.debug("No more pending requests") | ||
return generations, batch | ||
|
||
cached_batch = self._cached_batch(next_batch_id, active_slots) | ||
return generations, cached_batch | ||
|
||
def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None: | ||
"""Post-generate a slot after the generation has been completed. | ||
|
@@ -616,7 +589,13 @@ def _post_generate(self, slot: Slot, next_token: int, generations: List[Generati | |
) | ||
) | ||
|
||
def _cached_batch(self, batch_id: int, request_ids: List): | ||
def _cached_batch(self, batch_id: int, active_slots: List): | ||
"""Create a CachedBatch from the active slots. | ||
""" | ||
request_ids = [slot.request_id for slot in active_slots if slot.state == Slot.State.READY] | ||
if len(request_ids) == 0: | ||
logger.debug("No more pending requests") | ||
return None | ||
size = len(request_ids) | ||
max_tokens = size * self.model.config.sequence_length | ||
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
text-generation-inference/server/text_generation_server/version.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from pkg_resources import parse_version | ||
|
||
|
||
__version__ = "0.1.5" | ||
__version__ = "0.1.6" | ||
VERSION = parse_version(__version__) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.2.0