JetStream is a throughput and memory optimized engine for LLM inference on XLA devices, starting with TPUs (and GPUs in future -- PRs welcome).
Currently, there are two reference engine implementations available -- one for Jax models and another for Pytorch models.
- Git: https://github.com/google/maxtext
- README: https://github.com/google/JetStream/blob/main/docs/online-inference-with-maxtext-engine.md
- Git: https://github.com/google/jetstream-pytorch
- README: https://github.com/google/jetstream-pytorch/blob/main/README.md
- Online Inference with MaxText on v5e Cloud TPU VM [README]
- Online Inference with Pytorch on v5e Cloud TPU VM [README]
- Serve Gemma using TPUs on GKE with JetStream
- Benchmark JetStream Server
- Observability in JetStream Server
- Profiling in JetStream Server
- JetStream Standalone Local Setup
pip install -r requirements.txt
Use the following commands to run a server locally:
# Start a server
python -m jetstream.core.implementations.mock.server
# Test local mock server
python -m jetstream.tools.requester
# Load test local mock server
python -m jetstream.tools.load_tester
# Test JetStream core orchestrator
python -m unittest -v jetstream.tests.core.test_orchestrator
# Test JetStream core server library
python -m unittest -v jetstream.tests.core.test_server
# Test mock JetStream engine implementation
python -m unittest -v jetstream.tests.engine.test_mock_engine
# Test mock JetStream token utils
python -m unittest -v jetstream.tests.engine.test_token_utils
python -m unittest -v jetstream.tests.engine.test_utils