Skip to content

Commit

Permalink
add execution time tracking to metrics and improve error handling in …
Browse files Browse the repository at this point in the history
…various components, fixes issues related to trainers being slow in #184
  • Loading branch information
erfanzar committed Jan 20, 2025
1 parent e2cf950 commit d13ecbb
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 169 deletions.
3 changes: 2 additions & 1 deletion easydel/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def tree_flatten(self):
self._loop_rows,
), {}

@classmethod
def tree_unflatten(cls, aux, children):
return cls(*children)

Expand Down Expand Up @@ -268,7 +269,7 @@ def __repr__(self):
if len(repr_src) < 500
else f" {k} : " + f"{v.__class__.__name__}(...)" + "\n"
)
except TypeError:
except (TypeError,AttributeError):
pass
return string.strip() + "\n)"

Expand Down
4 changes: 1 addition & 3 deletions easydel/inference/vinference/_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import time
import typing as tp # noqa: F401
from functools import partial

import jax
from flax import nnx as nn
Expand Down Expand Up @@ -59,7 +58,6 @@ def measure_flops(func, *args, **kwargs):
return result, flops, flops / elapsed_time, elapsed_time


@partial(jax.jit, static_argnums=(0, 3))
def basic_generation_first_iter_fn(
graphdef: EasyDeLBaseModule,
graphstate: dict,
Expand Down Expand Up @@ -90,7 +88,7 @@ def basic_generation_first_iter_fn(
return state


@partial(jax.jit, static_argnums=(0, 3))
# @partial(jax.jit, static_argnums=(0, 3))
def basic_generation_iter_fn(
graphdef: EasyDeLBaseModule,
graphstate: dict,
Expand Down
141 changes: 90 additions & 51 deletions easydel/inference/vinference/vinference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.interpreters import pxla
from jax._src.stages import Compiled
from jax.sharding import NamedSharding, PartitionSpec
from pydantic import BaseModel
Expand Down Expand Up @@ -67,6 +68,22 @@
TIME = str(datetime.fromtimestamp(time.time())).split(" ")[0]


def extract_shardings(tree, mesh=None):
if mesh is None:
mesh = pxla.thread_resources.env.physical_mesh

def cond(x):
sharding = x.sharding if hasattr(x, "sharding") else None
if isinstance(sharding, jax.sharding.PartitionSpec):
assert mesh is not None, "Mesh Can not be none (use function under with `mesh`)."
sharding = jax.sharding.NamedSharding(mesh=mesh, spec=sharding)
if not isinstance(sharding, jax.sharding.NamedSharding):
return None
return sharding

return jax.tree_util.tree_map(cond, tree)


class vInferenceMetaData(BaseModel):
inference_name: str
generation_config: vInferenceConfig
Expand Down Expand Up @@ -96,45 +113,48 @@ def __init__(
inference_name: tp.Optional[str] = None,
):
"""
Initializes the vInference class.
Args:
model: The pre-trained language model.
processor_class: The processor_class for the model.
generation_config: The generation configuration.
seed: The random seed for generation.
input_partition_spec: The partitioning specification for input data.
max_new_tokens: The maximum number of new tokens to generate.
Arguments:
model: The pre-trained language model.
processor_class: The processor_class for the model.
generation_config: The generation configuration.
seed: The random seed for generation.
input_partition_spec: The partitioning specification for input data.
max_new_tokens: The maximum number of new tokens to generate.
"""
# fmt:off

graphdef, graphstate = nn.split(model)
self.graphdef = graphdef
self.graphdef = graphdef
self.graphstate = graphstate
self.model=model
self.model = model
self.processor_class = processor_class
self.generation_config = self._init_generation_config(generation_config, max_new_tokens)
self.generation_config = self._init_generation_config(
generation_config, max_new_tokens
)
if seed is None:
seed = random.randint(0, int(1e6))
self.input_partition_spec = input_partition_spec or PartitionSpec(("dp", "fsdp"), "sp")
self.input_partition_spec = input_partition_spec or PartitionSpec(
("dp", "fsdp"), "sp"
)
self.mesh = self.model.config.mesh
self._rng_generator = GenerateRNG(seed)
self._precompile_lock = asyncio.Lock()
self._precompiled_configs = set()
self._in_compiling_process = set()
self._init_variables()
self._validate_token_ids()
self._uuid4 = uuid4().hex
self._uuid4 = uuid4().hex
self._inference_name = inference_name or self._generate_inference_name(model)
erm = os.environ.get("EASYDEL_RECORDS_METRICS", "true").lower() in ["true", "yes", "1", "on"]
# fmt:on
erm = os.environ.get("EASYDEL_RECORDS_METRICS", "true").lower() in [
"true",
"yes",
"1",
"on",
]
self._report_metrics = erm and jax.process_count() == 1
if not self._report_metrics:
logger.info("vInference-metrics is disabled")
logger.debug(f"vInference-metrics is disabled [status erm={erm}]")

self._basic_generation_first_iter_fn = basic_generation_first_iter_fn
self._basic_generation_iter_fn = basic_generation_iter_fn

@cached_property
def metrics(self):
if self._report_metrics:
Expand Down Expand Up @@ -248,10 +268,10 @@ def model_prefill_length(self) -> int:
the maximum new tokens from the model's maximum sequence length.
Returns:
int: The maximum length available for input prefill
int: The maximum length available for input prefill
Raises:
ValueError: If no maximum sequence length configuration is found
ValueError: If no maximum sequence length configuration is found
"""
possible_length_attributes = [
"granted_mask_max_position_embedding",
Expand All @@ -274,10 +294,10 @@ def _get_model_max_length(self, attributes: list[str]) -> tp.Optional[int]:
Find the first available maximum length configuration from a list of possible attributes.
Args:
attributes: tp.List of attribute names to check in order of preference
attributes: tp.List of attribute names to check in order of preference
Returns:
tp.Optional[int]: The maximum length if found, None otherwise
tp.Optional[int]: The maximum length if found, None otherwise
"""
for attr in attributes:
max_length = getattr(self.model.config, attr, None)
Expand All @@ -294,11 +314,11 @@ def _init_generation_config(
Initializes the generation configuration.
Args:
generation_config: The generation configuration.
max_new_tokens: The maximum number of new tokens to generate.
generation_config: The generation configuration.
max_new_tokens: The maximum number of new tokens to generate.
Returns:
vInferenceConfig: The initialized generation configuration.
vInferenceConfig: The initialized generation configuration.
"""
if generation_config is None:
if self.model.generation_config is not None:
Expand Down Expand Up @@ -326,7 +346,7 @@ def _init_variables(self):
spec=PartitionSpec(),
mesh=self.model.mesh,
)
self.gen_input_sharding = NamedSharding(
self.generation_input_shape = NamedSharding(
spec=PartitionSpec(self.input_partition_spec[0], None),
mesh=self.model.mesh,
)
Expand All @@ -338,8 +358,6 @@ def _init_state_non_jit(
rng: tp.Optional[PRNGKey] = None,
**model_kwargs,
):
if rng is None:
rng = self._rng_generator.rng
pad_token_id = jnp.array(self.generation_config.pad_token_id, dtype=jnp.int32)
batch_size, current_length = input_ids.shape
max_length = current_length + self.generation_config.max_new_tokens
Expand Down Expand Up @@ -379,9 +397,9 @@ def _validate_token_ids(self):
"(Set `tokenizer.pad_token_id = tokenizer.eos_token_id` if undefined"
" or (`processing_class.tokenizer.pad_token_id = processing_class.tokenizer.eos_token_id`))"
)
assert (
self.generation_config.eos_token_id is not None
), "`eos_token_id` cannot be None."
assert self.generation_config.eos_token_id is not None, (
"`eos_token_id` cannot be None."
)

def generate(
self,
Expand Down Expand Up @@ -412,6 +430,7 @@ def generate(
input_ids = jnp.array(input_ids, dtype=jnp.int32)

batch_size, sequence_length = input_ids.shape
self.precompile(batch_size=batch_size, input_tokens_length=sequence_length)
if batch_size <= 0 or sequence_length <= 0:
raise ValueError(f"Invalid input dimensions: {input_ids.shape}")

Expand All @@ -429,8 +448,6 @@ def generate(
batch_size=batch_size,
input_tokens_length=sequence_length,
id=self._uuid4,
fn1=self._basic_generation_first_iter_fn,
fn2=self._basic_generation_iter_fn,
)

# Main generation loop
Expand Down Expand Up @@ -469,7 +486,9 @@ def _prepare_generation_state(
attention_mask = jnp.asarray(attention_mask, dtype="i4", device=self.input_sharding)
input_ids = jnp.asarray(input_ids, dtype="i4", device=self.input_sharding)
model_kwargs.update({"input_ids": input_ids, "attention_mask": attention_mask})

if model_kwargs.get("rng") is None:
rng = self._rng_generator.rng
model_kwargs["rng"] = rng
return self._init_state(**model_kwargs)

def _inner_generate(
Expand Down Expand Up @@ -580,6 +599,7 @@ def _get_compile_model_kwargs(self, batch_size, input_tokens_length):
dtype="i4",
device=self.input_sharding,
),
rng=self._rng_generator.rng,
)

def _compile_and_lower_funs(self, batch_size: int, input_tokens_length: int):
Expand All @@ -592,15 +612,22 @@ def _compile_and_lower_funs(self, batch_size: int, input_tokens_length: int):
do_compile = compiled_generate_func is None or compiled_interval_func is None
if do_compile:
logger.debug("initiating state for lowering and compiling func.")
state = self._init_state(
**self._get_compile_model_kwargs(
batch_size=batch_size,
input_tokens_length=input_tokens_length,
)
wargs = self._get_compile_model_kwargs(
batch_size=batch_size,
input_tokens_length=input_tokens_length,
)

state = self._init_state(**wargs)
logger.debug("smart compiling `first_iter_fn`")
logger.debug("lowering `first_iter_fn`")
first_iter_fn_lowered = basic_generation_first_iter_fn.lower(
first_iter_fn_lowered = jax.jit(
basic_generation_first_iter_fn,
static_argnums=(0, 3),
in_shardings=(
extract_shardings(self.graphstate),
extract_shardings(state),
),
).lower(
self.graphdef,
self.graphstate,
state,
Expand All @@ -613,10 +640,21 @@ def _compile_and_lower_funs(self, batch_size: int, input_tokens_length: int):
)
logger.debug("smart compiling `iter_fn`")
logger.debug("lowering `iter_fn`")
iter_fn_lowered = basic_generation_iter_fn.lower(
sample_state = compiled_generate_func(self.graphstate, state)
sample_state_shardings = extract_shardings(sample_state)
iter_fn_lowered = jax.jit(
basic_generation_iter_fn,
static_argnums=(0, 3),
in_shardings=(
extract_shardings(self.graphstate),
sample_state_shardings,
None,
),
out_shardings=sample_state_shardings,
).lower(
self.graphdef,
self.graphstate,
compiled_generate_func(self.graphstate, state),
sample_state,
self.generation_config,
self.generation_config.streaming_chunks,
)
Expand Down Expand Up @@ -649,11 +687,11 @@ def precompile(
in a cache.
Args:
batch_size: The batch size.
input_tokens_length: The length of the input tokens.
batch_size: The batch size.
input_tokens_length: The length of the input tokens.
Returns:
bool: True if precompilation was successful, False otherwise.
bool: True if precompilation was successful, False otherwise.
"""
if input_tokens_length is None:
input_tokens_length = self.model_prefill_length
Expand All @@ -677,10 +715,11 @@ def precompile(
with self._compilation_metrics_recorder():
logger.debug(f"lowering and compiling with `config` {config_key}")
self._in_compiling_process.add(config_key)
self._compile_and_lower_funs(
batch_size=batch_size,
input_tokens_length=input_tokens_length,
)
with self.mesh:
self._compile_and_lower_funs(
batch_size=batch_size,
input_tokens_length=input_tokens_length,
)
self._precompiled_configs.add(config_key)
return True

Expand Down
1 change: 1 addition & 0 deletions easydel/infra/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class LossMetrics:
chosen_rewards: tp.Optional[jax.Array] = None
rejected_rewards: tp.Optional[jax.Array] = None
other_metrics: tp.Optional[tp.Mapping[str, jax.Array]] = None
execution_time: tp.Optional[float] = None


def sigmoid_cross_entropy_with_logits(
Expand Down
9 changes: 4 additions & 5 deletions easydel/infra/mixins/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def prepare_inputs_for_generation(
"""
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)

sharding = input_ids.sharding if hasattr(input_ids, "sharding") else None
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
Expand All @@ -195,15 +195,14 @@ def prepare_inputs_for_generation(
)
else:
position_ids = jnp.broadcast_to(
jnp.arange(seq_length, dtype="i4")[None, :],
(batch_size, seq_length),
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
)

return self.prepare_inputs_for_call(
**{
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
"attention_mask": jax.device_put(extended_attention_mask, device=sharding),
"position_ids": jax.device_put(position_ids, device=sharding),
}
)

Expand Down
15 changes: 10 additions & 5 deletions easydel/layers/caching/transformer_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,16 @@ def init(
)

def __repr__(self):
return (
self.__class__.__name__
+ f"(key={self.key.shape}, value={self.value.shape}, layer_index={self.layer_index})"
)

try:
return (
self.__class__.__name__
+ f"(key={self.key.shape}, value={self.value.shape}, layer_index={self.layer_index})"
)
except AttributeError:
return (
self.__class__.__name__
+ f"(key={self.key}, value={self.value}, layer_index={self.layer_index})"
)
__str__ = __repr__


Expand Down
Loading

0 comments on commit d13ecbb

Please sign in to comment.