-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✈️ Introduce Jetstream/Pytorch in TGI #88
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
use packaging.version's parse instead of pkg_resources' parse_version.
The custom HfEngine contains functions that will allow for prefill and generate functions to use custom sampling functions.
This implementation is equivalent to the torch_xla one, but uses the Jetstream/Pytorch engine instead.
This way we can aboid trying to import torch xla.
This is just a way to provide a factory class method to create Jetstream/Pytorch or Pytorch XLA generator.
There are still some issues related to some fine-tuned models, so for now just enable only when JETSTREAM_PT is set.
5a73926
to
82849fa
Compare
For now it is possible to install dependency after optimum-tpu has been instelled, issuing this command: pip install "optimum-tpu[jetstream-pt]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
aab4506
to
c11d2bb
Compare
Also adapted other tests to avoid torch-xla generator implementaion, to avoid conflict. I also added the Jetstream/pytorch test to workflow in CI.
c11d2bb
to
07a71db
Compare
) | ||
return tokens, true_length | ||
|
||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share more insights on where the server take the request and call prefill?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Alvaro, great work! I took a first pass and left some fairly minor comments.
return False | ||
# Torch XLA should not be imported before torch_xla2 to avoid conflicts. | ||
if 'torch_xla2' not in sys.modules and 'torch_xla.core' in sys.modules: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would it make sense to emit a warning here? Like "JETSTREAM_PT
is enabled, but torch_xla2 is not installed. Falling back to torch_xla".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's actually a little trickier than that: torch_xla
cannot be imported after torch_xla
has been imported. I will add a warning.
@@ -0,0 +1,35 @@ | |||
from .generator_base import Generator | |||
from .jetstream_pt_support import check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from .jetstream_pt_support import check as should_use_jetstream
or something along this lines could be more descriptive. Possibly could just change the check
def within jetstream_pt_support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed it model_can_use_jetstream_pt
.
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length | ||
) | ||
else: | ||
from .generator import TpuGenerator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would useful to a user to log 1) when we have successfully loaded jetstream and 2) when we're falling back to the base generator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than re-implement llama/model_exportable.py
, could we implement some type of parameter transformation logic instead? That would allow us to directly use jetstream_pt's code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I tried to do at first, but if we want to support models as they are defined in transformers
, the simplest way is to extract the model parameters from the config file. In the model definition in transformers, for Llama some of the original parameters (hidden_dim
, multiple_of
and ffn_dim_multiplier
) were combined in the intermediate_size
variable. I could not see a trivial way to go back to the original values, That is why I ended up re-implementing FeedForward
, and as a consequence I ended up modifying the other classes that use that. If you think about a a way to get the original parameters back in a reliable way, then I can drop most of this and just use jetstream_pt's code.
return len(self._tokens) == 0 | ||
|
||
|
||
class TpuGeneratorJetStream(Generator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One general comment (no need for change at this point), since this is essentially re-implementing the responsibility of JetStream's orchestrator as designed, this will lose out on features like disaggregated serving and will likely result in different performance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought about it, and I agree with you that using only the engine and not the orchestrator means we will end up with different performance results. The reason why I did this was the API: the engine API is similar to TGI's model_server, while the orchestrator is not meant to interact via a Python API, but rather through gRPC, and its interface is more similar to the one in the TGI router. So interfacing the orchestrator with TGI would mean taking the TGI requests, re-encode them as requests for the jetstream orchestrator and forward them, then re-transcode the responses. So yes, at some point we might need to look at a way to integrate those, but it seems more complicated and I think we can do that later.
from jetstream_pt import engine | ||
|
||
|
||
class HfEngine(engine.PyTorchEngine): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General note (no need to respond to this within this PR), Ray support for multi-node currently lives within PyTorchRayEngine
. So as is, this won't be able to take advantage of Ray multi-host. A few options:
- [within JetStream] Consolidate
PyTorchRayEngine
withPyTorchEngine
- probably preferred since we saw issues rise because of the decoupled design (cc @FanhaiLu1) - [within TGI] Create a
RayHfEngine
or use some type of mixin
- Added warning when trying to load torch_xla2 adter torch_xla - renamed jetstream_pt_support.check to model_can_use_jetstream_pt
1f5e9c4
to
76fbf94
Compare
|
||
def __call__(self, logits: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
if self.temperature != 1.0: | ||
logits = logits / self.temperature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: what happens if temp = 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question @miladm. In that the operation will give an array with [inf, -inf] values. The generation will still give some result, though probably not the one you would expect (in my case it was as if it was using greedy search).
BTW, you will have the same division in the Jetstream sampling code.
What does this PR do?
This allows to use TGI with the
meta-llama/Llama-2-7b-hf
model using the Jetstream/Pytorch engine.This should be the starting point for a more complete integration in the future. It is not ready yet to replace the legacy implementation, in particular because:
Before submitting