-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More Jetstream Pytorch fixes, prepare for release #116
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This is not used anyway.
The cached batch returned was wrong, because the generator expects only one cache batch returned per each prefill/decode call. Also, the slot size is now fixed: this will prevent creating and destroying elements in the slot list, so to better allow further optimizations and avoid JIT compilation.
The randomness when sampling has been improved by splitting the key as suggested by the documentation of the JAX random submodule.
A GPT2 test file exists to verify the generator behaviour when using the legacy Pytorch/XLA code, so now this test has been added to verify the same behaviour on the Jetstream/Pytorch counterpart.
The Jetstream/Pt engine allows to pass a callback when using the prefill and generate methods. This callback is used to sample the generated token with custom function, but the caller function is JIT'ed, making a strong constraint on the callback signature. So far the callback was compiled on the first call, making it impossible to change the sampling algorithm on different requests. This commit fixes this issue by subclassing the PytorchEngine class and defining a new `prefill_ex` method that is not JIT'ed. The model calls are still compiled, so the performance should not be noticeably affected.
Minor version is increased mainly because of Jetstream Pytorch support on TGI.
2e3208e
to
d9e89be
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits but otherwise LGTM. Thanks !
optimum/tpu/version.py
Outdated
@@ -15,5 +15,5 @@ | |||
from packaging.version import parse | |||
|
|||
|
|||
__version__ = "0.1.5" | |||
__version__ = "0.1.6" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.2.0
@@ -524,8 +524,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |||
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") | |||
|
|||
# Use a custom function to select the next token for each slot | |||
select_fn = jax.tree_util.Partial(self._select_from_slots) | |||
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn) | |||
# select_fn = jax.tree_util.Partial(self._select_from_slots) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug leftover
text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py
Show resolved
Hide resolved
text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py
Show resolved
Hide resolved
Except for that, it looks good to me. Maybe for the testing part, I would have broken down the big tests into multiple smaller ones. This would make them easier to maintain and help pinpoint issues more easily, as the current tests cover a lot of different things at once. |
Can you elaborate ? The tests are inspired by the tests I wrote for neuron TGI, and are all testing a single feature, except for the multiple decode. |
Thanks for the clarification. I think I misunderstood the goal of the test and the specific scenario it was testing. Maybe it would be nice to add some doc on the test to explain it in more detail, like https://github.com/huggingface/optimum-neuron/blob/dd60749502cd05385d6f4fe3dd884dc221e22926/text-generation-inference/tests/server/helpers.py#L83C1-L85C8 |
@baptistecolle I just added the docstring to test, hoping to make it simpler to understand. |
What does this PR do?
This PR includes several fixes for Jetstream Pytorch TGI implementation, including corrected support for batches and sampling strategy switch support.
This finally prepares for the upcoming 0.2.0 release, the first that will officially support Jetstream Pytorch TGI.
Before submitting