Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jetstream by default #118

Merged
merged 11 commits into from
Nov 27, 2024
Merged

Jetstream by default #118

merged 11 commits into from
Nov 27, 2024

Conversation

tengomucho
Copy link
Collaborator

What does this PR do?

This makes all the changes to allow having the Jetstream Pytorch engine to be the default backend for TGI on TPUs.
This backend is reliable and performant and give the best throughput on TGI.

Implementation is slightly different, so a separate test is added.
Most tests work for both, except for the continuous batching one.
This allows to remove the old GPT2 based tests, that are quite slow and
do not use any sharding or KV cache, so they might not really be
representative of most relevant models on TGI.
There are equivalent tests now on the TinyLlama model, that run faster,
use the KV cache and sharding.
The only test that does not have an equivalence is the continuous
batching one, but the test was not working for most other models, so I
prefer to remove it anyway, as having it passing was not representative
anyway of the current state.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Now that the engine is stable and tested, its engine is set as the
default one for TGI.
@tengomucho tengomucho marked this pull request as ready for review November 22, 2024 15:39
Copy link

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I got lost in your changes: can you summarize how tests are now supposed to work ?

ids=["spaces", "chinese-utf8", "emojis"],
)
def test_decode_streaming_jetstream(tokenizer, input_text, generated_text):
if not jetstream_pt_available():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you could have created a decorator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the test to avoid repetitions.

assert generations[0].tokens.texts == [" the"]


def test_prefill_truncate_jetstream():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fail to see the difference between the two tests: I don't think it was required to add the 'jetstream' one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two tests are identical in behaviour, but if jetstream is loaded the other test will fail to run correctly because of incompatibility on the dependencies when using some features of pytorch (i.e.: multiprocessing).
I just have two identical tests, but one is going to be run when jetstream is enabled, the other one will be skipped, and when jetstream is disabled it will be the other way around.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not only identical in behaviour: this is the same test with two different names ... What am I missing ?

_test_continuous_batching_two_requests(model_path)


"""NOTE: This test does not work on PyTorch/XLA, because of the way
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should adapt the test to make it actually useful for the XLA configuration.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my tests, with BF16 and KV cache, I was not able to get this test working. I think there might be an issue on the way KV cache is implemented, because this test is successful on the Jetstream backend with BF16 on the same hardware. This is the reason why I left the test there, as a reminder that this should be done later on, but that it does not really work as expected for now.

So far filtering was done using the name of the test. Now the selection
is done using a custom marker, that allows for clearer filtering.
Copy link

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Thank you for this pull-request.

# Skip tests that require torch xla but not jetstream
if "torch_xla" in marker_names and "jetstream" not in marker_names:
if jetstream_pt_enabled:
pytest.skip("Jetstream PyTorch must be disabled")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would find it clearer to say sthg like "Jetstream is enabled: xla test will fail".

@tengomucho
Copy link
Collaborator Author

I think I got lost in your changes: can you summarize how tests are now supposed to work ?

@dacorvo as discussed offline, the idea is to change the default backend of TPU TGI from torch xla to jetstream.
I just updated the tests so they use clearer markers to check if the backend they are running is correctly selected.

For some reason the env var was not carried on (though Jetstream was
disabled anyway). Moving the variable to the command line invocation
will remove a warning in the logs.
Copy link
Contributor

@baptistecolle baptistecolle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JETSTREAM_PT_DISABLE=1

optimum/tpu/jetstream_pt_support.py Show resolved Hide resolved
Makefile Outdated Show resolved Hide resolved
text-generation-inference/tests/test_prefill_truncate.py Outdated Show resolved Hide resolved
.github/workflows/test-pytorch-xla-tpu.yml Outdated Show resolved Hide resolved
Some tests result change when operations are done in a slightly
different way. This has happened now with the torch xla tests, resulting
in different results on the CI.
To avoid this, now tests compare the obtained token and text is
different from the one obtained when running with greedy search.
@tengomucho tengomucho merged commit 8c2c199 into main Nov 27, 2024
5 checks passed
@tengomucho tengomucho deleted the jetstream-by-default branch November 27, 2024 10:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants