Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Jetstream Pytorch fixes, prepare for release #116

Merged
merged 12 commits into from
Nov 20, 2024
Merged

Conversation

tengomucho
Copy link
Collaborator

What does this PR do?

This PR includes several fixes for Jetstream Pytorch TGI implementation, including corrected support for batches and sampling strategy switch support.
This finally prepares for the upcoming 0.2.0 release, the first that will officially support Jetstream Pytorch TGI.

Before submitting

  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The cached batch returned was wrong, because the generator expects only
one cache batch returned per each prefill/decode call.
Also, the slot size is now fixed: this will prevent creating and
destroying elements in the slot list, so to better allow further
optimizations and avoid JIT compilation.
The randomness when sampling has been improved by splitting the key as
suggested by the documentation of the JAX random submodule.
A GPT2 test file exists to verify the generator behaviour when using the
legacy Pytorch/XLA code, so now this test has been added to verify the
same behaviour on the Jetstream/Pytorch counterpart.
The Jetstream/Pt engine allows to pass a callback when using the prefill
and generate methods. This callback is used to sample the generated
token with custom function, but the caller function is JIT'ed, making a
strong constraint on the callback signature. So far the callback was
compiled on the first call, making it impossible to change the sampling
algorithm on different requests.
This commit fixes this issue by subclassing the PytorchEngine class and
defining a new `prefill_ex` method that is not JIT'ed. The model calls
are still compiled, so the performance should not be noticeably
affected.
Minor version is increased mainly because of Jetstream Pytorch support
on TGI.
Copy link

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits but otherwise LGTM. Thanks !

@@ -15,5 +15,5 @@
from packaging.version import parse


__version__ = "0.1.5"
__version__ = "0.1.6"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.2.0

@@ -524,8 +524,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")

# Use a custom function to select the next token for each slot
select_fn = jax.tree_util.Partial(self._select_from_slots)
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn)
# select_fn = jax.tree_util.Partial(self._select_from_slots)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug leftover

@baptistecolle
Copy link
Contributor

Except for that, it looks good to me.

Maybe for the testing part, I would have broken down the big tests into multiple smaller ones. This would make them easier to maintain and help pinpoint issues more easily, as the current tests cover a lot of different things at once.

@dacorvo
Copy link

dacorvo commented Nov 20, 2024

Maybe for the testing part, I would have broken down the big tests into multiple smaller ones. This would make them easier to maintain and help pinpoint issues more easily, as the current tests cover a lot of different things at once.

Can you elaborate ? The tests are inspired by the tests I wrote for neuron TGI, and are all testing a single feature, except for the multiple decode.
This longer test verifies that the decoding of multiple sequences that are not ending at the same time works, so it cannot be broken down into smaller pieces.

@baptistecolle
Copy link
Contributor

baptistecolle commented Nov 20, 2024

Maybe for the testing part, I would have broken down the big tests into multiple smaller ones. This would make them easier to maintain and help pinpoint issues more easily, as the current tests cover a lot of different things at once.

Can you elaborate ? The tests are inspired by the tests I wrote for neuron TGI, and are all testing a single feature, except for the multiple decode. This longer test verifies that the decoding of multiple sequences that are not ending at the same time works, so it cannot be broken down into smaller pieces.

Thanks for the clarification. I think I misunderstood the goal of the test and the specific scenario it was testing. Maybe it would be nice to add some doc on the test to explain it in more detail, like https://github.com/huggingface/optimum-neuron/blob/dd60749502cd05385d6f4fe3dd884dc221e22926/text-generation-inference/tests/server/helpers.py#L83C1-L85C8

@tengomucho
Copy link
Collaborator Author

@baptistecolle I just added the docstring to test, hoping to make it simpler to understand.

@tengomucho tengomucho merged commit 1fc59ce into main Nov 20, 2024
5 checks passed
@tengomucho tengomucho deleted the jetstream-switch branch November 20, 2024 12:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants