Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Jetstream Pytorch fixes, prepare for release #116

Merged
merged 12 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch

on:
push:
branches: [ main ]
paths:
- "text-generation-inference/**"
pull_request:
branches: [ main ]
paths:
Expand Down
2 changes: 1 addition & 1 deletion optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from packaging.version import parse


__version__ = "0.1.5"
__version__ = "0.2.0"
VERSION = parse(__version__)
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,35 @@

if TYPE_CHECKING:
from transformers import PretrainedConfig

from transformers import AutoConfig

from .compatibility import model_can_use_jetstream_pt
from .models import GemmaModel, LlamaModel, MixtralModel


class OptimumJetstreamEngine(PyTorchEngine):
"""This is essentially the same as the PytorchEngine, but it also supports a callback for sampling in prefill and
generation that can change on each request while not needing to be JIT'ed.
"""
prefill_ex = PyTorchEngine.prefill
baptistecolle marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
pt_model: torch.nn.Module,
env: JetEngineEnvironment,
weights=None,
):
super().__init__(pt_model, env, weights)
# Call model prefill and generate needs to be JIT'ed, because it is called with sharded notations, and it would
# otherwise not work for some models.
self._call_model_prefill = jax.jit(
self._call_model_prefill,
)
self._call_model_generate = jax.jit(
self._call_model_generate,
)

def _get_head_dim(config: "PretrainedConfig") -> int:
if hasattr(config, "head_dim"):
return config.head_dim
Expand Down Expand Up @@ -174,8 +197,9 @@ def create_engine(
logger.info(f"Quantization took {end - start:.2f} seconds")
model_weights = model.state_dict()
sharded_weights = shard_weights(env, model_weights, weight_shardings)
return PyTorchEngine(
engine = OptimumJetstreamEngine(
pt_model=model,
env=env,
weights=torchjax.from_torch_with_copy(sharded_weights),
)
return engine
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def clear(self):
self._state = Slot.State.EMPTY
self._batch_id = None
self._request_id = None
self._inputs = ""
self._generation_config = None
self._tokens = []
self._selector = None
Expand Down Expand Up @@ -263,9 +262,8 @@ def __init__(
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids

# Slots are empty to begin with, they will be populated as new batches arrive
self.slots = []
# Slots number is static, it cannot grow over the size of the batch
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)]
self.batch_id = 0
# Note: this index will _never_ be decremented, and that's fine.
self.slot_index = 0
Expand Down Expand Up @@ -363,13 +361,11 @@ def warmup(self, batch: Batch) -> int:
seq_len = self.engine.env.seq_len
return batch_size * seq_len

def _get_slot_id(self):
"""Get the next available slot id."""
batch_size = self.engine.env.batch_size
used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY]
for i in range(batch_size):
if i not in used_ids:
return i
def _get_slot(self):
"""Get the next available slot."""
for slot in self.slots:
if slot.state == Slot.State.EMPTY:
return slot
# if we reach this point, all slots were used - this should not happen
raise ValueError("All slots are used, but we should have stopped earlier")

Expand All @@ -388,6 +384,8 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]:
"""
if max_length == 0:
max_length = self.model.config.sequence_length
# Remove one to max_length because BOS is going to be added when padding
max_length -= 1
input_ids = self.tokenizer.encode(
text,
return_tensors="np",
Expand Down Expand Up @@ -417,14 +415,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""

slots = {state: [] for state in Slot.State}
for slot in self.slots:
slots[slot.state].append(slot)
len_active_slots = len(slots[Slot.State.READY])
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
for slot in empty_slots:
self.slots.remove(slot)
active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY]
len_active_slots = len(active_slots)

len_requests = len(batch.requests)
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len_active_slots + len_requests:
Expand All @@ -439,10 +432,10 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# Assign each request to an empty slot
logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)")
generations = []

prefilled_active_slots = []
for request in batch.requests:
# Dynamically create a new slot for each request
slot = Slot(self._get_slot_id(), self.tokenizer)
slot = self._get_slot()
self.prefill_slot.set(slot)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
Expand All @@ -462,7 +455,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(self.prefill_slot.select)
# Ask for prefill and insert
prefill_results, _result_tokens = self.engine.prefill(
prefill_results, _result_tokens = self.engine.prefill_ex(
baptistecolle marked this conversation as resolved.
Show resolved Hide resolved
params=self.params,
padded_tokens=input_ids,
true_length=true_lengths,
Expand All @@ -473,28 +466,20 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:

self._post_generate(slot, next_token, generations)
if not slot.empty:
# append current to list of active slots
self.slots.append(slot)
len_active_slots += 1

batch = None
if len_active_slots > 0:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
batch = self._cached_batch(self.batch_id, request_ids)
else:
logger.debug("No more pending requests")
prefilled_active_slots.append(slot)

cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots)
self.batch_id += 1
logger.debug("Model ready for decoding")
return generations, batch
return generations, cached_batch

def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray:
pad_token_id = self.tokenizer.pad_token_id
batch_size = logits.shape[0]
tokens = jnp.full((batch_size, 1), pad_token_id)
for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots):
# Every slot might have a different selection criteria, so we are obliged to call select in a loop
next_token = slot.select(logits)
next_token = slot.select(logits[slot.id : slot.id + 1, :])
tokens = tokens.at[slot.id].set(next_token)
return tokens

Expand Down Expand Up @@ -539,10 +524,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")

# Use a custom function to select the next token for each slot
select_fn = jax.tree_util.Partial(self._select_from_slots)
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn)
self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots)

newly_empty = []
generations = []
for slot in active_slots:
# Get the next token.
Expand All @@ -555,20 +538,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
raise ValueError("Unexpected Slot is not ready for decoding")

self._post_generate(slot, next_token, generations)
if slot.empty:
newly_empty.append(slot)

# Remove empty slots
for slot in newly_empty:
self.slots.remove(slot)
batch = None
if len(self.slots) > 0:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
batch = self._cached_batch(next_batch_id, request_ids)
else:
logger.debug("No more pending requests")
return generations, batch

cached_batch = self._cached_batch(next_batch_id, active_slots)
return generations, cached_batch

def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None:
"""Post-generate a slot after the generation has been completed.
Expand Down Expand Up @@ -616,7 +588,13 @@ def _post_generate(self, slot: Slot, next_token: int, generations: List[Generati
)
)

def _cached_batch(self, batch_id: int, request_ids: List):
def _cached_batch(self, batch_id: int, active_slots: List):
"""Create a CachedBatch from the active slots.
"""
request_ids = [slot.request_id for slot in active_slots if slot.state == Slot.State.READY]
if len(request_ids) == 0:
logger.debug("No more pending requests")
return None
size = len(request_ids)
max_tokens = size * self.model.config.sequence_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,13 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
"""
scores = self.logits_processor(input_ids, logits)
if self.mode == GenerationMode.SAMPLE:
return self._sample(scores)
# split the key to avoid reusing the same key for multiple samples
subkey, self.key = jax.random.split(self.key)
return self._sample(scores, subkey)
else:
return jnp.argmax(scores, axis=-1)

def _sample(self, scores: jnp.ndarray) -> jnp.ndarray:
def _sample(self, scores: jnp.ndarray, key) -> jnp.ndarray:
do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1]
do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0

Expand All @@ -188,14 +190,14 @@ def _sample(self, scores: jnp.ndarray) -> jnp.ndarray:
scores,
self.logits_warper.top_k,
self.logits_warper.temperature,
self.key,
key,
)
elif do_top_p:
return sampling_utils.sample_nucleus_topp_logits(
scores,
self.logits_warper.top_p,
self.logits_warper.temperature,
self.key,
key,
)

return jax.random.categorical(self.key, scores / self.logits_warper.temperature)
return jax.random.categorical(key, scores / self.logits_warper.temperature)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pkg_resources import parse_version


__version__ = "0.1.5"
__version__ = "0.2.0"
VERSION = parse_version(__version__)
Loading
Loading