Skip to content

hacky way to test aot in jetstream for interleaved mode #259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
import numpy as np

log_level = os.getenv("LOG_LEVEL", "WARNING").upper()
from jax.experimental import layout as jax_layout
DLL = jax_layout.DeviceLocalLayout
Layout = jax_layout.Layout

log_level = os.getenv("LOG_LEVEL", "DEBUG").upper()

logger = logging.getLogger("JetstreamLogger")
logger.propagate = False
Expand Down Expand Up @@ -405,6 +409,26 @@ def __init__(

self._jax_padding = jax_padding

##### Auto layout compile for interleaved engine
self._generate_executables = [None for _ in self._generate_engines]
self._cached_insert = [None for _ in self._generate_engines]
self._cached_prefill = [None for _ in self._prefill_engines]
if self._interleaved_mode:
for idx in range(len(self._generate_engines)):
logger.debug("Compiling interleaved engine {}".format(idx))
engine = self._generate_engines[idx]
params = self._generate_params[idx]
engine, params, gen_fn, prefill_fn, insert_fn = self._auto_layout_compile(engine, params)

self._prefill_engines[idx] = engine
self._generate_engines[idx] = engine
self._prefill_params[idx] = params
self._generate_params[idx] = params
self._cached_prefill[idx] = prefill_fn
self._cached_insert[idx] = insert_fn
self._generate_executables[idx] = gen_fn


# Create all threads
self._prefill_threads = [
JetThread(
Expand Down Expand Up @@ -670,6 +694,56 @@ def _do_chunked_prefill(

return prefill_result, first_token

def _auto_layout_compile(self, engine, params):
logger.debug("Compiling generate function")
generate_executable, params, decode_state_executable = engine.aot_compile(
params, pass_rng_shape=False
)
decode_state = decode_state_executable(None)

# prefill
interesting_buckets = [
64,
128,
256,
512,
1024,
]

cached_prefill = {}
cached_insert = {}
for length in interesting_buckets:
i32_scalar = jax.ShapeDtypeStruct((), int)
logger.debug("Compiling prefill: %d", length)
input_data = jax.ShapeDtypeStruct((length,), jax.numpy.dtype("int32"))

cached_prefill[length] = (
jax.jit(
engine.prefill_aot,
in_shardings=(engine.param_layouts, None, None),
out_shardings=(Layout(DLL.AUTO), Layout(DLL.AUTO)),
).lower(params, input_data, i32_scalar)
).compile(compiler_options=None)

logger.debug("Generate dummy prefix: %d", length)
dummy_tokens = jax.numpy.ones(shape=(length,), dtype=jax.numpy.dtype("int32"))
prefix_shapes = jax.eval_shape(engine.prefill_aot, params, dummy_tokens, 1)

logger.debug("Compiling insert: %d", length)
prefill_output_layout, _ = cached_prefill[length].output_layouts
logger.debug("Prefill output layout: {}".format(prefill_output_layout))
logger.debug("Prefix shapes: {}".format(prefix_shapes))
i32_scalar = jax.ShapeDtypeStruct((), int)
cached_insert[length] = (
jax.jit(
engine.insert,
in_shardings=(prefill_output_layout, engine.decode_state_layouts, None),
out_shardings=(engine.decode_state_layouts),
donate_argnames=("decode_state"),
).lower(prefix_shapes[0], engine.decode_state_shapes, i32_scalar)
).compile(compiler_options=None)
return engine, params, generate_executable, cached_prefill, cached_insert

def _prefill_thread(self, idx: int):
"""Thread which runs in the background performing prefills."""
logger.info("Spinning up prefill thread %d.", idx)
Expand All @@ -683,6 +757,12 @@ def _prefill_thread(self, idx: int):
thread_name = f"Prefill thread {idx}"
ThreadDebugLog(thread_name, f"Prefill params {idx} loaded.")

if not self.interleaved:
prefill_engine, prefill_params, gen_fn, prefill_fn, insert_fn = self._auto_layout_compile(
prefill_engine, prefill_params
)
self._cached_prefill[idx] = prefill_fn

while self.live:
my_transfer_backlog = self._transfer_backlogs[idx]
# The prefill thread can just sleep until it has work to do.
Expand Down Expand Up @@ -759,10 +839,11 @@ def _prefill_thread(self, idx: int):
)
else:
# Compute new kv cache for the prefill_content.
prefill_result, first_token = prefill_engine.prefill(
params=final_prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
assert padded_tokens.shape[0] in self._cached_prefill[idx]
prefill_result, first_token = self._cached_prefill[idx][padded_tokens.shape[0]](
final_prefill_params,
padded_tokens,
true_length,
)

request.complete = np.zeros(
Expand Down Expand Up @@ -967,10 +1048,11 @@ def _insert_if_possible(
else:
break

decode_state = generate_engine.insert(
length = new_request.prefill_result['cache']['decoder']['layers_0']['self_attention']['KVCache_0']['cache_prefill_segment_id'].value.shape[1]
decode_state = self._cached_insert[idx][length](
new_request.prefill_result,
decode_state,
slot=slot,
slot,
# request_id=new_request.request_id,
)
ThreadDebugLog(
Expand Down Expand Up @@ -1115,9 +1197,15 @@ def _generate_thread(self, idx: int):
# Keep track of what step tokens were generated at.
generate_timestep = 0
# State to store things like running kv cache in.
decode_state = generate_engine.init_decode_state()

decode_state = self.decode_state
generate_params = self._generate_params[idx]

if not self.interleaved:
generate_engine, generate_params, gen_fn, prefill_fn, insert_fn = self._auto_layout_compile(
generate_engine, generate_params
)
self._generate_executables[idx] = gen_fn

thread_name = f"Generate thread {idx}"
ThreadDebugLog(thread_name, f"Generate params {idx} loaded.")
time_of_last_generate = time.time()
Expand Down Expand Up @@ -1178,8 +1266,8 @@ def _generate_thread(self, idx: int):
), "At this point we must have some requests inserted into the slots."

# Now we actually take a generate step on requests in the slots.
decode_state, sampled_tokens = generate_engine.generate(
generate_params, decode_state
decode_state, sampled_tokens = self._generate_executables[idx](
generate_params, decode_state, None
)
sampled_tokens.copy_to_host_async()
# Respond to detokenization backpressure.
Expand Down
Loading