Skip to content

Commit

Permalink
Few more Inference Endpoints fixes (#69)
Browse files Browse the repository at this point in the history
* fix(TGI): correct clear request with a give batch id

* ci(tgi): create images when pushing on current branch

* fix(generator): raise error if prefill receives too many requests

* feat(tgi): add more prefill lenghts

Since bucketing does not work for now, we add more (small) prefill
lengths. This will increase the warmup time, but it will also allow to
speed up generation.

* Revert "ci(tgi): create images when pushing on current branch"

This reverts commit 26e1193.

* fix(test): multiple decode test require max_batch_size to be > 1

* fix(test): expected result is different when model is compiled

Compiled model results are not always very good. While this should be
better investigated later on, current solution is just to use the
non-compiled version. This results in some tests generating different
results, so expectations has been updated accordingly.

* chore: bump to version v0.1.2
  • Loading branch information
tengomucho authored Jul 8, 2024
1 parent fd29591 commit 7cce24c
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from pkg_resources import parse_version


__version__ = "0.1.1"
__version__ = "0.1.2"
VERSION = parse_version(__version__)
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
optimum_logger.setLevel("CRITICAL")

# These will do some bucketing on prefill lengths to avoid too many different sizes
PREFILL_LENGTHS = [
PREFILL_LENGTHS = list(range(6, 16)) + [
16,
32,
64,
Expand Down Expand Up @@ -446,6 +446,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
active_slots = slots[Slot.State.READY]
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s)."
f" Maximum batch size supported is: {model_batch_size}."
)
for slot in empty_slots:
self.slots.remove(slot)
# Assign each request to an empty slot
Expand Down Expand Up @@ -836,7 +842,8 @@ def return_to_caller(*data):
cached_batch = generator.filter(batch_id, request_ids)
return_to_caller(cached_batch.SerializeToString())
if command == GeneratorCommand.CLEAR:
generator.clear()
batch_id = data[0]
generator.clear(batch_id)
if command == GeneratorCommand.DELETE:
if rank == 0:
# Set agent to ready
Expand Down Expand Up @@ -902,8 +909,8 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
s_cached_batch = self.mailbox.send(GeneratorCommand.FILTER, batch_id, request_ids)[0]
return CachedBatch.FromString(s_cached_batch)

def clear(self):
self.mailbox.send(GeneratorCommand.CLEAR)
def clear(self, batch_id: Optional[int] = None):
self.mailbox.send(GeneratorCommand.CLEAR, batch_id)

def leave(self):
if self.mailbox is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch"""
raise NotImplementedError

def clear(self):
def clear(self, batch_id: Optional[int] = None):
"""Remove all requests from the generator"""
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pkg_resources import parse_version


__version__ = "0.1.1"
__version__ = "0.1.2"
VERSION = parse_version(__version__)
2 changes: 1 addition & 1 deletion text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_decode_single(params):
DecodeTestParams(
model_id="google/gemma-7b",
sequence_length=128,
expected_text="\n\nThe time is 1984. The place is Airstrip One, the British",
expected_text="\n\nThe first line of George Orwell’s <em>1984</em> is a perfect example",
),
DecodeTestParams(
model_id="mistralai/Mistral-7B-v0.3",
Expand Down
5 changes: 4 additions & 1 deletion text-generation-inference/tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_


def test_decode_multiple(model_path):
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH)
generator = TpuGenerator.from_pretrained(model_path,
revision="",
max_batch_size=2,
max_sequence_length=SEQUENCE_LENGTH)
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
Expand Down

0 comments on commit 7cce24c

Please sign in to comment.