From 4054dc46b8bad9932903322acf5fb90712febdb2 Mon Sep 17 00:00:00 2001 From: aman2930 Date: Mon, 14 Apr 2025 18:19:42 +0000 Subject: [PATCH] Squashed commit --- README.md | 3 + jetstream/core/lora/adapter_tensorstore.py | 593 ++++++++ .../core/lora/multi_lora_inference_api.py | 126 ++ jetstream/core/metrics/prometheus.py | 38 + jetstream/core/orchestrator.py | 234 +++- jetstream/core/proto/jetstream.proto | 6 +- jetstream/core/proto/jetstream_pb2.py | 52 +- jetstream/core/proto/jetstream_pb2_grpc.py | 26 +- .../core/proto/multi_lora_decoding.proto | 67 + .../core/proto/multi_lora_decoding_pb2.py | 42 + .../proto/multi_lora_decoding_pb2_grpc.py | 169 +++ jetstream/core/server_lib.py | 78 +- jetstream/engine/mock_engine.py | 27 + .../core/lora/test_adapter_tensorstore.py | 1199 +++++++++++++++++ .../core/lora/test_multi_lora_manager.py | 247 ++++ jetstream/tests/core/test_orchestrator.py | 212 +++ jetstream/tests/core/test_server.py | 1 + .../tools/maxtext/model_ckpt_conversion.sh | 78 +- .../tools/multi_adapter_service_client.py | 193 +++ .../tools/multi_lora_decode_requester.py | 254 ++++ jetstream/tools/requester.py | 8 + 21 files changed, 3585 insertions(+), 68 deletions(-) create mode 100644 jetstream/core/lora/adapter_tensorstore.py create mode 100644 jetstream/core/lora/multi_lora_inference_api.py create mode 100644 jetstream/core/proto/multi_lora_decoding.proto create mode 100644 jetstream/core/proto/multi_lora_decoding_pb2.py create mode 100644 jetstream/core/proto/multi_lora_decoding_pb2_grpc.py create mode 100644 jetstream/tests/core/lora/test_adapter_tensorstore.py create mode 100644 jetstream/tests/core/lora/test_multi_lora_manager.py create mode 100644 jetstream/tools/multi_adapter_service_client.py create mode 100644 jetstream/tools/multi_lora_decode_requester.py diff --git a/README.md b/README.md index 62959c46..3fe3feac 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,9 @@ python -m unittest -v jetstream.tests.core.test_orchestrator # Test JetStream core server library python -m unittest -v jetstream.tests.core.test_server +# Test JetStream lora adapter tensorstore +python -m unittest -v jetstream.tests.core.lora.test_adapter_tensorstore + # Test mock JetStream engine implementation python -m unittest -v jetstream.tests.engine.test_mock_engine diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py new file mode 100644 index 00000000..ffadd079 --- /dev/null +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -0,0 +1,593 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manages the list of fine-tuned adapters loaded on top of the base model for serving. +""" + +import logging +import dataclasses + +import jax +import jax.numpy as jnp +from flax import struct +import time +import asyncio +import functools +from typing import Dict, Optional, Any +import numpy as np +from jetstream.engine import engine_api +from enum import Enum + + +def _get_size_of_pytree(params): + """Get the size of the PyTree.""" + + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) + total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) + return total_bytes + + +def _as_np_array(params): + """Create a new PyTree with Tensors as np.array.""" + + def convert_if_jnp(leaf): + return np.array(leaf) + + return jax.tree_util.tree_map(convert_if_jnp, params) + + +def _as_jnp_array(params): + """Create a new PyTree with Tensors as jnp.array.""" + + def convert_if_np(leaf): + return jnp.array(leaf) + + return jax.tree_util.tree_map(convert_if_np, params) + + +class AdapterStatus(str, Enum): + UNLOADED = "unloaded" + LOADING = "loading" + LOADED_HBM = "loaded_hbm" + LOADED_CPU = "loaded_cpu" + + +@dataclasses.dataclass +class AdapterMetadata: + adapter_id: str + adapter_path: str + status: AdapterStatus = AdapterStatus.UNLOADED + size_hbm: int = 0 # Size in HBM (bytes) + size_cpu: int = 0 # Size in CPU RAM (bytes) + last_accessed: float = 0.0 # timestamp + config: Dict[str, Any] = dataclasses.field(default_factory=dict) + loading_event: Optional[asyncio.Event] = None # Add Event + + +class AdapterTensorStore: + """ + Manages the storage and retrieval of LoRA adapter weights, handling + placement in either HBM (High Bandwidth Memory, on the TPU/GPU) or CPU RAM. + + This class implements an LRU (Least Recently Used) eviction policy + to manage memory usage. It supports asynchronous loading and unloading + of adapters to avoid blocking the main inference thread. + + Args: + engine: The instance of the JetStream Engine for this AdapterTensorStore + adapters_dir_path: Location of all the adapters + hbm_memory_budget (int): The maximum amount of HBM (in bytes) to use for + storing LoRA adapter weights. + cpu_memory_budget (int): The maximum amount of CPU RAM (in bytes) to use + for storing LoRA adapter weights. + """ + + def __init__( + self, + engine: engine_api.Engine, + adapters_dir_path: str, + hbm_memory_budget: int, + cpu_memory_budget: int, + ): + """Initializes the AdapterTensorStore.""" + self.engine = engine # Possibly MaxEngine object + self.adapters_dir_path = adapters_dir_path.rstrip( + "/" + ) # All Adapters path without trailing `/` + self.hbm_memory_budget = hbm_memory_budget + self.cpu_memory_budget = cpu_memory_budget + self.adapter_registry: Dict[str, AdapterMetadata] = {} # All known adapters + self.loaded_adapters_hbm: Dict[str, jnp.ndarray] = ( + {} + ) # adapter_id -> LoRA params (in HBM) + self.loaded_adapters_cpu: Dict[str, np.ndarray] = ( + {} + ) # adapter_id -> LoRA params (in CPU RAM) + self.current_hbm_usage: int = 0 + self.current_cpu_usage: int = 0 + self.running_requests: int = ( + 0 # Number of async tasks which are in "loading" state + ) + self.lock = asyncio.Lock() # Use an asyncio Lock for thread safety + + # --- Unsafe Internal methods which assumes that lock is held --- + def _unsafe_transfer_to_hbm(self, adapter_id: str): + """Internal: Transfers an adapter from CPU RAM to HBM. Assumes lock is held.""" + if adapter_id not in self.loaded_adapters_cpu: + raise ValueError(f"Adapter '{adapter_id}' not loaded in CPU RAM.") + + metadata = self.adapter_registry[adapter_id] + + # Check if we have enough space in HBM; evict if necessary + while (self.current_hbm_usage + metadata.size_hbm) > self.hbm_memory_budget: + if not self._evict(from_hbm=True): + raise RuntimeError( + "Not enough HBM to transfer adapter, and HBM eviction failed." + ) + + # Move from CPU RAM to HBM + logging.info(f"Transferring {adapter_id} from CPU to HBM.") + self.loaded_adapters_hbm[adapter_id] = _as_jnp_array( + self.loaded_adapters_cpu[adapter_id] + ) # Convert to JAX array + + # TODO: We can avoid deleting cpu_loaded adapters if RAM is not a concern + del self.loaded_adapters_cpu[adapter_id] + + self.current_cpu_usage -= metadata.size_cpu + self.current_hbm_usage += metadata.size_hbm + + metadata.status = AdapterStatus.LOADED_HBM + metadata.last_accessed = time.time() # Update time on transfer + + def _unsafe_transfer_to_cpu(self, adapter_id: str): + """Internal: Transfers an adapter from HBM to CPU RAM. Assumes lock is held.""" + + if adapter_id not in self.loaded_adapters_hbm: + raise ValueError(f"Adapter '{adapter_id}' not loaded in HBM.") + + metadata = self.adapter_registry[adapter_id] + + # Check if we have enough space in CPU; evict if necessary. + while (self.current_cpu_usage + metadata.size_cpu) > self.cpu_memory_budget: + if not self._evict(from_hbm=False): + raise RuntimeError( + "Not enough CPU RAM to transfer adapter, and CPU eviction failed." + ) + + # Move from HBM to CPU RAM + logging.info(f"Transferring {adapter_id} from HBM to CPU.") + self.loaded_adapters_cpu[adapter_id] = _as_np_array( + self.loaded_adapters_hbm[adapter_id] + ) + del self.loaded_adapters_hbm[adapter_id] + + self.current_hbm_usage -= metadata.size_hbm + self.current_cpu_usage += metadata.size_cpu + + metadata.status = AdapterStatus.LOADED_CPU + metadata.last_accessed = time.time() # Update time on transfer + + def _unsafe_unload_adapter(self, adapter_id: str): + """Internal: Unload adapter. Assumes lock is held.""" + + if adapter_id not in self.adapter_registry: + raise ValueError(f"Adapter with ID '{adapter_id}' not found.") + + metadata = self.adapter_registry[adapter_id] + if metadata.status == AdapterStatus.UNLOADED: + return + + logging.info(f"Unloading adapter {adapter_id}.") + if metadata.status == AdapterStatus.LOADED_HBM: + del self.loaded_adapters_hbm[adapter_id] + self.current_hbm_usage -= metadata.size_hbm + elif metadata.status == AdapterStatus.LOADED_CPU: + del self.loaded_adapters_cpu[adapter_id] + self.current_cpu_usage -= metadata.size_cpu + + metadata.status = AdapterStatus.UNLOADED + metadata.last_accessed = time.time() + metadata.size_hbm = 0 + metadata.size_cpu = 0 + + # --- Public Methods (Acquire lock, then call unsafe methods) --- + + async def register_adapter( + self, + adapter_id: str, + adapter_path: str | None = None, + adapter_config: Dict[str, Any] | None = None, + ): + """Registers a new LoRA adatper.""" + """ + Registers a LoRA adapter with the TensorStore. This also loads the adapter; + IF called without adapter_config. Because in this case, it needs + to get adapter_config from the engine's load_single_adapter() call, which + also provides the adapter_params. So in that case it is beneficial to load + the adapter to HBM. This call path is expected only from the direct inference + request. + OTHERWISE, it simply adds metadata about the adapter to the registry. + + Args: + adapter_id (str): A unique identifier for the adapter. + adapter_path (str): The path to the adapter weights (file or directory). + adapter_config (dict): Config of the loRA adapter. + + Raises: + ValueError: If an adapter with the same ID is already registered. + """ + if adapter_id in self.adapter_registry: + logging.warning(f"Adapter with ID '{adapter_id}' already registered.") + return + + if adapter_path is None: + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + + adapter_params = None + if adapter_config is None: + # This call happens *outside* the lock for potentially slow I/O. + loop = asyncio.get_running_loop() + adapter_params, adapter_config = await loop.run_in_executor( + None, functools.partial(self.engine.load_single_adapter, adapter_path) + ) + + if adapter_config is None: + raise ValueError(f"Failed to read adapter_config from {adapter_path}") + + async with self.lock: + # Double check registration inside lock + if adapter_id in self.adapter_registry: + logging.warning(f"Adapter '{adapter_id}' registered concurrently.") + return + + self.adapter_registry[adapter_id] = AdapterMetadata( + adapter_id=adapter_id, + adapter_path=adapter_path, + config=adapter_config, + ) + + # If params were loaded outside lock, now load them into store. + if adapter_params is not None: + await self.load_adapter(adapter_id, adapter_params, True) + + async def get_hbm_loaded_adapters(self): + """Returns a comma separated list of adapters loaded into HBM.""" + + hbm_loaded_adapters = [] + + async with self.lock: + for adapter_id, metadata in self.adapter_registry.items(): + if metadata.status == AdapterStatus.LOADED_HBM: + hbm_loaded_adapters.append(adapter_id) + + return ", ".join(hbm_loaded_adapters) + + async def load_adapter( + self, adapter_id: str, adapter_weights=None, to_hbm: bool = True + ): + """ + Loads a LoRA adapter's weights into memory (either HBM or CPU RAM). + + This method is asynchronous to avoid blocking the main thread during + potentially slow I/O operations. It handles: + - Checking if the adapter is already loaded. + - Checking if there's enough memory (and evicting if necessary). + - Loading the weights (in a separate thread). + - Updating the adapter's status and metadata. + + Args: + adapter_id (str): The ID of the adapter to load. + adapter_weights: In the form of a PyTree. + to_hbm (bool): Whether to load the adapter into HBM (True) or + CPU RAM (False). Defaults to True (HBM). + + Raises: + ValueError: If the adapter ID is not registered. + RuntimeError: If there is not enough memory to load the adapter, + and eviction fails to free up enough space. + """ + if adapter_id not in self.adapter_registry: + raise ValueError(f"Adapter with ID '{adapter_id}' not registered.") + + event_to_wait_on: Optional[asyncio.Event] = None + + async with self.lock: + metadata = self.adapter_registry[adapter_id] + + if metadata.status in ( + AdapterStatus.LOADED_HBM, + AdapterStatus.LOADED_CPU, + ): + metadata.last_accessed = time.time() + + # if already loaded in HBM and we want HBM, or + # already loaded in CPU and we want CPU, we're done. + if (to_hbm and metadata.status == AdapterStatus.LOADED_HBM) or ( + not to_hbm and metadata.status == AdapterStatus.LOADED_CPU + ): + return # Already in correct expected state + elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU: + # Transfer from cpu to hbm + self._unsafe_transfer_to_hbm(adapter_id) + return + elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM: + # Transfer from hbm to cpu + self._unsafe_transfer_to_cpu(adapter_id) + return + + # --- Handle LOADING state --- + if metadata.status == AdapterStatus.LOADING: + # Wait untill loading is done. + logging.info( + f"Adapter {adapter_id} is already loading by another task, waiting..." + ) + + # Get the event created by the first loading task + event_to_wait_on = metadata.loading_event + if event_to_wait_on is None: + # Should not happen if status is LOADING, indicates inconsistency + raise RuntimeError( + f"Inconsistent state: Adapter {adapter_id} is LOADING but has no event." + ) + + logging.info(f"Adapter {adapter_id} is loading, will wait.") + + if metadata.status == AdapterStatus.UNLOADED: # Check if it was UNLOADED + logging.info(f"Beginning load for adapter {adapter_id}...") + + metadata.loading_event = ( + asyncio.Event() + ) # Create event *before* releasing lock + metadata.status = AdapterStatus.LOADING + self.running_requests += 1 + + # ---- Wait if needed (Unlocked) ---- + if event_to_wait_on: + await event_to_wait_on.wait() + # After waiting, the original loader finished (or failed). + # Re-call load_adapter to ensure desired state (HBM/CPU) and update timestamp. + logging.info(f"Finished waiting for {adapter_id}. Re-checking state.") + await self.load_adapter(adapter_id, adapter_weights, to_hbm) + return # Recursive call handled the final state + + # --- Perform actual loading outside the main lock --- + load_successful = False + try: + if adapter_weights is None: + adapter_path = metadata.adapter_path # Use path from metadata + + # TODO: Compare performance improvements + # Option 1: Low performant (Run blocking I/O on main thread) + # adapter_weights, adapter_config = self.engine.load_single_adapter(adapter_path) + + # Option 2: Better performant + # Run blocking I/O in executor + loop = asyncio.get_running_loop() + adapter_weights, adapter_config = await loop.run_in_executor( + None, + functools.partial(self.engine.load_single_adapter, adapter_path), + ) + + if adapter_weights is None: + raise ValueError(f"Failed to load adapter_weights from {adapter_path}.") + + # Convert to JAX/NumPy outside main lock if possible (CPU heavy) + adapter_weights_as_jnp_array = _as_jnp_array(adapter_weights) + adapter_weights_as_np_array = _as_np_array(adapter_weights) + del adapter_weights + + # --- Re-acquire lock for final memory check and update --- + async with self.lock: # Critical section for memory management + metadata = self.adapter_registry[adapter_id] # Re-fetch latest metadata + + # If status changed while loading (e.g., unloaded), abort + if metadata.status != AdapterStatus.LOADING: + logging.warning( + f"Load cancelled for {adapter_id}, status changed to {metadata.status}" + ) + return + + # Get size of unified_lora_params when they are saved in HBM as JAX array + adapter_size_hbm = _get_size_of_pytree(adapter_weights_as_jnp_array) + + # Get size of unified_lora_params when they are saved in CPU RAM as NumPy array + adapter_size_cpu = _get_size_of_pytree(adapter_weights_as_np_array) + + metadata.size_hbm = adapter_size_hbm + metadata.size_cpu = adapter_size_cpu + + # --- EVICTION (if needed) --- + # Evict if necessary *before* loading into the target memory + if to_hbm: + while ( + self.current_hbm_usage + adapter_size_hbm + ) > self.hbm_memory_budget: + if not self._evict(from_hbm=True): + raise RuntimeError( + "Not enough HBM to load adapter, and eviction failed." + ) + else: # to_cpu + while ( + self.current_cpu_usage + adapter_size_cpu + ) > self.cpu_memory_budget: + if not self._evict(from_hbm=False): + raise RuntimeError( + "Not enough CPU RAM to load adapter, and eviction failed." + ) + + # Now that we have space (potentially), do the actual loading + if to_hbm: + self.loaded_adapters_hbm[adapter_id] = ( + adapter_weights_as_jnp_array # Convert the PyTree to Jax Array + ) + self.current_hbm_usage += adapter_size_hbm + metadata.status = AdapterStatus.LOADED_HBM + + else: # to cpu + self.loaded_adapters_cpu[adapter_id] = ( + adapter_weights_as_np_array # Convert the PyTree to NumPy Array + ) + self.current_cpu_usage += adapter_size_cpu + metadata.status = AdapterStatus.LOADED_CPU + + metadata.last_accessed = time.time() + load_successful = True + + except Exception as e: + async with self.lock: + metadata = self.adapter_registry[adapter_id] + metadata.status = AdapterStatus.UNLOADED # Mark as unloaded on error + if metadata.loading_event is not None: + metadata.loading_event.set() + metadata.loading_event = None # Clear the event + + raise e # Re-Raise the exception + finally: + # --- Decrement running_requests, ensure status is correct --- + async with self.lock: + metadata = self.adapter_registry[adapter_id] + self.running_requests -= 1 + + if metadata.loading_event is not None: + metadata.loading_event.set() + metadata.loading_event = None # Clear the event + + # If load failed after marking LOADING, reset status + if ( + not load_successful + and self.adapter_registry[adapter_id].status + == AdapterStatus.LOADING + ): + metadata.status = AdapterStatus.UNLOADED # Mark as unloaded on error + + async def get_lora_config( + self, adapter_id: str, load_if_not_loaded: bool = False + ): + """Getter for the LoRA adapter config.""" + metadata = self.adapter_registry.get(adapter_id) + + if load_if_not_loaded and metadata is None: + await self.register_adapter(adapter_id) + metadata = self.adapter_registry.get(adapter_id) + + if metadata is None: + raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.") + + return metadata.config + + async def get_lora_weights( + self, adapter_id, to_hbm: bool = True, load_if_not_loaded: bool = False + ): + """Retrieves the unified LoRA parameters for the given adapter IDs. + Handles HBM/CPU placement. + """ + + metadata = self.adapter_registry.get(adapter_id) + + if load_if_not_loaded and metadata is None: + await self.register_adapter(adapter_id) + metadata = self.adapter_registry.get(adapter_id) + + if metadata is None: + raise ValueError(f"LoRA adapter with id={adapter_id} is not loaded.") + + if metadata.status not in ( + AdapterStatus.LOADED_HBM, + AdapterStatus.LOADED_CPU, + ): + await self.load_adapter(adapter_id, None, to_hbm) # Start loading (async) + elif to_hbm and metadata.status == AdapterStatus.LOADED_CPU: + async with self.lock: + self._unsafe_transfer_to_hbm(adapter_id) + elif not to_hbm and metadata.status == AdapterStatus.LOADED_HBM: + async with self.lock: + self._unsafe_transfer_to_cpu(adapter_id) + + # Now all required adapters should be loaded in correct memory (HBM or CPU), get them + adapter_params = None + if to_hbm: + if adapter_id not in self.loaded_adapters_hbm: + raise RuntimeError( + f"Adapter {adapter_id} should be in HBM but wasn't found after loading." + ) + adapter_params = self.loaded_adapters_hbm[adapter_id] + else: + if adapter_id not in self.loaded_adapters_cpu: + raise RuntimeError( + f"Adapter {adapter_id} should be in CPU but wasn't found after loading." + ) + adapter_params = self.loaded_adapters_cpu[adapter_id] + + return adapter_params + + async def unload_adapter(self, adapter_id: str): + """Unloads a LoRA adapter's weights and removes it from the TensorStore.""" + if adapter_id not in self.adapter_registry: + raise ValueError(f"Adatper with ID '{adapter_id}' not found.") + + event_to_wait_on: Optional[asyncio.Event] = None + async with self.lock: + metadata = self.adapter_registry[adapter_id] + if metadata.status == AdapterStatus.LOADING: + event_to_wait_on = metadata.loading_event + + if event_to_wait_on: + await event_to_wait_on.wait() + + async with self.lock: + metadata = self.adapter_registry[adapter_id] + if metadata.status == AdapterStatus.LOADING: + raise RuntimeError( + f"Inconsistent state: Adapter {adapter_id} is LOADING after just finishing one." + ) + + self._unsafe_unload_adapter(adapter_id) + + def list_adapters(self) -> Dict[str, AdapterMetadata]: + """Lists all registered adatpers and their metadata.""" + return self.adapter_registry + + def _evict(self, from_hbm: bool = True) -> bool: + """Evicts the least recently used adapter from memory (HBM or CPU).""" + + # Find the least recently used adapter that is currently loaded. + lru_adapter_id = None + lru_time = float("inf") + + for adapter_id, metadata in self.adapter_registry.items(): + if ( + metadata.status == AdapterStatus.LOADED_HBM + if from_hbm + else metadata.status == AdapterStatus.LOADED_CPU + ): + if metadata.last_accessed < lru_time: + lru_time = metadata.last_accessed + lru_adapter_id = adapter_id + + # If no adapter found to evict, return False + if lru_adapter_id is None: + return False + + if from_hbm: + # Instead of completely unloading it, kept it in CPU RAM. + # It can be loaded to HBM if any request demanded it, or + # it will be evicted from CPU when cpu memory budget reached. + self._unsafe_transfer_to_cpu(lru_adapter_id) + else: + # Unload the LRU adapter + self._unsafe_unload_adapter( + lru_adapter_id + ) # This is not synchronous, but ONLY within the lock + return True diff --git a/jetstream/core/lora/multi_lora_inference_api.py b/jetstream/core/lora/multi_lora_inference_api.py new file mode 100644 index 00000000..5dee2227 --- /dev/null +++ b/jetstream/core/lora/multi_lora_inference_api.py @@ -0,0 +1,126 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manages the list of fine-tuned adapters loaded on top of +the base model for serving. +""" + +import logging +import grpc +import asyncio + +from typing import Optional +from jetstream.core import orchestrator +from jetstream.core.proto import multi_lora_decoding_pb2_grpc +from jetstream.core.proto import multi_lora_decoding_pb2 + + +class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer): + """Manages the parameters of multiple lora requests and their + status/lifetimes. + """ + + _driver: orchestrator.Driver + + def __init__(self, driver: orchestrator.Driver): + self._driver = driver + + def models( + self, + request: multi_lora_decoding_pb2.ListAdaptersRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> multi_lora_decoding_pb2.ListAdaptersResponse: + """ListAdapters all loaded LoRA adapters.""" + + try: + adapters = self._driver.list_adapters_from_tensorstore() + + adapter_infos = [] + for adapter_id, adapter_data in adapters.items(): + if adapter_data.status == "loaded_hbm": + loading_cost = 0 + elif adapter_data.status == "loaded_cpu": + loading_cost = 1 + elif adapter_data.status == "unloaded": + loading_cost = 2 + else: + loading_cost = -1 + + adapter_info = multi_lora_decoding_pb2.AdapterInfo( + adapter_id=adapter_id, + loading_cost=loading_cost, + size_hbm=adapter_data.size_hbm, + size_cpu=adapter_data.size_cpu, + last_accessed=adapter_data.last_accessed, + status=adapter_data.status, + ) + + adapter_infos.append(adapter_info) + + return multi_lora_decoding_pb2.ListAdaptersResponse( + success=True, adapter_infos=adapter_infos + ) + except Exception as e: # pylint: disable=broad-exception-caught + logging.info("Listing of adapters failed with error: %s", str(e)) + return multi_lora_decoding_pb2.ListAdaptersResponse( + success=False, error_message=str(e) + ) + + def load_lora_adapter( + self, + request: multi_lora_decoding_pb2.LoadAdapterRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> multi_lora_decoding_pb2.LoadAdapterResponse: + """Load a LoRA adapter as mentioned in the request.""" + + try: + asyncio.run( + self._driver.load_adapter_to_tensorstore( + request.adapter_id, request.adapter_path + ) + ) + + return multi_lora_decoding_pb2.LoadAdapterResponse(success=True) + except Exception as e: # pylint: disable=broad-exception-caught + logging.info( + "Loading of adapter_id=%s failed with error: %s", + request.adapter_id, + str(e), + ) + return multi_lora_decoding_pb2.LoadAdapterResponse( + success=False, error_message=str(e) + ) + + def unload_lora_adapter( + self, + request: multi_lora_decoding_pb2.UnloadAdapterRequest, + context: Optional[grpc.aio.ServicerContext] = None, + ) -> multi_lora_decoding_pb2.UnloadAdapterResponse: + """Unload a LoRA adapter as mentioned in the request.""" + + try: + asyncio.run( + self._driver.unload_adapter_from_tensorstore(request.adapter_id) + ) + + return multi_lora_decoding_pb2.UnloadAdapterResponse(success=True) + except Exception as e: # pylint: disable=broad-exception-caught + logging.info( + "Loading of adapter_id=%s failed with error: %s", + request.adapter_id, + str(e), + ) + return multi_lora_decoding_pb2.UnloadAdapterResponse( + success=False, error_message=str(e) + ) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 34475e23..c99c3215 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -245,6 +245,31 @@ def __init__(self, model_name: Optional[str] = None): ], ) + self._num_requests_waiting = Gauge( + name="num_requests_waiting", + documentation="Requests count waiting to be processed for inference.", + labelnames=universal_label_names, + multiprocess_mode="sum", + ) + + self._kv_cache_utilization = Gauge( + name="kv_cache_utilization_perc", + documentation="kv-cache utilization by the requests under processing.", + labelnames=universal_label_names, + multiprocess_mode="sum", + ) + + self._lora_request_info = Gauge( + name="lora_request_info", + documentation="LoRA adapters loaded into HBM for processing requests.", + labelnames=universal_label_names + + [ + "max_lora", + "running_lora_adapters", + ], + multiprocess_mode="livemostrecent", + ) + def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(**self.universal_labels) @@ -289,3 +314,16 @@ def get_request_output_length(self): def get_request_success_count_metric(self): return self._request_success_count.labels(**self.universal_labels) + + def get_num_requests_waiting_metric(self): + return self._num_requests_waiting.labels(**self.universal_labels) + + def get_kv_cache_utilization_metric(self): + return self._kv_cache_utilization.labels(**self.universal_labels) + + def get_lora_request_info_metric(self, max_lora: int, loaded_adapters: str): + return self._lora_request_info.labels( + **self.universal_labels, + max_lora=max_lora, + running_lora_adapters=loaded_adapters + ) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 01ee2f97..87cca39e 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -86,11 +86,13 @@ import threading import time import traceback +import asyncio import uuid from typing import Any, AsyncIterator, Optional, Tuple, cast, List import grpc import jax +from jetstream.core.lora import adapter_tensorstore as adapterstore from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc @@ -174,6 +176,8 @@ class ActiveRequest: metadata: ActiveRequestMetadata = dataclasses.field( default_factory=ActiveRequestMetadata ) + ################## Id of the adapter ################### + adapter_id: str = "" def enqueue_samples(self, generated_samples: list[ReturnSample]): """Adds the generated sample(s) to return channel for current step. @@ -253,12 +257,22 @@ class Driver: # All metrics we want to monitor should be collected with this _metrics_collector: JetstreamMetricsCollector | None = None + # Store and manage the adapters for each prefill & generate Engine + _prefill_adapterstore: list[adapterstore.AdapterTensorStore] | None = None + _generate_adapterstore: list[adapterstore.AdapterTensorStore] | None = None + def __init__( self, prefill_engines: Optional[list[engine_api.Engine]] = None, generate_engines: Optional[list[engine_api.Engine]] = None, prefill_params: Optional[list[Any]] = None, generate_params: Optional[list[Any]] = None, + prefill_adapterstore: Optional[ + list[adapterstore.AdapterTensorStore] + ] = None, + generate_adapterstore: Optional[ + list[adapterstore.AdapterTensorStore] + ] = None, interleaved_mode: bool = False, jax_padding: bool = True, metrics_collector: JetstreamMetricsCollector | None = None, @@ -274,6 +288,9 @@ def __init__( if generate_params is None: raise ValueError("No generate parameter provided.") + self._prefill_adapterstore = prefill_adapterstore + self._generate_adapterstore = generate_adapterstore + logger.info( "Initializing the driver with %d prefill engines and %d " "generate engines in %s mode", @@ -439,6 +456,15 @@ def __init__( ) self.live = True self._is_ray_backend = is_ray_backend + + if self._metrics_collector: + self._metrics_collector.get_num_requests_waiting_metric().set_function( + self._get_total_requests_waiting_decode + ) + self._metrics_collector.get_kv_cache_utilization_metric().set_function( + self._get_kv_cache_utilization + ) + # Start all threads for t in self._all_threads: t.start() @@ -499,6 +525,50 @@ def stop(self): logger.info("Driver stopped.") + def _get_kv_cache_utilization(self): + """ + Calculated the kv_cache utilization in percentage based on requests + being decoded. + """ + total_slots = 0 + empty_slots = 0 + for idx, engine in enumerate(self._generate_engines): + total_slots += engine.max_concurrent_decodes + empty_slots += self._generate_slots[idx].qsize() + + return (total_slots - empty_slots) * 100 / total_slots + + def _get_total_requests_waiting_decode(self): + """Calculate the total size of all relevant queues.""" + total_size = self._prefill_backlog.qsize() + + for transfer_queue in self._transfer_backlogs: + total_size += transfer_queue.qsize() + + for gen_queue in self._generate_backlogs.values(): + total_size += gen_queue.qsize() + + return float(total_size) + + def _export_lora_request_info(self): + """Export the metric named `lora_request_info`.""" + + adapters_list_str = "" + max_loras = 0 + if self._metrics_collector: + for idx, engine in enumerate(self._generate_engines): + max_loras += engine.max_concurrent_decodes + if self._generate_adapterstore and idx < len( + self._generate_adapterstore + ): + adapters_list_str += asyncio.run( + self._generate_adapterstore[idx].get_hbm_loaded_adapters() + ) + + self._metrics_collector.get_lora_request_info_metric( + max_loras, adapters_list_str + ).set_to_current_time() + def get_total_concurrent_requests(self) -> int: """Gets the total number of concurrent requests the driver can handle.""" # We don't support filling all backlogs at once because it can cause GIL @@ -605,6 +675,9 @@ def _prefill_thread(self, idx: int): logger.info("Spinning up prefill thread %d.", idx) prefill_engine = self._prefill_engines[idx] prefill_params = self._prefill_params[idx] + adapter_tensorstore = None + if self._prefill_adapterstore and idx < len(self._prefill_adapterstore): + adapter_tensorstore = self._prefill_adapterstore[idx] metadata = prefill_engine.get_tokenizer() tokenizer = prefill_engine.build_tokenizer(metadata) thread_name = f"Prefill thread {idx}" @@ -633,10 +706,43 @@ def _prefill_thread(self, idx: int): prefill_engine.max_prefill_length, ) + adapter_id = request.adapter_id + + # Here we are applying the LoRA adapter params to the base params and + # them. In the interleaved mode, the prefill and generate shares the + # same params. But as long as prefill and decode happens sequentially, + # there is no issues. Issue will arrise if prefill and decode is running + # in parallel and sharing the same params. Issue arrise because prefill + # uses pre-merged weights and generate uses only base weights. + final_prefill_params = prefill_params + if adapter_id and adapter_tensorstore is not None: + try: + lora_params = asyncio.run( + adapter_tensorstore.get_lora_weights( + adapter_id=adapter_id, load_if_not_loaded=True + ) + ) + lora_config = asyncio.run( + adapter_tensorstore.get_lora_config( + adapter_id=adapter_id, load_if_not_loaded=True + ) + ) + prefill_engine.apply_adapter( + final_prefill_params, lora_config, lora_params + ) + except Exception as e: # pylint: disable=broad-exception-caught + request.num_samples = 1 + request.complete = np.zeros((request.num_samples,), np.bool_) + error_message = f"An error occurred: {type(e).__name__} - {str(e)}" + error_result = ReturnSample(text=[error_message], token_ids=[]) + request.enqueue_samples([error_result]) + request.return_channel.close() + continue + # Compute new kv cache for the prefill_content. if self._multi_sampling: prefill_result, first_token = prefill_engine.prefill_multisampling( - params=prefill_params, + params=final_prefill_params, padded_tokens=padded_tokens, true_length=true_length, num_samples=request.num_samples, @@ -647,14 +753,14 @@ def _prefill_thread(self, idx: int): if prefill_engine.use_chunked_prefill: prefill_result, first_token = self._do_chunked_prefill( prefill_engine, - prefill_params, + final_prefill_params, tokenizer, padded_tokens[:true_length], ) else: # Compute new kv cache for the prefill_content. prefill_result, first_token = prefill_engine.prefill( - params=prefill_params, + params=final_prefill_params, padded_tokens=padded_tokens, true_length=true_length, ) @@ -663,6 +769,27 @@ def _prefill_thread(self, idx: int): (prefill_engine.samples_per_slot,), np.bool_ ) + if adapter_id and adapter_tensorstore is not None: + try: + lora_params = asyncio.run( + adapter_tensorstore.get_lora_weights(adapter_id) + ) + lora_config = asyncio.run( + adapter_tensorstore.get_lora_config(adapter_id) + ) + prefill_engine.unapply_adapter( + final_prefill_params, lora_config, lora_params + ) + except Exception as e: # pylint: disable=broad-exception-caught + request.num_samples = 1 + request.complete = np.zeros((request.num_samples,), np.bool_) + error_message = f"An error occurred: {type(e).__name__} - {str(e)}" + error_result = ReturnSample(text=[error_message], token_ids=[]) + request.enqueue_samples([error_result]) + request.return_channel.close() + continue + + del final_prefill_params request.prefill_result = prefill_result # put first token to detokenize queue @@ -844,7 +971,7 @@ def _insert_if_possible( new_request.prefill_result, decode_state, slot=slot, - request_id=new_request.request_id, + # request_id=new_request.request_id, ) ThreadDebugLog( thread_name, @@ -1038,9 +1165,13 @@ def _generate_thread(self, idx: int): generate_engine, my_detokenize_backlog, ) + if decode_state is None: break + # Export the lora_request_info metric + self._export_lora_request_info() + # At this point, we know that we have at least some slots filled. assert ( my_slots.qsize() < max_concurrent_decodes @@ -1240,6 +1371,100 @@ def _detokenize_thread(self, idx: int): logger.info("Detokenize thread %d stopped.", idx) + async def load_adapter_to_tensorstore( + self, adapter_id: str, adapter_path: str + ): + """Load the adapter to adapter_tensorstore for each engine.""" + logger.info("Loading adapter_id=%s from %s.", adapter_id, adapter_path) + + for idx, tensorstore in enumerate(self._prefill_adapterstore): + try: + engine = self._prefill_engines[idx] + adapter_params, adapter_config = engine.load_single_adapter( + adapter_path + ) + + if not adapter_params or not adapter_config: + raise ValueError( + f"Failed to load adapter={adapter_id} from {adapter_path}." + ) + + await tensorstore.register_adapter( + adapter_id, adapter_path, adapter_config + ) + + await tensorstore.load_adapter(adapter_id, adapter_params, True) + + logger.info("Successfully loaded '%s' in engine_%d.", adapter_id, idx) + engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}") + + except Exception as e: + logger.info("Adapter loading failed with error: %s", str(e)) + raise e + + for idx, tensorstore in enumerate(self._generate_adapterstore): + try: + engine = self._generate_engines[idx] + adapter_params, adapter_config = engine.load_single_adapter( + adapter_path + ) + + if not adapter_params or not adapter_config: + raise ValueError( + f"Failed to load adapter={adapter_id} from {adapter_path}." + ) + + await tensorstore.register_adapter( + adapter_id, adapter_path, adapter_config + ) + + await tensorstore.load_adapter(adapter_id, adapter_params, True) + + logger.info("Successfully loaded '%s' in engine_%d.", adapter_id, idx) + engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}") + + except Exception as e: + logger.info("Adapter loading failed with error: %s", str(e)) + raise e + + async def unload_adapter_from_tensorstore(self, adapter_id: str): + """Unload the adapter from adapter_tensorstore of each engine.""" + logger.info("Unloading adapter_id=%s", adapter_id) + + for idx, tensorstore in enumerate(self._prefill_adapterstore): + try: + engine = self._prefill_engines[idx] + await tensorstore.unload_adapter(adapter_id) + + logger.info("Successfully unloaded '%s' in engine_%d.", adapter_id, idx) + engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}") + + except Exception as e: + logger.info("Adapter unloading failed with error: %s", str(e)) + raise e + + for idx, tensorstore in enumerate(self._generate_adapterstore): + try: + engine = self._generate_engines[idx] + await tensorstore.unload_adapter(adapter_id) + + logger.info("Successfully unloaded '%s' in engine_%d.", adapter_id, idx) + engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}") + + except Exception as e: + logger.info("Adapter unloading failed with error: %s", str(e)) + raise e + + def list_adapters_from_tensorstore(self): + """List all the adapters from the adapter_tensorstore of each engine.""" + logger.info("Listing loaded adapters.") + + listed_adapters = {} + for tensorstore in self._generate_adapterstore: + listed_adapters.update(tensorstore.adapter_registry) + + return listed_adapters + class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer): """Coordinates a set of prefill and generate slices for LLM decoding.""" @@ -1346,6 +1571,7 @@ async def Decode( # pylint: disable=invalid-overridden-method prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, + adapter_id=request.lora_adapter_id, metadata=ActiveRequestMetadata( start_time=request.metadata.start_time, prefill_enqueue_time=time.perf_counter(), diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 59cd7dc4..624313c7 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -61,8 +61,10 @@ message DecodeRequest { int32 num_samples = 8; + string lora_adapter_id = 9; + reserved 1, 2, 3; - // Next ID: 9 + // Next ID: 10 } message DecodeResponse { @@ -93,4 +95,4 @@ message HealthCheckRequest {} message HealthCheckResponse { // Denotes whether the model server is live bool is_live = 1; -} \ No newline at end of file +} diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index e649b3bb..a2f72bb0 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -13,7 +13,7 @@ # limitations under the License. # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: jetstream/core/proto/jetstream.proto +# source: jetstream.proto # Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor @@ -26,36 +26,34 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x91\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x13\n\x0bnum_samples\x18\x08 \x01(\x05\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' + b'\n\x0fjetstream.proto\x12\x0fjetstream_proto"\xaa\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x13\n\x0bnum_samples\x18\x08 \x01(\x05\x12\x17\n\x0flora_adapter_id\x18\t \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "jetstream.core.proto.jetstream_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "jetstream_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals["_DECODEREQUEST"]._serialized_start = 58 - _globals["_DECODEREQUEST"]._serialized_end = 459 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 315 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 342 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 344 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 377 - _globals["_DECODEREQUEST_METADATA"]._serialized_start = 379 - _globals["_DECODEREQUEST_METADATA"]._serialized_end = 409 - _globals["_DECODERESPONSE"]._serialized_start = 462 - _globals["_DECODERESPONSE"]._serialized_end = 793 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 628 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 644 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 647 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 776 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 735 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 776 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 795 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 815 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 817 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 855 - _globals["_ORCHESTRATOR"]._serialized_start = 858 - _globals["_ORCHESTRATOR"]._serialized_end = 1043 + _globals["_DECODEREQUEST"]._serialized_start = 37 + _globals["_DECODEREQUEST"]._serialized_end = 463 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 319 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 346 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 348 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 381 + _globals["_DECODEREQUEST_METADATA"]._serialized_start = 383 + _globals["_DECODEREQUEST_METADATA"]._serialized_end = 413 + _globals["_DECODERESPONSE"]._serialized_start = 466 + _globals["_DECODERESPONSE"]._serialized_end = 797 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 632 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 648 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 651 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 780 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 739 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 780 + _globals["_HEALTHCHECKREQUEST"]._serialized_start = 799 + _globals["_HEALTHCHECKREQUEST"]._serialized_end = 819 + _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 821 + _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 859 + _globals["_ORCHESTRATOR"]._serialized_start = 862 + _globals["_ORCHESTRATOR"]._serialized_end = 1047 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index d571ade8..3302ab76 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -15,7 +15,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from jetstream.core.proto import jetstream_pb2 as jetstream_dot_core_dot_proto_dot_jetstream__pb2 +from jetstream.core.proto import jetstream_pb2 as jetstream__pb2 class OrchestratorStub(object): @@ -29,13 +29,13 @@ def __init__(self, channel): """ self.Decode = channel.unary_stream( "/jetstream_proto.Orchestrator/Decode", - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, + request_serializer=jetstream__pb2.DecodeRequest.SerializeToString, + response_deserializer=jetstream__pb2.DecodeResponse.FromString, ) self.HealthCheck = channel.unary_unary( "/jetstream_proto.Orchestrator/HealthCheck", - request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, - response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, + request_serializer=jetstream__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=jetstream__pb2.HealthCheckResponse.FromString, ) @@ -59,13 +59,13 @@ def add_OrchestratorServicer_to_server(servicer, server): rpc_method_handlers = { "Decode": grpc.unary_stream_rpc_method_handler( servicer.Decode, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString, + request_deserializer=jetstream__pb2.DecodeRequest.FromString, + response_serializer=jetstream__pb2.DecodeResponse.SerializeToString, ), "HealthCheck": grpc.unary_unary_rpc_method_handler( servicer.HealthCheck, - request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString, - response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString, + request_deserializer=jetstream__pb2.HealthCheckRequest.FromString, + response_serializer=jetstream__pb2.HealthCheckResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -95,8 +95,8 @@ def Decode( request, target, "/jetstream_proto.Orchestrator/Decode", - jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString, - jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString, + jetstream__pb2.DecodeRequest.SerializeToString, + jetstream__pb2.DecodeResponse.FromString, options, channel_credentials, insecure, @@ -124,8 +124,8 @@ def HealthCheck( request, target, "/jetstream_proto.Orchestrator/HealthCheck", - jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString, - jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString, + jetstream__pb2.HealthCheckRequest.SerializeToString, + jetstream__pb2.HealthCheckResponse.FromString, options, channel_credentials, insecure, diff --git a/jetstream/core/proto/multi_lora_decoding.proto b/jetstream/core/proto/multi_lora_decoding.proto new file mode 100644 index 00000000..30df270e --- /dev/null +++ b/jetstream/core/proto/multi_lora_decoding.proto @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTICE: run `make generate-protos` if making changes to this file + +syntax = "proto3"; + +service v1 { + // Lists all the currently loaded LoRA adapters + rpc models (ListAdaptersRequest) returns (ListAdaptersResponse) {} + + // Loads a new LoRA adapter. + rpc load_lora_adapter (LoadAdapterRequest) returns (LoadAdapterResponse) {} + + // Unloads a LoRA adapter + rpc unload_lora_adapter (UnloadAdapterRequest) returns (UnloadAdapterResponse) {} +} + + +message ListAdaptersRequest {} + +message ListAdaptersResponse { + bool success = 1; // True if successful, False otherwise + string error_message = 2; // Error message if listing the adapters + repeated AdapterInfo adapter_infos = 3; // List of information about loaded adapters. +} + +// Information about a single loaded LoRA adapter +message AdapterInfo { + string adapter_id = 1; + int64 loading_cost = 2; + int64 size_hbm = 3; + int64 size_cpu = 4; + float last_accessed = 5; + string status = 6; +} + +message LoadAdapterRequest { + string adapter_id = 1; // Unique ID/name for the adapter + string adapter_path = 2; // Path to the LoRA adapter (config & weights) +} + +message LoadAdapterResponse { + bool success = 1; // True if successful, false otherwise + string error_message = 2; // Error message if loading failed +} + +message UnloadAdapterRequest { + string adapter_id = 1; // ID/Name of the adapter to unload +} + +message UnloadAdapterResponse { + bool success = 1; // True if successful, false otherwise + string error_message = 2; // Error message if unloading failed +} + diff --git a/jetstream/core/proto/multi_lora_decoding_pb2.py b/jetstream/core/proto/multi_lora_decoding_pb2.py new file mode 100644 index 00000000..d53069c3 --- /dev/null +++ b/jetstream/core/proto/multi_lora_decoding_pb2.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: multi_lora_decoding.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x19multi_lora_decoding.proto"\x15\n\x13ListAdaptersRequest"c\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12#\n\radapter_infos\x18\x03 \x03(\x0b\x32\x0c.AdapterInfo"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xc7\x01\n\x02v1\x12\x37\n\x06models\x12\x14.ListAdaptersRequest\x1a\x15.ListAdaptersResponse"\x00\x12@\n\x11load_lora_adapter\x12\x13.LoadAdapterRequest\x1a\x14.LoadAdapterResponse"\x00\x12\x46\n\x13unload_lora_adapter\x12\x15.UnloadAdapterRequest\x1a\x16.UnloadAdapterResponse"\x00\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "multi_lora_decoding_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_LISTADAPTERSREQUEST"]._serialized_start = 29 + _globals["_LISTADAPTERSREQUEST"]._serialized_end = 50 + _globals["_LISTADAPTERSRESPONSE"]._serialized_start = 52 + _globals["_LISTADAPTERSRESPONSE"]._serialized_end = 151 + _globals["_ADAPTERINFO"]._serialized_start = 154 + _globals["_ADAPTERINFO"]._serialized_end = 284 + _globals["_LOADADAPTERREQUEST"]._serialized_start = 286 + _globals["_LOADADAPTERREQUEST"]._serialized_end = 348 + _globals["_LOADADAPTERRESPONSE"]._serialized_start = 350 + _globals["_LOADADAPTERRESPONSE"]._serialized_end = 411 + _globals["_UNLOADADAPTERREQUEST"]._serialized_start = 413 + _globals["_UNLOADADAPTERREQUEST"]._serialized_end = 455 + _globals["_UNLOADADAPTERRESPONSE"]._serialized_start = 457 + _globals["_UNLOADADAPTERRESPONSE"]._serialized_end = 520 + _globals["_V1"]._serialized_start = 523 + _globals["_V1"]._serialized_end = 722 +# @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py new file mode 100644 index 00000000..6b22989b --- /dev/null +++ b/jetstream/core/proto/multi_lora_decoding_pb2_grpc.py @@ -0,0 +1,169 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from jetstream.core.proto import multi_lora_decoding_pb2 as multi__lora__decoding__pb2 + + +class v1Stub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.models = channel.unary_unary( + "/v1/models", + request_serializer=multi__lora__decoding__pb2.ListAdaptersRequest.SerializeToString, + response_deserializer=multi__lora__decoding__pb2.ListAdaptersResponse.FromString, + ) + self.load_lora_adapter = channel.unary_unary( + "/v1/load_lora_adapter", + request_serializer=multi__lora__decoding__pb2.LoadAdapterRequest.SerializeToString, + response_deserializer=multi__lora__decoding__pb2.LoadAdapterResponse.FromString, + ) + self.unload_lora_adapter = channel.unary_unary( + "/v1/unload_lora_adapter", + request_serializer=multi__lora__decoding__pb2.UnloadAdapterRequest.SerializeToString, + response_deserializer=multi__lora__decoding__pb2.UnloadAdapterResponse.FromString, + ) + + +class v1Servicer(object): + """Missing associated documentation comment in .proto file.""" + + def models(self, request, context): + """Lists all the currently loaded LoRA adapters""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def load_lora_adapter(self, request, context): + """Loads a new LoRA adapter.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def unload_lora_adapter(self, request, context): + """Unloads a LoRA adapter""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_v1Servicer_to_server(servicer, server): + rpc_method_handlers = { + "models": grpc.unary_unary_rpc_method_handler( + servicer.models, + request_deserializer=multi__lora__decoding__pb2.ListAdaptersRequest.FromString, + response_serializer=multi__lora__decoding__pb2.ListAdaptersResponse.SerializeToString, + ), + "load_lora_adapter": grpc.unary_unary_rpc_method_handler( + servicer.load_lora_adapter, + request_deserializer=multi__lora__decoding__pb2.LoadAdapterRequest.FromString, + response_serializer=multi__lora__decoding__pb2.LoadAdapterResponse.SerializeToString, + ), + "unload_lora_adapter": grpc.unary_unary_rpc_method_handler( + servicer.unload_lora_adapter, + request_deserializer=multi__lora__decoding__pb2.UnloadAdapterRequest.FromString, + response_serializer=multi__lora__decoding__pb2.UnloadAdapterResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "v1", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class v1(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def models( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/v1/models", + multi__lora__decoding__pb2.ListAdaptersRequest.SerializeToString, + multi__lora__decoding__pb2.ListAdaptersResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def load_lora_adapter( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/v1/load_lora_adapter", + multi__lora__decoding__pb2.LoadAdapterRequest.SerializeToString, + multi__lora__decoding__pb2.LoadAdapterResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def unload_lora_adapter( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/v1/unload_lora_adapter", + multi__lora__decoding__pb2.UnloadAdapterRequest.SerializeToString, + multi__lora__decoding__pb2.UnloadAdapterResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index d9ae2cb0..cbf5a1af 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -27,6 +27,7 @@ import threading import time import traceback +import importlib from typing import Any, Type @@ -34,6 +35,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator +from jetstream.core.lora import adapter_tensorstore as adapterstore from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import warmup_utils, engine_api @@ -62,7 +64,12 @@ class JetStreamServer: """JetStream grpc server.""" def __init__( - self, driver: orchestrator.Driver, threads: int, port, credentials + self, + driver: orchestrator.Driver, + threads: int, + port, + credentials, + enable_llm_inference_pool=False, ): self._executor = futures.ThreadPoolExecutor(max_workers=threads) @@ -81,6 +88,19 @@ async def do_init(): jetstream_pb2_grpc.add_OrchestratorServicer_to_server( orchestrator.LLMOrchestrator(driver=self._driver), self._grpc_server ) + + if enable_llm_inference_pool: + module_name = "jetstream.core.lora.multi_lora_inference_api" + multi_lora_inference = importlib.import_module(module_name) + + module_name = "jetstream.core.proto.multi_lora_decoding_pb2_grpc" + multi_lora_decoding_pb2_grpc = importlib.import_module(module_name) + + multi_lora_decoding_pb2_grpc.add_v1Servicer_to_server( + multi_lora_inference.MultiLoraManager(driver=self._driver), + self._grpc_server, + ) + self._grpc_server.add_secure_port(f"{_HOST}:{port}", credentials) async def _async_start(self) -> None: @@ -117,6 +137,7 @@ def create_driver( metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, multi_sampling: bool = False, + lora_input_adapters_path: str | None = None, ): """Creates a driver with a specified config. @@ -145,10 +166,47 @@ def create_driver( len(config.prefill_slices) + len(config.generate_slices) == 0 ) + prefill_adapterstore = [] + generate_adapterstore = [] + shared_adapterstore = [] + + if lora_input_adapters_path: + for pe in engines.prefill_engines: + prefill_adapterstore.append( + adapterstore.AdapterTensorStore( + engine=pe, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=20 * (1024**3), # 20 GB HBM + cpu_memory_budget=100 * (1024**3), # 100 GB RAM + ) + ) + # TODO: Make hbm_memory_budget and cpu_memory_budget configurable + for ge in engines.generate_engines: + generate_adapterstore.append( + adapterstore.AdapterTensorStore( + engine=ge, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=20 * (1024**3), # 20 GB HBM + cpu_memory_budget=100 * (1024**3), # 100 GB RAM + ) + ) + + for ie in engines.interleaved_engines: + shared_adapterstore.append( + adapterstore.AdapterTensorStore( + engine=ie, + adapters_dir_path=lora_input_adapters_path, + hbm_memory_budget=20 * (1024**3), # 20 GB HBM + cpu_memory_budget=100 * (1024**3), # 100 GB RAM + ) + ) + prefill_engines = engines.prefill_engines + engines.interleaved_engines generate_engines = engines.generate_engines + engines.interleaved_engines prefill_params = prefill_params + shared_params generate_params = generate_params + shared_params + prefill_adapterstore += shared_adapterstore + generate_adapterstore += shared_adapterstore if prefill_engines is None: prefill_engines = [] # pragma: no branch @@ -183,6 +241,8 @@ def create_driver( generate_engines=generate_engines, prefill_params=prefill_params, generate_params=generate_params, + prefill_adapterstore=prefill_adapterstore, + generate_adapterstore=generate_adapterstore, interleaved_mode=interleaved_mode, jax_padding=jax_padding, metrics_collector=metrics_collector, @@ -220,16 +280,11 @@ def run( jax_profiler_port: The port JAX profiler server (default to 9999). enable_model_warmup: The flag to enable model server warmup. multi_sampling: The flag to enable multi-sampling. - lora_input_adapters_path: Path to define the location of all lora adapters. + lora_input_adapters_path: Input path for all lora adapters. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. """ - # TODO: Deleting the lora_input_adapters_path for now. - # Planning to use it in next big PR. Currently accomodating it - # to fix the params mismatch between maxText and JetStream - del lora_input_adapters_path - server_start_time = time.time() logger.info("Kicking off gRPC server.") # Setup Prometheus server @@ -254,11 +309,18 @@ def run( metrics_collector, enable_model_warmup, multi_sampling, + lora_input_adapters_path, ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) - jetstream_server = JetStreamServer(driver, threads, port, credentials) + enable_llm_inference_pool = False + if lora_input_adapters_path: + enable_llm_inference_pool = True + jetstream_server = JetStreamServer( + driver, threads, port, credentials, enable_llm_inference_pool + ) + logging.info("Starting server on port %d with %d threads", port, threads) # Tweak gc config. # Force a gen 2 collection here. diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index 19af8fbc..6bead7c3 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -99,6 +99,10 @@ def __init__( self._prng_key = jax.random.PRNGKey(42) self._use_chunked_prefill = use_chunked_prefill + def print_stats(self, label: str): + del label + print("print_stats() is not yet supported in TestEngine") + def load_params(self) -> Params: """Loads model weights.""" # An integer, used to multiply inputs. @@ -109,6 +113,29 @@ def load_params_dict(self) -> Params: # An integer, used to multiply inputs. return {"params": jnp.array([self.weight], dtype=jnp.float32)} + def apply_adapter(self, base_params, adapter_config, adapter_params): + """Apply the adapter to the base params.""" + + del adapter_config + base_params = jnp.add(base_params, adapter_params) + + def unapply_adapter(self, base_params, adapter_config, adapter_params): + """Unapply the adapter to the base params.""" + + del adapter_config + base_params = jnp.subtract(base_params, adapter_params) + + def load_single_adapter(self, adapter_path): + """Return adapter_params and adapter_config.""" + + if "fail" in adapter_path: + return None, None + + adapter_params = jnp.array([3.0], dtype=jnp.float32) + adapter_config = {"r": 4, "alpha": 32} + + return adapter_params, adapter_config + @functools.partial( jax.jit, static_argnums=(0,), diff --git a/jetstream/tests/core/lora/test_adapter_tensorstore.py b/jetstream/tests/core/lora/test_adapter_tensorstore.py new file mode 100644 index 00000000..8bfe732e --- /dev/null +++ b/jetstream/tests/core/lora/test_adapter_tensorstore.py @@ -0,0 +1,1199 @@ +import asyncio +import time +import unittest +from unittest.mock import patch, MagicMock, AsyncMock, call, ANY +import dataclasses +import logging +import functools +from typing import Dict, Optional, Any, List + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest, parameterized # Keep for parameterized tests + +# Assuming the adapter_tensorstore code is in this path relative to the tests +# NOTE: Adjust the import path based on your project structure +from jetstream.core.lora import adapter_tensorstore +from jetstream.engine import engine_api # For mocking engine type + +# --- Mocking Helpers --- +# Use helpers directly from the module +_get_size_of_pytree = adapter_tensorstore._get_size_of_pytree +_as_np_array = adapter_tensorstore._as_np_array +_as_jnp_array = adapter_tensorstore._as_jnp_array +AdapterStatus = adapter_tensorstore.AdapterStatus +AdapterMetadata = adapter_tensorstore.AdapterMetadata + + +def create_mock_weights( + size_multiplier: int = 1, dtype=np.float32 +) -> Dict[str, np.ndarray]: + """Creates a dummy PyTree of NumPy arrays for LoRA weights.""" + rank = 8 + input_dim = 128 + output_dim = 256 + # Simple structure for testing purposes + return { + "lora_A": (np.ones((input_dim, rank)) * size_multiplier).astype(dtype), + "lora_B": (np.ones((rank, output_dim)) * size_multiplier).astype(dtype), + } + + +def get_mock_config(rank=8, alpha=16): + return {"rank": rank, "alpha": alpha} + + +def get_mock_size(weights): + weights_as_jnp_array = _as_jnp_array(weights) + weights_as_np_array = _as_np_array(weights) + + size_hbm = _get_size_of_pytree(weights_as_jnp_array) + size_cpu = _get_size_of_pytree(weights_as_np_array) + + return size_hbm, size_cpu + + +# Mock function for engine's load_single_adapter +def load_single_adapter_sync_mock(adapter_path, store_instance): + """Synchronous mock loading function for run_in_executor.""" + logging.info(f"SYNC MOCK: Loading from {adapter_path}") + time.sleep(0.01) # Simulate slight delay + adapter_id = adapter_path.split("/")[-1] + if adapter_id == "adapter_a": + return store_instance.mock_weights_a, store_instance.mock_config_a + elif adapter_id == "adapter_b": + return store_instance.mock_weights_b, store_instance.mock_config_b + elif adapter_id == "adapter_c": # Add another for eviction tests + return store_instance.mock_weights_c, store_instance.mock_config_c + elif adapter_id == "adapter_fail": + raise FileNotFoundError(f"Mock intentionally failed for {adapter_path}") + elif "no_config" in adapter_path: + return create_mock_weights(99), None # Simulate missing config return + else: + raise FileNotFoundError(f"Mock sync path not found: {adapter_path}") + + +# --- Test Class --- + + +class AdapterTensorStoreTest( + parameterized.TestCase, unittest.IsolatedAsyncioTestCase +): # Use IsolatedAsyncioTestCase + + async def asyncSetUp(self): + """Set up mocks and the AdapterTensorStore instance before each test.""" + await super().asyncSetUp() + + self.mock_engine = MagicMock(spec=engine_api.Engine) + self.mock_engine.load_single_adapter = MagicMock() + self.mock_engine.load_single_adapter.side_effect = ( + lambda path: load_single_adapter_sync_mock(path, self) + ) + + self.adapters_dir_path = "/test/adapters" + + self.mock_weights_a = create_mock_weights(1) + self.mock_config_a = get_mock_config(rank=8) + self.mock_size_hbm_a = _get_size_of_pytree( + _as_jnp_array(self.mock_weights_a) + ) + self.mock_size_cpu_a = _get_size_of_pytree(self.mock_weights_a) + + self.mock_weights_b = create_mock_weights(2) + self.mock_config_b = get_mock_config(rank=4) + self.mock_size_hbm_b = _get_size_of_pytree( + _as_jnp_array(self.mock_weights_b) + ) + self.mock_size_cpu_b = _get_size_of_pytree(self.mock_weights_b) + + self.mock_weights_c = create_mock_weights(4) + self.mock_config_c = get_mock_config(rank=12) + self.mock_size_hbm_c = _get_size_of_pytree( + _as_jnp_array(self.mock_weights_c) + ) + self.mock_size_cpu_c = _get_size_of_pytree(self.mock_weights_c) + + # Default budgets + self.hbm_budget = self.mock_size_hbm_a + self.mock_size_hbm_b + 100 + self.cpu_budget = self.mock_size_cpu_a + self.mock_size_cpu_b + 100 + + # Patch time.time + self.time_patcher = patch("time.time") + self.mock_time = self.time_patcher.start() + self.current_time = 1000.0 + self.mock_time.return_value = self.current_time + self.addCleanup(self.time_patcher.stop) + + # Create the store instance + self.store = adapter_tensorstore.AdapterTensorStore( + engine=self.mock_engine, + adapters_dir_path=self.adapters_dir_path, + hbm_memory_budget=self.hbm_budget, + cpu_memory_budget=self.cpu_budget, + ) + + # Pre-register adapters for most tests to simplify setup + # Use await now because register_adapter is async + await self.store.register_adapter( + "adapter_a", adapter_config=self.mock_config_a + ) + await self.store.register_adapter( + "adapter_b", adapter_config=self.mock_config_b + ) + await self.store.register_adapter( + "adapter_c", adapter_config=self.mock_config_c + ) + # Reset mock call count after potential loads during registration + self.mock_engine.load_single_adapter.reset_mock() + + def advance_time(self, seconds: float = 1.0): + """Helper to advance the mocked time.""" + self.current_time += seconds + self.mock_time.return_value = self.current_time + + # === Test Initialization === + def test_initialization(self): + """Test basic attribute initialization.""" + self.assertEqual(self.store.engine, self.mock_engine) + self.assertEqual(self.store.adapters_dir_path, self.adapters_dir_path) + self.assertEqual(self.store.hbm_memory_budget, self.hbm_budget) + self.assertEqual(self.store.cpu_memory_budget, self.cpu_budget) + # Registry will have pre-registered adapters from setUp + self.assertIn("adapter_a", self.store.adapter_registry) + self.assertIn("adapter_b", self.store.adapter_registry) + self.assertIn("adapter_c", self.store.adapter_registry) + self.assertEqual(self.store.loaded_adapters_hbm, {}) + self.assertEqual(self.store.loaded_adapters_cpu, {}) + self.assertEqual(self.store.current_hbm_usage, 0) + self.assertEqual(self.store.current_cpu_usage, 0) + self.assertEqual(self.store.running_requests, 0) + + # === Test register_adapter === + async def test_register_adapter_with_config_only(self): + """Test registration when only config is provided (no auto-load).""" + adapter_id = "adapter_new_config" + adapter_path = "/custom/path/adapter_new" + adapter_config = {"rank": 16} + # Clear registry for this test + self.store.adapter_registry.clear() + + await self.store.register_adapter(adapter_id, adapter_path, adapter_config) + + self.assertIn(adapter_id, self.store.adapter_registry) + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.adapter_id, adapter_id) + self.assertEqual(metadata.adapter_path, adapter_path) + self.assertEqual(metadata.config, adapter_config) + self.assertEqual(metadata.status, AdapterStatus.UNLOADED) + self.assertEqual(self.store.current_hbm_usage, 0) + self.assertEqual(self.store.current_cpu_usage, 0) + self.mock_engine.load_single_adapter.assert_not_called() + + async def test_register_adapter_without_config_triggers_load(self): + """Test registration without config triggers engine load and store load.""" + adapter_id = "adapter_new_load" + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + # Configure engine mock for this specific ID + mock_weights_new = create_mock_weights(9) + mock_config_new = get_mock_config(9) + self.mock_engine.load_single_adapter.side_effect = ( + lambda p: (mock_weights_new, mock_config_new) + if p == adapter_path + else FileNotFoundError + ) + + # Call register - this will call engine.load_single_adapter synchronously + # and then internally call await self.load_adapter(...) + await self.store.register_adapter(adapter_id) # No path, no config + + # 1. Check registration occurred and config populated + self.assertIn(adapter_id, self.store.adapter_registry) + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.adapter_id, adapter_id) + self.assertEqual(metadata.adapter_path, adapter_path) + self.assertEqual( + metadata.config, mock_config_new + ) # Config populated by load + + # 2. Check engine was called to get weights/config + self.mock_engine.load_single_adapter.assert_called_once_with(adapter_path) + + # 3. Check the final state (since mocked run was synchronous) + self.assertEqual( + metadata.status, AdapterStatus.LOADED_HBM + ) # Default load is HBM + self.assertIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertTrue(self.store.current_hbm_usage > 0) + + async def test_register_adapter_load_fails_no_config(self): + """Test register raises error if engine load fails to provide config.""" + adapter_id = "adapter_no_config" + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + # Mock is configured in setUp to return None for config + + with self.assertRaisesRegex( + ValueError, f"Failed to read adapter_config from {adapter_path}" + ): + await self.store.register_adapter(adapter_id) + + self.mock_engine.load_single_adapter.assert_called_once_with(adapter_path) + self.assertNotIn(adapter_id, self.store.adapter_registry) + + async def test_register_adapter_duplicate_logs_warning(self): + """Test registering a duplicate adapter ID logs a warning and is no-op.""" + adapter_id = "adapter_a" # Already registered in setUp + initial_metadata = self.store.adapter_registry[adapter_id] + self.mock_engine.load_single_adapter.reset_mock() + + with self.assertLogs(level="WARNING") as log: + await self.store.register_adapter( + adapter_id, adapter_config={"rank": 32} + ) # Duplicate + + self.assertIn( + f"Adapter with ID '{adapter_id}' already registered.", log.output[0] + ) + # Ensure registry wasn't overwritten + self.assertIs(self.store.adapter_registry[adapter_id], initial_metadata) + self.assertEqual( + self.store.adapter_registry[adapter_id].config, self.mock_config_a + ) + self.mock_engine.load_single_adapter.assert_not_called() + + # === Test load_adapter === + # (Includes tests for success, already loaded, transfers, waiting, eviction) + + async def test_load_unregistered_adapter_raises_error(self): + with self.assertRaisesRegex( + ValueError, "Adapter with ID 'unregistered' not registered." + ): + await self.store.load_adapter("unregistered") + self.assertEqual(self.store.running_requests, 0) + + # Mock the executor call path specifically for load_adapter + @parameterized.named_parameters( + ("load_to_hbm", True, AdapterStatus.LOADED_HBM, "_as_jnp_array"), + ("load_to_cpu", False, AdapterStatus.LOADED_CPU, "_as_np_array"), + ) + async def test_load_new_adapter_success( + self, to_hbm, final_status, convert_func_name + ): + """Test loading a new adapter successfully to HBM or CPU.""" + adapter_id = "adapter_a" + # Reset mocks as register_adapter might have been called implicitly if config was None + self.mock_engine.load_single_adapter.reset_mock() + + # Patch the conversion functions to check they are called + with patch( + f"jetstream.core.lora.adapter_tensorstore.{convert_func_name}", + wraps=getattr(adapter_tensorstore, convert_func_name), + ) as mock_convert: + + await self.store.load_adapter(adapter_id, to_hbm=to_hbm) + + mock_convert.assert_called_once() # Verify the correct conversion was called + + # Verify engine load was called via executor + self.mock_engine.load_single_adapter.assert_called_once_with( + f"{self.adapters_dir_path}/{adapter_id}" + ) + + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.status, final_status) + self.assertEqual(metadata.last_accessed, self.current_time) + self.assertEqual(metadata.size_hbm, self.mock_size_hbm_a) + self.assertEqual(metadata.size_cpu, self.mock_size_cpu_a) + self.assertEqual(self.store.running_requests, 0) # Should be decremented + + if to_hbm: + self.assertIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertNotIn(adapter_id, self.store.loaded_adapters_cpu) + self.assertEqual(self.store.current_hbm_usage, self.mock_size_hbm_a) + self.assertEqual(self.store.current_cpu_usage, 0) + jax.tree_util.tree_map( + lambda x, y: self.assertIsInstance(x, jax.Array), + self.store.loaded_adapters_hbm[adapter_id], + self.mock_weights_a, # Structure reference + ) + else: # Loaded to CPU + self.assertIn(adapter_id, self.store.loaded_adapters_cpu) + self.assertNotIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertEqual(self.store.current_cpu_usage, self.mock_size_cpu_a) + self.assertEqual(self.store.current_hbm_usage, 0) + jax.tree_util.tree_map( + lambda x, y: self.assertIsInstance(x, np.ndarray), + self.store.loaded_adapters_cpu[adapter_id], + self.mock_weights_a, + ) + + async def test_load_adapter_with_preloaded_weights_success(self): + """Test loading works when weights are passed directly.""" + adapter_id = "adapter_pre" + weights_np = create_mock_weights(3) # Use NumPy as if loaded + config = get_mock_config() + size_hbm, size_cpu = get_mock_size(weights_np) + + # Register first + await self.store.register_adapter(adapter_id, adapter_config=config) + + await self.store.load_adapter( + adapter_id, adapter_weights=weights_np, to_hbm=True + ) + + self.mock_engine.load_single_adapter.assert_not_called() # Should not call engine + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.status, AdapterStatus.LOADED_HBM) + self.assertIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertEqual(self.store.current_hbm_usage, size_hbm) + self.assertEqual(metadata.size_hbm, size_hbm) + self.assertEqual(metadata.size_cpu, size_cpu) + self.assertEqual(self.store.running_requests, 0) + + @parameterized.named_parameters( + ("hbm_to_hbm", True, AdapterStatus.LOADED_HBM), + ("cpu_to_cpu", False, AdapterStatus.LOADED_CPU), + ) + async def test_load_adapter_already_loaded_correct_location( + self, to_hbm, initial_status + ): + """Test loading when adapter is already in the desired location.""" + adapter_id = "adapter_a" + # Manually set initial state + if initial_status == AdapterStatus.LOADED_HBM: + self.store.loaded_adapters_hbm[adapter_id] = _as_jnp_array( + self.mock_weights_a + ) + self.store.current_hbm_usage = self.mock_size_hbm_a + else: + self.store.loaded_adapters_cpu[adapter_id] = self.mock_weights_a + self.store.current_cpu_usage = self.mock_size_cpu_a + self.store.adapter_registry[adapter_id].status = initial_status + initial_time = self.current_time + self.advance_time(10) + + # Patch transfers to ensure they are not called + with patch.object( + self.store, "_unsafe_transfer_to_hbm" + ) as mock_t_hbm, patch.object( + self.store, "_unsafe_transfer_to_cpu" + ) as mock_t_cpu: + await self.store.load_adapter(adapter_id, to_hbm=to_hbm) + mock_t_hbm.assert_not_called() + mock_t_cpu.assert_not_called() + + self.mock_engine.load_single_adapter.assert_not_called() # No reload + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.status, initial_status) + self.assertEqual(metadata.last_accessed, self.current_time) # Time updated + self.assertNotEqual(metadata.last_accessed, initial_time) + self.assertEqual(self.store.running_requests, 0) + + @parameterized.named_parameters( + ( + "cpu_to_hbm", + False, + True, + AdapterStatus.LOADED_HBM, + "_unsafe_transfer_to_hbm", + ), + ( + "hbm_to_cpu", + True, + False, + AdapterStatus.LOADED_CPU, + "_unsafe_transfer_to_cpu", + ), + ) + async def test_load_adapter_triggers_transfer( + self, initial_hbm, to_hbm, final_status, transfer_method_name + ): + """Test loading when adapter needs transferring between HBM and CPU.""" + adapter_id = "adapter_a" + # Manually set initial state + if initial_hbm: + self.store.loaded_adapters_hbm[adapter_id] = _as_jnp_array( + self.mock_weights_a + ) + self.store.adapter_registry[adapter_id].status = AdapterStatus.LOADED_HBM + self.store.current_hbm_usage = self.mock_size_hbm_a + else: + self.store.loaded_adapters_cpu[adapter_id] = self.mock_weights_a + self.store.adapter_registry[adapter_id].status = AdapterStatus.LOADED_CPU + self.store.current_cpu_usage = self.mock_size_cpu_a + self.advance_time(10) + + # Patch the specific internal transfer method we expect to be called + with patch.object( + self.store, + transfer_method_name, + wraps=getattr(self.store, transfer_method_name), + ) as mock_transfer: + await self.store.load_adapter(adapter_id, to_hbm=to_hbm) + mock_transfer.assert_called_once_with(adapter_id) + + # Verify final state + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.status, final_status) + self.assertEqual(metadata.last_accessed, self.current_time) + self.mock_engine.load_single_adapter.assert_not_called() # No reload + self.assertEqual(self.store.running_requests, 0) + + # This decorator temporarily replaces the real 'asyncio.get_running_loop' + # function with a fake one ('mock_get_loop') for this test only. + @patch( + "asyncio.get_running_loop" + ) # Need to mock the loop for run_in_executor + async def test_load_adapter_waits_for_loading_state(self, mock_get_loop): + """Test that a second load call waits if the adapter is already loading.""" + adapter_id = "adapter_a" + + # Create an asyncio Event. This will be used to later to signal a task + load_finished_event = asyncio.Event() + + # --- Create a Fake 'run_in_executor' --- + # 'run_in_executor' is the tool asyncio uses to run slow, blocking tasks + # (like loading from disk) in the background. We need to fake this tool. + + mock_loop = MagicMock() # Create a fake "manager" for async tasks. + + # Create a placeholder "box" for the result of the background task. + load_task_future = asyncio.Future() # Future to control completion + + # This is our FAKE function that will replace the real 'run_in_executor'. + async def mock_run_in_executor(executor, func): + # func is the actual loading function (self.engine.load_single_adapter) + print(f"Test Executor: Fake background task started for {adapter_id}") + + # PAUSE HERE: Wait until the signal light (event) turns green. + await load_finished_event.wait() + + print(f"Test Executor: Finishing fake background task for {adapter_id}") + result = func() # Execute the original sync function + load_task_future.set_result(result) # Set future result + return await load_task_future # Return awaitable future + + # When 'run_in_executor' is called, use the **fake** function. + mock_loop.run_in_executor.side_effect = mock_run_in_executor + # Tell the fake 'get_running_loop' (mock_get_loop) to return fake manager + mock_get_loop.return_value = mock_loop + + # --- Start the Test Scenario --- + + # Start task 1: Try to load 'adapter_a'. + # This will eventually call our fake 'run_in_executor' and pause at + # 'await load_finished_event.wait()' + task1 = asyncio.create_task( + self.store.load_adapter(adapter_id, to_hbm=True) + ) + await asyncio.sleep(0.02) # Give task1 time to enter the executor call + + # Assert that task1 has marked the adapter as LOADING + async with self.store.lock: + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.LOADING + ) + self.assertEqual(self.store.running_requests, 1) + + # Start task 2: Try to load the *same* adapter 'adapter_a' again while task 1 is "loading" + task2 = asyncio.create_task( + self.store.load_adapter(adapter_id, to_hbm=True) + ) + + # Give task 2 a tiny moment to start. It should see the status is 'LOADING', + # release the lock, and enter its own waiting loop (calling asyncio.sleep) + await asyncio.sleep(0.01) + + # Allow task 1's load (in executor) to finish + load_finished_event.set() + + # Wait for both tasks to complete + await asyncio.gather(task1, task2) + + # Assertions + self.mock_engine.load_single_adapter.assert_called_once() # Load from disk only once + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.status, AdapterStatus.LOADED_HBM) + self.assertIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertEqual(self.store.running_requests, 0) + + @patch( + "asyncio.get_running_loop" + ) # Need to mock the loop for run_in_executor + async def test_load_adapter_inconsistent_loading_state(self, mock_get_loop): + """Test that a second load call waits if the adapter is already loading.""" + adapter_id = "adapter_a" + + # Create an asyncio Event. This will be used to later to signal a task + load_finished_event = asyncio.Event() + + # --- Create a Fake 'run_in_executor' --- + # 'run_in_executor' is the tool asyncio uses to run slow, blocking tasks + # (like loading from disk) in the background. We need to fake this tool. + + mock_loop = MagicMock() # Create a fake "manager" for async tasks. + + # Create a placeholder "box" for the result of the background task. + load_task_future = asyncio.Future() # Future to control completion + + # This is our FAKE function that will replace the real 'run_in_executor'. + async def mock_run_in_executor(executor, func): + # func is the actual loading function (self.engine.load_single_adapter) + print(f"Test Executor: Fake background task started for {adapter_id}") + + # PAUSE HERE: Wait until the signal light (event) turns green. + await load_finished_event.wait() + + print(f"Test Executor: Finishing fake background task for {adapter_id}") + result = func() # Execute the original sync function + load_task_future.set_result(result) # Set future result + return await load_task_future # Return awaitable future + + # When 'run_in_executor' is called, use the **fake** function. + mock_loop.run_in_executor.side_effect = mock_run_in_executor + # Tell the fake 'get_running_loop' (mock_get_loop) to return fake manager + mock_get_loop.return_value = mock_loop + + with self.assertRaisesRegex( + RuntimeError, + f"Inconsistent state: Adapter {adapter_id} is LOADING but has no event.", + ): + task1 = asyncio.create_task( + self.store.load_adapter(adapter_id, to_hbm=True) + ) + await asyncio.sleep(0.02) # Give task1 time to enter the executor call + + self.store.adapter_registry[adapter_id].loading_event = None + + # Start task 2: Try to load the *same* adapter 'adapter_a' again while task 1 is "loading" + task2 = asyncio.create_task( + self.store.load_adapter(adapter_id, to_hbm=True) + ) + + # Allow task 1's load (in executor) to finish + load_finished_event.set() + + # Wait for both tasks to complete + await asyncio.gather(task1, task2) + + async def test_load_single_adapter_returning_none(self): + """Test when load_single_adapter returns adapter_weights=None.""" + adapter_id = "adapter_a" + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + + self.mock_engine.load_single_adapter.side_effect = ( + lambda p: (None, None) if p == adapter_path else FileNotFoundError + ) + + with self.assertRaisesRegex( + ValueError, f"Failed to load adapter_weights from {adapter_path}." + ): + await self.store.load_adapter(adapter_id) + + async def test_load_adapter_with_changed_status_before_loading(self): + """Test corner case of LOADING status change before loading weights.""" + adapter_id = "adapter_a" + + event_load_finished = asyncio.Event() + + # Mock run_in_executor to control load duration + async def mock_executor(executor, func): + print(f"Test Executor: Started load {adapter_id}") + await event_load_finished.wait() + print(f"Test Executor: Finishing load {adapter_id}") + return func() + + with patch("asyncio.get_running_loop") as mock_get_loop, self.assertLogs( + level="WARNING" + ) as cm: + mock_loop = MagicMock() + mock_loop.run_in_executor.side_effect = mock_executor + mock_get_loop.return_value = mock_loop + + # Start loading task + load_task = asyncio.create_task(self.store.load_adapter(adapter_id)) + await asyncio.sleep(0.01) # Let load start + + # Update the metadata.status to not-LOADING + self.store.adapter_registry[adapter_id].status = AdapterStatus.UNLOADED + + # Allow register_adapter to finish + event_load_finished.set() + + # Wait for both tasks + await asyncio.gather(load_task) + + self.assertEqual(len(cm.output), 1) # Expect exactly one warning message + expected_log = f"Load cancelled for {adapter_id}, status changed to {AdapterStatus.UNLOADED}" + # Check if the expected message is present in the captured output lines + self.assertIn( + expected_log, cm.output[0] + ) # Check the first (and only) logged line + + # --- Eviction Tests --- + + async def test_load_triggers_hbm_eviction(self): + """Test loading to HBM triggers HBM LRU eviction (transfer to CPU).""" + self.store.hbm_memory_budget = ( + self.mock_size_hbm_a + self.mock_size_hbm_b + self.mock_size_hbm_c // 2 + ) # Fits A & B, but not A+B+C + + # Load A (HBM), advance time (A is LRU) + await self.store.load_adapter("adapter_a", to_hbm=True) + self.advance_time(10) + # Load B (HBM), advance time + await self.store.load_adapter("adapter_b", to_hbm=True) + self.advance_time(5) + + # Patch the internal methods involved in eviction + with patch.object( + self.store, "_evict", wraps=self.store._evict + ) as mock_evict, patch.object( + self.store, + "_unsafe_transfer_to_cpu", + wraps=self.store._unsafe_transfer_to_cpu, + ) as mock_transfer_cpu: + + await self.store.load_adapter( + "adapter_c", to_hbm=True + ) # Load C, should evict B + + # Verify eviction happened + mock_evict.assert_called_with(from_hbm=True) + # Check that _unsafe_transfer_to_cpu was called within the evict logic + mock_transfer_cpu.assert_called_once_with("adapter_a") # A was LRU + + # Verify final state + self.assertEqual( + self.store.adapter_registry["adapter_a"].status, + AdapterStatus.LOADED_CPU, + ) # A remains + self.assertEqual( + self.store.adapter_registry["adapter_b"].status, + AdapterStatus.LOADED_HBM, + ) # B transferred + self.assertEqual( + self.store.adapter_registry["adapter_c"].status, + AdapterStatus.LOADED_HBM, + ) # C loaded + self.assertIn("adapter_a", self.store.loaded_adapters_cpu) + self.assertIn("adapter_b", self.store.loaded_adapters_hbm) + self.assertIn("adapter_c", self.store.loaded_adapters_hbm) + self.assertNotIn("adapter_a", self.store.loaded_adapters_hbm) + self.assertEqual( + self.store.current_hbm_usage, + self.mock_size_hbm_b + self.mock_size_hbm_c, + ) + self.assertEqual(self.store.current_cpu_usage, self.mock_size_cpu_a) + + async def test_load_triggers_cpu_eviction(self): + """Test loading to CPU triggers CPU LRU eviction (unload).""" + # self.store.cpu_memory_budget = self.mock_size_cpu_a # Budget fits A + self.store.cpu_memory_budget = ( + self.mock_size_hbm_a + self.mock_size_hbm_b + self.mock_size_hbm_c // 2 + ) # Fits A & B, but not A+B+C + + # Load A (CPU), advance time (A is LRU) + await self.store.load_adapter("adapter_a", to_hbm=False) + self.advance_time(10) + # Load B (CPU), advance time + await self.store.load_adapter("adapter_b", to_hbm=False) + self.advance_time(5) + + with patch.object( + self.store, "_evict", wraps=self.store._evict + ) as mock_evict, patch.object( + self.store, + "_unsafe_unload_adapter", + wraps=self.store._unsafe_unload_adapter, + ) as mock_unload: + + await self.store.load_adapter( + "adapter_c", to_hbm=False + ) # Load C to CPU, should evict A + + mock_evict.assert_called_with(from_hbm=False) + # Check that _unsafe_unload_adapter was called within the evict logic + mock_unload.assert_called_once_with("adapter_a") # A was LRU + + # Verify final state + self.assertEqual( + self.store.adapter_registry["adapter_a"].status, AdapterStatus.UNLOADED + ) # A unloaded + self.assertEqual( + self.store.adapter_registry["adapter_b"].status, + AdapterStatus.LOADED_CPU, + ) # B remains + self.assertEqual( + self.store.adapter_registry["adapter_c"].status, + AdapterStatus.LOADED_CPU, + ) # C loaded + self.assertNotIn("adapter_a", self.store.loaded_adapters_cpu) + self.assertIn("adapter_b", self.store.loaded_adapters_cpu) + self.assertIn("adapter_c", self.store.loaded_adapters_cpu) + self.assertEqual( + self.store.current_cpu_usage, + self.mock_size_cpu_b + self.mock_size_cpu_c, + ) + self.assertEqual(self.store.current_hbm_usage, 0) + + async def test_load_hbm_eviction_fails_raises_error(self): + """Test RuntimeError when HBM eviction fails (no suitable adapter).""" + self.store.hbm_memory_budget = self.mock_size_hbm_a # Fits A + await self.store.load_adapter("adapter_a", to_hbm=True) # Load A + + # Mock _evict to simulate no adapter can be evicted + with patch.object(self.store, "_evict", return_value=False) as mock_evict: + with self.assertRaisesRegex( + RuntimeError, "Not enough HBM to load adapter, and eviction failed." + ): + await self.store.load_adapter("adapter_b", to_hbm=True) # Try load B + mock_evict.assert_called_once_with(from_hbm=True) + + # Check state reverted + self.assertEqual( + self.store.adapter_registry["adapter_b"].status, AdapterStatus.UNLOADED + ) + self.assertEqual( + self.store.adapter_registry["adapter_a"].status, + AdapterStatus.LOADED_HBM, + ) # A remains + self.assertEqual(self.store.current_hbm_usage, self.mock_size_hbm_a) + self.assertEqual(self.store.running_requests, 0) + + async def test_load_cpu_eviction_fails_raises_error(self): + """Test RuntimeError when CPU eviction fails (no suitable adapter).""" + self.store.cpu_memory_budget = self.mock_size_cpu_a # Fits A + await self.store.load_adapter("adapter_a", to_hbm=False) # Load A + + # Mock _evict to simulate no adapter can be evicted + with patch.object(self.store, "_evict", return_value=False) as mock_evict: + with self.assertRaisesRegex( + RuntimeError, + "Not enough CPU RAM to load adapter, and eviction failed.", + ): + await self.store.load_adapter("adapter_b", to_hbm=False) # Try load B + mock_evict.assert_called_once_with(from_hbm=False) + + # Check state reverted + self.assertEqual( + self.store.adapter_registry["adapter_b"].status, AdapterStatus.UNLOADED + ) + self.assertEqual( + self.store.adapter_registry["adapter_a"].status, + AdapterStatus.LOADED_CPU, + ) # A remains + self.assertEqual(self.store.current_cpu_usage, self.mock_size_cpu_a) + self.assertEqual(self.store.running_requests, 0) + + async def test_load_fails_during_io(self): + """Test status reset if engine load fails during I/O.""" + adapter_id = "adapter_fail" # Mock engine will raise FileNotFoundError + await self.store.register_adapter( + adapter_id, adapter_config=get_mock_config() + ) # Register first + self.mock_engine.load_single_adapter.reset_mock() + + # Expect FileNotFoundError from our mock side effect wrapped in executor + with self.assertRaises(FileNotFoundError): + await self.store.load_adapter(adapter_id, to_hbm=True) + + # Check state reverted correctly in finally block + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.status, AdapterStatus.UNLOADED) + self.assertEqual(self.store.running_requests, 0) + self.assertEqual(self.store.current_hbm_usage, 0) + self.assertEqual(self.store.current_cpu_usage, 0) + self.assertNotIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertNotIn(adapter_id, self.store.loaded_adapters_cpu) + + # === Test unload_adapter === + @parameterized.named_parameters( + ("unload_hbm", True), + ("unload_cpu", False), + ) + async def test_unload_adapter_success(self, loaded_hbm): + """Test unloading a loaded adapter from HBM or CPU.""" + adapter_id = "adapter_a" + await self.store.load_adapter( + adapter_id, to_hbm=loaded_hbm + ) # Load it first + self.assertTrue( + self.store.adapter_registry[adapter_id].status + in (AdapterStatus.LOADED_HBM, AdapterStatus.LOADED_CPU) + ) + initial_hbm = self.store.current_hbm_usage + initial_cpu = self.store.current_cpu_usage + + with patch.object( + self.store, + "_unsafe_unload_adapter", + wraps=self.store._unsafe_unload_adapter, + ) as mock_unsafe_unload: + await self.store.unload_adapter(adapter_id) + mock_unsafe_unload.assert_called_once_with(adapter_id) + + metadata = self.store.adapter_registry[adapter_id] + self.assertEqual(metadata.status, AdapterStatus.UNLOADED) + self.assertEqual( + self.store.current_hbm_usage, 0 if loaded_hbm else initial_hbm + ) + self.assertEqual( + self.store.current_cpu_usage, 0 if not loaded_hbm else initial_cpu + ) + self.assertNotIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertNotIn(adapter_id, self.store.loaded_adapters_cpu) + self.assertEqual(metadata.size_hbm, 0) + self.assertEqual(metadata.size_cpu, 0) + + async def test_unload_unregistered_adapter_raises_error(self): + with self.assertRaisesRegex( + ValueError, "Adatper with ID 'unknown' not found." + ): + await self.store.unload_adapter("unknown") + + async def test_unload_already_unloaded_adapter_is_noop(self): + adapter_id = "adapter_a" # Registered but not loaded + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.UNLOADED + ) + + with patch.object( + self.store, + "_unsafe_unload_adapter", + wraps=self.store._unsafe_unload_adapter, + ) as mock_unsafe_unload: + await self.store.unload_adapter(adapter_id) # Should do nothing + mock_unsafe_unload.assert_called_once_with( + adapter_id + ) # Unsafe method called once + + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.UNLOADED + ) + + async def test_unload_waits_for_loading(self): + """Test unload waits if adapter is currently loading.""" + adapter_id = "adapter_a" + event_load_finished = asyncio.Event() + + # Mock run_in_executor to control load duration + async def mock_executor(executor, func): + print(f"Test Executor: Started load {adapter_id}") + await event_load_finished.wait() + print(f"Test Executor: Finishing load {adapter_id}") + return func() + + with patch("asyncio.get_running_loop") as mock_get_loop: + mock_loop = MagicMock() + mock_loop.run_in_executor.side_effect = mock_executor + mock_get_loop.return_value = mock_loop + + # Start loading task + load_task = asyncio.create_task( + self.store.load_adapter(adapter_id, to_hbm=True) + ) + await asyncio.sleep(0.01) # Let load start + + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.LOADING + ) + + # Start unload task concurrently + unload_task = asyncio.create_task(self.store.unload_adapter(adapter_id)) + await asyncio.sleep(0.01) # Let unload start and potentially wait + + # Allow loading to finish + event_load_finished.set() + + # Wait for both tasks + await asyncio.gather(load_task, unload_task) + + # Final state should be unloaded + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.UNLOADED + ) + self.assertNotIn(adapter_id, self.store.loaded_adapters_hbm) + self.assertEqual(self.store.current_hbm_usage, 0) + + # === Test get_lora_config === + + async def test_get_lora_config_success(self): + adapter_id = "adapter_a" + # Already registered in setUp + config = await self.store.get_lora_config(adapter_id) + self.assertEqual(config, self.mock_config_a) + + async def test_get_lora_config_unregistered_raises_error(self): + with self.assertRaisesRegex( + ValueError, "LoRA adapter with id=unknown is not loaded." + ): + await self.store.get_lora_config("unknown") + + async def test_get_lora_config_unregistered_with_load_flag(self): + """Test get_lora_config triggers registration (and potential load).""" + adapter_id = "adapter_new_config" + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + mock_weights_new = create_mock_weights(9) + mock_config_new = get_mock_config(9) + # Configure engine mock for this specific ID + self.mock_engine.load_single_adapter.side_effect = ( + lambda p: (mock_weights_new, mock_config_new) + if p == adapter_path + else FileNotFoundError + ) + + config = await self.store.get_lora_config( + adapter_id, load_if_not_loaded=True + ) + + self.assertEqual(config, mock_config_new) + self.assertIn(adapter_id, self.store.adapter_registry) + self.mock_engine.load_single_adapter.assert_called_once_with(adapter_path) + # Check status after the mocked synchronous load + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.LOADED_HBM + ) + + # === Test get_lora_weights === + + async def test_get_lora_weights_hbm_loaded(self): + """Test getting weights already in HBM.""" + adapter_id = "adapter_a" + await self.store.load_adapter(adapter_id, to_hbm=True) + start_time = self.current_time + + weights = await self.store.get_lora_weights(adapter_id, to_hbm=True) + + # Ensure it's the correct weights and type + self.assertIsInstance(jax.tree_util.tree_leaves(weights)[0], jax.Array) + # Basic check, assumes structure matches mock_weights_a + self.assertTrue( + jnp.allclose( + jax.tree_util.tree_leaves(weights)[0], + jax.tree_util.tree_leaves(_as_jnp_array(self.mock_weights_a))[0], + ) + ) + # Check access time updated + self.assertEqual( + self.store.adapter_registry[adapter_id].last_accessed, start_time + ) + + async def test_get_lora_weights_needs_register_and_load(self): + """Test getting weights triggers loading.""" + adapter_id = "adapter_a" + del self.store.adapter_registry[adapter_id] + + weights = await self.store.get_lora_weights( + adapter_id, to_hbm=True, load_if_not_loaded=True + ) + + self.mock_engine.load_single_adapter.assert_called_once() # Load should be triggered + self.assertIsInstance(jax.tree_util.tree_leaves(weights)[0], jax.Array) + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.LOADED_HBM + ) + + async def test_get_lora_weights_with_unregistered_adapter(self): + """Test getting weights triggers loading.""" + adapter_id = "adapter_a" + del self.store.adapter_registry[adapter_id] + + with self.assertRaisesRegex( + ValueError, f"LoRA adapter with id={adapter_id} is not loaded." + ): + weights = await self.store.get_lora_weights(adapter_id, to_hbm=True) + + async def test_get_lora_weights_needs_load(self): + """Test getting weights triggers loading.""" + adapter_id = "adapter_a" + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.UNLOADED + ) + + weights = await self.store.get_lora_weights(adapter_id, to_hbm=True) + + self.mock_engine.load_single_adapter.assert_called_once() # Load should be triggered + self.assertIsInstance(jax.tree_util.tree_leaves(weights)[0], jax.Array) + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.LOADED_HBM + ) + + @parameterized.named_parameters( + ( + "cpu_to_hbm", + False, + True, + AdapterStatus.LOADED_HBM, + "_unsafe_transfer_to_hbm", + ), + ( + "hbm_to_cpu", + True, + False, + AdapterStatus.LOADED_CPU, + "_unsafe_transfer_to_cpu", + ), + ) + async def test_get_lora_weights_needs_transfer( + self, initial_hbm, to_hbm, final_status, transfer_method_name + ): + """Test getting weights triggers transfer.""" + adapter_id = "adapter_a" + await self.store.load_adapter(adapter_id, to_hbm=initial_hbm) # Load to CPU + + with patch.object( + self.store, + transfer_method_name, + wraps=getattr(self.store, transfer_method_name), + ) as mock_transfer: + weights = await self.store.get_lora_weights( + adapter_id, to_hbm=to_hbm + ) # Request HBM + # The call to _unsafe_transfer_to_hbm happens *inside* the load_adapter call triggered by get_lora_weights + # We rely on the mocked asyncio.run to execute it. + mock_transfer.assert_called_once_with(adapter_id) + + self.assertEqual( + self.store.adapter_registry[adapter_id].status, final_status + ) + + # === Test list_adapters === + async def test_list_adapters_multiple_states(self): + """Test listing adapters with various statuses.""" + # adapter_a, adapter_b, adapter_c are registered in setUp + await self.store.load_adapter("adapter_b", to_hbm=False) # Load B to CPU + await self.store.load_adapter("adapter_c", to_hbm=True) # Load C to HBM + + adapters = self.store.list_adapters() + + self.assertEqual(len(adapters), 3) # a, b, c + self.assertIn("adapter_a", adapters) + self.assertIn("adapter_b", adapters) + self.assertIn("adapter_c", adapters) + self.assertEqual(adapters["adapter_a"].status, AdapterStatus.UNLOADED) + self.assertEqual(adapters["adapter_b"].status, AdapterStatus.LOADED_CPU) + self.assertEqual(adapters["adapter_c"].status, AdapterStatus.LOADED_HBM) + + # === Test get_hbm_loaded_adapters === + async def test_get_hbm_loaded_adapters_mixed(self): + """Test getting only HBM loaded adapters.""" + await self.store.load_adapter("adapter_a", to_hbm=True) + await self.store.load_adapter("adapter_b", to_hbm=False) # B on CPU + await self.store.load_adapter("adapter_c", to_hbm=True) + + hbm_list_str = await self.store.get_hbm_loaded_adapters() + hbm_list = set(s.strip() for s in hbm_list_str.split(",") if s.strip()) + + self.assertEqual(hbm_list, {"adapter_a", "adapter_c"}) + + async def test_get_hbm_loaded_adapters_none(self): + """Test getting HBM adapters when none are loaded.""" + await self.store.load_adapter("adapter_a", to_hbm=False) # Load A to CPU + + hbm_list_str = await self.store.get_hbm_loaded_adapters() + + self.assertEqual(hbm_list_str, "") + + # === Other Tests === + async def test_unsafe_transfer_to_hbm_with_valueerror(self): + """Test raises error in transfer_to_hbm if adapter not in CPU.""" + adapter_id = "adapter_not_in_cpu" + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + + with self.assertRaisesRegex( + ValueError, f"Adapter '{adapter_id}' not loaded in CPU RAM." + ): + self.store._unsafe_transfer_to_hbm(adapter_id) + + async def test_unsafe_transfer_to_hbm_with_evict_failure(self): + """Test eviction failure during transfer_to_hbm.""" + adapter_id = "adapter_a" + self.store.hbm_memory_budget = self.mock_size_hbm_a // 2 # Not enough for A + await self.store.load_adapter("adapter_a", to_hbm=False) + + with self.assertRaisesRegex( + RuntimeError, + "Not enough HBM to transfer adapter, and HBM eviction failed.", + ): + self.store._unsafe_transfer_to_hbm(adapter_id) + + async def test_unsafe_transfer_to_cpu_with_valueerror(self): + """Test raises error in transfer_to_cpu if adapter not in hbm.""" + adapter_id = "adapter_not_in_hbm" + adapter_path = f"{self.adapters_dir_path}/{adapter_id}" + + with self.assertRaisesRegex( + ValueError, f"Adapter '{adapter_id}' not loaded in HBM." + ): + self.store._unsafe_transfer_to_cpu(adapter_id) + + async def test_unsafe_transfer_to_cpu_with_evict_failure(self): + """Test eviction failure during transfer_to_cpu.""" + adapter_id = "adapter_a" + self.store.cpu_memory_budget = self.mock_size_cpu_a // 2 # Not enough for A + await self.store.load_adapter("adapter_a", to_hbm=True) + + with self.assertRaisesRegex( + RuntimeError, + "Not enough CPU RAM to transfer adapter, and CPU eviction failed.", + ): + self.store._unsafe_transfer_to_cpu(adapter_id) + + async def test_unsafe_unload_adapter_with_unregistered_adapter(self): + """Test failure with unregistered adapterd during unload_adapter.""" + adapter_id = "adapter_unregistered" + + with self.assertRaisesRegex( + ValueError, f"Adapter with ID '{adapter_id}' not found." + ): + self.store._unsafe_unload_adapter(adapter_id) + + async def test_register_adapter_with_concurrent_registrations(self): + """Test register adapter scenario with concurrent registrations.""" + adapter_id = "adapter_a" + + adapter_metadata = self.store.adapter_registry[adapter_id] + del self.store.adapter_registry[adapter_id] # Delete already registered one + event_load_finished = asyncio.Event() + + # Mock run_in_executor to control load duration + async def mock_executor(executor, func): + print(f"Test Executor: Started load {adapter_id}") + await event_load_finished.wait() + print(f"Test Executor: Finishing load {adapter_id}") + return func() + + with patch("asyncio.get_running_loop") as mock_get_loop, self.assertLogs( + level="WARNING" + ) as cm: + mock_loop = MagicMock() + mock_loop.run_in_executor.side_effect = mock_executor + mock_get_loop.return_value = mock_loop + + # Start register task + register_task = asyncio.create_task( + self.store.register_adapter(adapter_id) + ) + await asyncio.sleep(0.01) # Let load start + + self.store.adapter_registry[adapter_id] = adapter_metadata + + # Allow register_adapter to finish + event_load_finished.set() + + # Wait for both tasks + await asyncio.gather(register_task) + + # Final state should be unloaded + self.assertEqual( + self.store.adapter_registry[adapter_id].status, AdapterStatus.UNLOADED + ) + self.assertEqual(self.store.current_hbm_usage, 0) + + self.assertEqual(len(cm.output), 1) # Expect exactly one warning message + expected_log = f"Adapter '{adapter_id}' registered concurrently." + # Check if the expected message is present in the captured output lines + self.assertIn( + expected_log, cm.output[0] + ) # Check the first (and only) logged line diff --git a/jetstream/tests/core/lora/test_multi_lora_manager.py b/jetstream/tests/core/lora/test_multi_lora_manager.py new file mode 100644 index 00000000..dc498011 --- /dev/null +++ b/jetstream/tests/core/lora/test_multi_lora_manager.py @@ -0,0 +1,247 @@ +import asyncio +import logging +import time +import unittest +from unittest.mock import patch, MagicMock, AsyncMock # Use AsyncMock for async methods + +import grpc # For mocking context if needed, often None suffices +import numpy as np # For dummy weights if needed by helpers +import jax.numpy as jnp # For dummy weights if needed by helpers + + +# Assuming protos are generated and importable +from jetstream.core.proto import multi_lora_decoding_pb2 +from jetstream.core.proto import multi_lora_decoding_pb2_grpc + +# Assuming the class under test and its dependencies are importable +from jetstream.core.lora import multi_lora_inference_api # Adjust import path +from jetstream.core import orchestrator +from jetstream.core.lora import adapter_tensorstore # For status enum and metadata + +AdapterStatus = adapter_tensorstore.AdapterStatus +AdapterMetadata = ( + adapter_tensorstore.AdapterMetadata +) # Assuming this is accessible +MultiLoraManager = multi_lora_inference_api.MultiLoraManager + + +# --- Mocking Helpers --- +def create_mock_adapter_metadata(adapter_id, status, last_accessed_offset=0): + """Creates a mock AdapterMetadata object.""" + return AdapterMetadata( + adapter_id=adapter_id, + adapter_path=f"/fake/path/{adapter_id}", + status=status, + size_hbm=1024 * 1024 * 10, # 10 MiB + size_cpu=1024 * 1024 * 12, # 12 MiB + last_accessed=time.time() - last_accessed_offset, + config={"rank": 8}, + ) + + +async def mock_load_adapter_to_tensorstore(adapter_id: str, adapter_path: str): + print(f"Test Executor: Fake load_adapter_to_tensorstore") + del adapter_id + del adapter_path + + +# --- Test Class --- + + +class MultiLoraManagerTest(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + """Set up mocks before each test.""" + self.mock_driver = MagicMock(spec=orchestrator.Driver) + # Mock the async methods on the driver using AsyncMock + self.mock_driver.load_adapter_to_tensorstore = AsyncMock(return_value=None) + self.mock_driver.unload_adapter_from_tensorstore = AsyncMock( + return_value=None + ) + + # list_adapters_from_tensorstore is synchronous in the example + self.mock_driver.list_adapters_from_tensorstore = MagicMock() + + # Create the instance of the class under test + self.manager = MultiLoraManager(driver=self.mock_driver) + + # === Test models (ListAdapters) === + + def test_models_success_multiple_adapters(self): + """Test listing adapters successfully with various statuses.""" + mock_registry_data = { + "adapter1": create_mock_adapter_metadata( + "adapter1", AdapterStatus.LOADED_HBM, 10 + ), + "adapter2": create_mock_adapter_metadata( + "adapter2", AdapterStatus.LOADED_CPU, 5 + ), + "adapter3": create_mock_adapter_metadata( + "adapter3", AdapterStatus.UNLOADED, 20 + ), + "adapter4": create_mock_adapter_metadata( + "adapter4", AdapterStatus.LOADING, 1 + ), + } + self.mock_driver.list_adapters_from_tensorstore.return_value = ( + mock_registry_data + ) + + request = multi_lora_decoding_pb2.ListAdaptersRequest() + response = self.manager.models(request) # Call the sync method + + self.mock_driver.list_adapters_from_tensorstore.assert_called_once() + self.assertTrue(response.success) + self.assertEqual(response.error_message, "") + self.assertEqual(len(response.adapter_infos), 4) + + # Check mapping and content (order might vary depending on dict iteration) + response_map = {info.adapter_id: info for info in response.adapter_infos} + self.assertIn("adapter1", response_map) + self.assertIn("adapter2", response_map) + self.assertIn("adapter3", response_map) + self.assertIn("adapter4", response_map) + + # Check loading_cost mapping based on status + self.assertEqual(response_map["adapter1"].loading_cost, 0) # LOADED_HBM + self.assertEqual(response_map["adapter2"].loading_cost, 1) # LOADED_CPU + self.assertEqual(response_map["adapter3"].loading_cost, 2) # UNLOADED + self.assertEqual( + response_map["adapter4"].loading_cost, -1 + ) # LOADING (or other) + + # Check other fields are copied correctly + self.assertEqual( + response_map["adapter1"].size_hbm, + mock_registry_data["adapter1"].size_hbm, + ) + self.assertEqual( + response_map["adapter2"].size_cpu, + mock_registry_data["adapter2"].size_cpu, + ) + self.assertEqual( + response_map["adapter3"].status, + mock_registry_data["adapter3"].status.value, + ) + + def test_models_success_no_adapters(self): + """Test listing when no adapters are registered.""" + self.mock_driver.list_adapters_from_tensorstore.return_value = ( + {} + ) # Empty dict + + request = multi_lora_decoding_pb2.ListAdaptersRequest() + response = self.manager.models(request) + + self.mock_driver.list_adapters_from_tensorstore.assert_called_once() + self.assertTrue(response.success) + self.assertEqual(response.error_message, "") + self.assertEqual(len(response.adapter_infos), 0) + + def test_models_driver_exception(self): + """Test error handling when the driver raises an exception.""" + error_message = "Driver failed!" + self.mock_driver.list_adapters_from_tensorstore.side_effect = Exception( + error_message + ) + + request = multi_lora_decoding_pb2.ListAdaptersRequest() + with self.assertLogs(level="INFO") as log: + response = self.manager.models(request) + + self.mock_driver.list_adapters_from_tensorstore.assert_called_once() + self.assertFalse(response.success) + self.assertEqual(response.error_message, error_message) + self.assertEqual(len(response.adapter_infos), 0) + self.assertIn("Listing of adapters failed with error:", log.output[0]) + self.assertIn(error_message, log.output[0]) + + # === Test load_lora_adapter === + + def test_load_lora_adapter_success(self): + """Test successful loading of an adapter.""" + adapter_id = "adapter_to_load" + adapter_path = "/path/to/load" + request = multi_lora_decoding_pb2.LoadAdapterRequest( + adapter_id=adapter_id, adapter_path=adapter_path + ) + + response = self.manager.load_lora_adapter(request) # Call sync method + + self.mock_driver.load_adapter_to_tensorstore.assert_awaited_once_with( + adapter_id, adapter_path + ) + self.assertEqual(response.error_message, "") + self.assertTrue(response.success) + + def test_load_lora_adapter_driver_exception(self): + """Test error handling when driver load fails.""" + adapter_id = "adapter_fail_load" + adapter_path = "/path/to/fail" + error_message = "Loading failed in driver!" + request = multi_lora_decoding_pb2.LoadAdapterRequest( + adapter_id=adapter_id, adapter_path=adapter_path + ) + + # Configure the async mock to raise an exception + self.mock_driver.load_adapter_to_tensorstore.side_effect = Exception( + error_message + ) + + with self.assertLogs(level="INFO") as log: + response = self.manager.load_lora_adapter(request) + + self.mock_driver.load_adapter_to_tensorstore.assert_awaited_once_with( + adapter_id, adapter_path + ) + self.assertFalse(response.success) + self.assertEqual(response.error_message, error_message) + self.assertIn( + f"Loading of adapter_id={adapter_id} failed with error:", log.output[0] + ) + self.assertIn(error_message, log.output[0]) + + # === Test unload_lora_adapter === + + def test_unload_lora_adapter_success(self): + """Test successful unloading of an adapter.""" + adapter_id = "adapter_to_unload" + request = multi_lora_decoding_pb2.UnloadAdapterRequest( + adapter_id=adapter_id + ) + + self.mock_driver.unload_adapter_from_tensorstore.return_value = None + + response = self.manager.unload_lora_adapter(request) + + self.mock_driver.unload_adapter_from_tensorstore.assert_awaited_once_with( + adapter_id + ) + self.assertTrue(response.success) + self.assertEqual(response.error_message, "") + + def test_unload_lora_adapter_driver_exception(self): + """Test error handling when driver unload fails.""" + adapter_id = "adapter_fail_unload" + error_message = "Unloading failed in driver!" + request = multi_lora_decoding_pb2.UnloadAdapterRequest( + adapter_id=adapter_id + ) + + self.mock_driver.unload_adapter_from_tensorstore.side_effect = Exception( + error_message + ) + + # Logging is same as load error in the original code, adjust if needed + with self.assertLogs(level="INFO") as log: + response = self.manager.unload_lora_adapter(request) + + self.mock_driver.unload_adapter_from_tensorstore.assert_awaited_once_with( + adapter_id + ) + self.assertFalse(response.success) + self.assertEqual(response.error_message, error_message) + self.assertIn( + f"Loading of adapter_id={adapter_id} failed with error:", log.output[0] + ) # Check log message + self.assertIn(error_message, log.output[0]) diff --git a/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py index 1f872c8e..55db5a09 100644 --- a/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -42,8 +42,10 @@ """ import unittest +import jax.numpy as jnp from parameterized import parameterized from jetstream.core import orchestrator +from jetstream.core.lora import adapter_tensorstore as adapterstore from jetstream.core.proto import jetstream_pb2 from jetstream.core.utils.return_sample import ReturnSample from jetstream.engine import mock_engine @@ -93,6 +95,61 @@ def _setup_driver_chunked_prefill(self, interleaved_mode: bool = True): ) return driver + async def _setup_driver_with_adapterstore( + self, interleaved_mode: bool = True, multi_sampling: bool = False + ): + prefill_engine = mock_engine.TestEngine( + batch_size=32, cache_length=256, weight=2.0 + ) + # Create a generate engine with a different set of weights + # so that we can test that the right one is in use at a given time. + generate_engine = mock_engine.TestEngine( + batch_size=4, cache_length=32, weight=4.0 + ) + + prefill_adapterstore = adapterstore.AdapterTensorStore( + engine=prefill_engine, + adapters_dir_path="/tmp/", + hbm_memory_budget=20 * (1024**3), # 20 GB HBM + cpu_memory_budget=100 * (1024**3), # 100 GB RAM + ) + + generate_adapterstore = adapterstore.AdapterTensorStore( + engine=generate_engine, + adapters_dir_path="/tmp/", + hbm_memory_budget=20 * (1024**3), # 20 GB HBM + cpu_memory_budget=100 * (1024**3), # 100 GB RAM + ) + + await prefill_adapterstore.register_adapter( + adapter_id="test_adapter_1", adapter_config={"r": 4, "alpha": 32} + ) + + adapter_params = jnp.array([3.0], dtype=jnp.float32) + await prefill_adapterstore.load_adapter( + adapter_id="test_adapter_1", adapter_weights=adapter_params, to_hbm=True + ) + + await generate_adapterstore.register_adapter( + adapter_id="test_adapter_1", adapter_config={"r": 4, "alpha": 32} + ) + + await generate_adapterstore.load_adapter( + adapter_id="test_adapter_1", adapter_weights=adapter_params, to_hbm=True + ) + + driver = orchestrator.Driver( + prefill_engines=[prefill_engine], + generate_engines=[generate_engine], + prefill_params=[prefill_engine.load_params()], + generate_params=[generate_engine.load_params()], + prefill_adapterstore=[prefill_adapterstore], + generate_adapterstore=[generate_adapterstore], + interleaved_mode=interleaved_mode, + multi_sampling=multi_sampling, + ) + return driver + @unittest.skip("Rewrite mock engine to test chunked prefill call correctly.") @parameterized.expand([True, False]) async def test_orchestrator_chunked_prefill(self, interleaved_mode: bool): @@ -264,3 +321,158 @@ def test_should_buffer_response(self, interleaved_mode: bool): ) driver.stop() print("Orchestrator driver stopped.") + + @parameterized.expand([True, False]) + async def test_orchestrator_with_adapterstore(self, interleaved_mode: bool): + """Test the multithreaded orchestration with LoRA adapterStore.""" + driver = await self._setup_driver_with_adapterstore(interleaved_mode) + client = orchestrator.LLMOrchestrator(driver=driver) + + # The string representation of np.array([[65, 66]]), [2] will be prepend + # as BOS. + text = "AB" + + request = jetstream_pb2.DecodeRequest( + text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), + max_tokens=3, + lora_adapter_id="test_adapter_1", + ) + iterator = client.Decode(request) + # chr of [266, 332, 415]. + expected_text = ["Ċ", "Ō", "Ɵ", ""] + expected_token_ids = [266, 332, 415, None] + counter = 0 + async for resp in iterator: + output_text = resp.stream_content.samples[0].text + token_ids = resp.stream_content.samples[0].token_ids + output_token_id = token_ids[0] if len(token_ids) > 0 else None + print(f"actual output: {output_text=} {output_token_id=}") + assert output_text == expected_text[counter] + assert output_token_id == expected_token_ids[counter] + counter += 1 + driver.stop() + print("Orchestrator driver stopped.") + + @parameterized.expand([True, False]) + async def test_load_unload_adapter(self, interleaved_mode: bool): + """Test loading of adapter to adapter_tensorstore.""" + driver = await self._setup_driver_with_adapterstore(interleaved_mode) + + await driver.load_adapter_to_tensorstore("test_adapter_2", "/tmp/") + client = orchestrator.LLMOrchestrator(driver=driver) + + # The string representation of np.array([[65, 66]]), [2] will be prepend + # as BOS. + text = "AB" + + request = jetstream_pb2.DecodeRequest( + text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), + max_tokens=3, + lora_adapter_id="test_adapter_2", + ) + + # results = asyncio.run(_consume_decode_iterator(client, request)) + iterator = client.Decode(request) + # chr of [266, 332, 415]. + expected_text = ["Ċ", "Ō", "Ɵ", ""] + expected_token_ids = [266, 332, 415, None] + counter = 0 + async for resp in iterator: + output_text = resp.stream_content.samples[0].text + token_ids = resp.stream_content.samples[0].token_ids + output_token_id = token_ids[0] if len(token_ids) > 0 else None + print(f"actual output: {output_text=} {output_token_id=}") + assert output_text == expected_text[counter] + assert output_token_id == expected_token_ids[counter] + counter += 1 + + adapters = driver.list_adapters_from_tensorstore() + + assert "test_adapter_2" in adapters + + metadata = adapters["test_adapter_2"] + assert metadata.status in ( + adapterstore.AdapterStatus.LOADED_HBM, + adapterstore.AdapterStatus.LOADED_CPU, + ) + + await driver.unload_adapter_from_tensorstore("test_adapter_2") + + adapters = driver.list_adapters_from_tensorstore() + + assert "test_adapter_2" in adapters + + metadata = adapters["test_adapter_2"] + assert metadata.status == adapterstore.AdapterStatus.UNLOADED + + driver.stop() + print("Orchestrator driver stopped.") + + async def test_drivers_with_none_engine_and_params(self): + """Test should raise error when driver is init with none engine/driver.""" + prefill_engine = mock_engine.TestEngine( + batch_size=32, cache_length=256, weight=2.0 + ) + # Create a generate engine with a different set of weights + # so that we can test that the right one is in use at a given time. + generate_engine = mock_engine.TestEngine( + batch_size=4, cache_length=32, weight=4.0 + ) + + with self.assertRaisesRegex(ValueError, "No prefill engine provided."): + driver = orchestrator.Driver( + generate_engines=[generate_engine], + prefill_params=[prefill_engine.load_params()], + generate_params=[generate_engine.load_params()], + ) + del driver + + with self.assertRaisesRegex(ValueError, "No generate engine provided."): + driver = orchestrator.Driver( + prefill_engines=[prefill_engine], + prefill_params=[prefill_engine.load_params()], + generate_params=[generate_engine.load_params()], + ) + del driver + + with self.assertRaisesRegex(ValueError, "No prefill parameter provided."): + driver = orchestrator.Driver( + generate_engines=[generate_engine], + prefill_engines=[prefill_engine], + generate_params=[generate_engine.load_params()], + ) + del driver + + with self.assertRaisesRegex(ValueError, "No generate parameter provided."): + driver = orchestrator.Driver( + generate_engines=[generate_engine], + prefill_engines=[prefill_engine], + prefill_params=[prefill_engine.load_params()], + ) + del driver + + async def test_adapterstores_exceptions(self, interleaved_mode: bool = True): + driver = await self._setup_driver_with_adapterstore(interleaved_mode) + + client = orchestrator.LLMOrchestrator(driver=driver) + + # The string representation of np.array([[65, 66]]), [2] will be prepend + # as BOS. + text = "AB" + + request = jetstream_pb2.DecodeRequest( + text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), + max_tokens=3, + lora_adapter_id="test_adapter_fail", + ) + iterator = client.Decode(request) + + # chr of [266, 332, 415]. + expected_text = "An error occurred" + output_text = "" + async for resp in iterator: + output_text += resp.stream_content.samples[0].text + + self.assertIn(expected_text, output_text) + driver.stop() + print("Orchestrator driver stopped.") diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 89afc02d..bfed1540 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -89,6 +89,7 @@ def test_server( ) if metrics_enabled is True else None, + lora_input_adapters_path="/test/adapter/", ) ###################### Requester side ###################################### diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index 0340dbfe..79187dfc 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -38,8 +38,17 @@ export MODEL_BUCKET=$4 # Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint. export BASE_OUTPUT_DIRECTORY=$5 +export HUGGING_FACE_CHECKPOINT=$6 + +export LORA_INPUT_ADAPTERS_PATH=$7 + export BUCKET_LOCATION=US +if [[ -z "HUGGING_FACE_CHECKPOINT" ]]; then + echo "HUGGING_FACE_CHECKPOINT is required." + exit 1 +fi + # Create three GCS buckets for the demo. gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true @@ -56,32 +65,73 @@ else pip install torch --index-url https://download.pytorch.org/whl/cpu # llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local. tmp_ckpt_path="/tmp/" - gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} + #gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} + path_parts=(${CHKPT_BUCKET//\// }) directory_substring=${path_parts[-1]} CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py" - JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ - --base-model-path ${tmp_ckpt_path}${directory_substring} \ - --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ - --model-size ${MODEL_NAME} + + if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then + lora_local_path="/tmp/" + + if [[ "${LORA_INPUT_ADAPTERS_PATH}" =~ ^gs:// ]]; then + path_parts=(${LORA_INPUT_ADAPTERS_PATH//\// }) + lora_dir_substring=${path_parts[-1]} + + lora_local_path="${tmp_ckpt_path}${lora_dir_substring}" + if [[ ! -d ${lora_local_path} ]]; then + mkdir ${lora_local_path} + fi + gcloud storage cp -r ${LORA_INPUT_ADAPTERS_PATH} ${tmp_ckpt_path} + else + lora_local_path=${LORA_INPUT_ADAPTERS_PATH} + fi + + JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ + --base-model-path ${tmp_ckpt_path}${directory_substring} \ + --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ + --model-size ${MODEL_NAME} \ + --lora-input-adapters-path ${lora_local_path} \ + --huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT} + else + JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ + --base-model-path ${tmp_ckpt_path}${directory_substring} \ + --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ + --model-size ${MODEL_NAME} \ + --huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT} + fi fi echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" # We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory. -export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items +export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} # Convert MaxText compatible checkpoints to unscanned checkpoints. # Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} -JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ -MaxText/configs/base.yml \ -base_output_directory=${BASE_OUTPUT_DIRECTORY} \ -load_parameters_path=${SCANNED_CKPT_PATH} \ -run_name=${RUN_NAME} \ -model_name=${MODEL_NAME} \ -force_unroll=true -echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" +if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then + JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ + MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + load_parameters_path=${SCANNED_CKPT_PATH}/base/0/items \ + lora_input_adapters_path=${SCANNED_CKPT_PATH}/loras \ + run_name=${RUN_NAME} \ + model_name=${MODEL_NAME} \ + force_unroll=true + echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" +else + JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ + MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + load_parameters_path=${SCANNED_CKPT_PATH}/0/items \ + run_name=${RUN_NAME} \ + model_name=${MODEL_NAME} \ + force_unroll=true + echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints" +fi + + # We will use the unscanned checkpoints by passing `UNSCANNED_CKPT_PATH` into `LOAD_PARAMETERS_PATH` in the following sections. export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items diff --git a/jetstream/tools/multi_adapter_service_client.py b/jetstream/tools/multi_adapter_service_client.py new file mode 100644 index 00000000..f91609ae --- /dev/null +++ b/jetstream/tools/multi_adapter_service_client.py @@ -0,0 +1,193 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A gRPC client to interact with JetStream Server.""" + +from typing import Sequence + +from absl import app +from absl import flags +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.core.proto import multi_lora_decoding_pb2 +from jetstream.core.proto import multi_lora_decoding_pb2_grpc +from jetstream.engine.token_utils import load_vocab + + +_SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") +_PORT = flags.DEFINE_string("port", "9000", "port to ping") +# _TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") +_TEXT = flags.DEFINE_string("text", "22 year old", "The message") +_MAX_TOKENS = flags.DEFINE_integer( + "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" +) + +_LORA_ADAPTER_ID = flags.DEFINE_string( + "lora_adapter_id", + None, + "Id of the fine-tuned adapter to be loaded on top of the base model.", + required=False, +) + +_LORA_ADAPTER_PATH = flags.DEFINE_string( + "lora_adapter_path", + None, + "Path of the fine-tuned adapter to be loaded from.", + required=False, +) + +_TEST_API_NAME = flags.DEFINE_string( + "test_api_name", + None, + "Name of the JetStream API to call.", + required=True, +) + + +def main(argv: Sequence[str]) -> None: + """ + Main function for a gRPC client that interacts with a JetStream server. + + This client can: + - Load a LoRA adapter. + - Unload a LoRA adapter. + - List loaded adapters and their metadata. + - Generate text completions (using LoRA adapters if specified). + + The client uses command-line flags to specify the server address, port, + text input, maximum number of tokens, adapter ID, adapter path, and the + API to call. It uses insecure gRPC channels (suitable for local testing). + + Args: + argv: Command-line arguments (not used directly, flags are used instead). + + Raises: + ValueError: For invalid configurations, like missing required parameters + for specific API calls. + """ + + del argv # Unused + # Note: Uses insecure_channel only for local testing. Please add grpc + # credentials for Production. + address = f"{_SERVER.value}:{_PORT.value}" + with grpc.insecure_channel(address) as channel: + grpc.channel_ready_future(channel).result() + stub = multi_lora_decoding_pb2_grpc.v1Stub(channel) + print(f"Sending request to: {address}") + + if _TEST_API_NAME.value == "load_lora_adapter": + print(f"Calling the /v1/load_lora_adapter.") + + adapter_id = _LORA_ADAPTER_ID.value + adapter_path = _LORA_ADAPTER_PATH.value + + if adapter_id == None or adapter_path == None: + print( + f"For `load_lora_adapter` API call, `adapter_id` and `adapter_path` must be passed." + ) + return + + request = multi_lora_decoding_pb2.LoadAdapterRequest( + adapter_id=adapter_id, adapter_path=adapter_path + ) + + response = stub.load_lora_adapter(request) + + if response.success is True: + print(f"Adapter={adapter_id} is loaded successfully.") + else: + print( + f"Adapter={adapter_id} loading failed with error={response.error_message}" + ) + + elif _TEST_API_NAME.value == "unload_lora_adapter": + print(f"Calling the /v1/unload_lora_adapter.") + + adapter_id = _LORA_ADAPTER_ID.value + + if adapter_id == None: + print( + f"For `unload_lora_adapter` API call, `adapter_id` must be passed." + ) + return + + request = multi_lora_decoding_pb2.UnloadAdapterRequest( + adapter_id=adapter_id, + ) + + response = stub.unload_lora_adapter(request) + + if response.success is True: + print(f"Adapter={adapter_id} is unloaded successfully.") + else: + print( + f"Adapter={adapter_id} unloading failed with error={response.error_message}" + ) + + elif _TEST_API_NAME.value == "models": + print(f"Calling the /v1/models.") + + request = multi_lora_decoding_pb2.ListAdaptersRequest() + + response = stub.models(request) + + if response.success is True: + print(f"`models` call responded successfully.") + if response.adapter_infos: + print(f"Here is the list of adapters loaded on server:") + else: + print(f"No adapters are loaded on the server.") + + for adapter_info in response.adapter_infos: + print( + f"adapter_id={adapter_info.adapter_id}, loading_cost={adapter_info.loading_cost}, size_hbm={adapter_info.size_hbm} bytes, size_cpu={adapter_info.size_cpu} Bytes, last_accessed={adapter_info.last_accessed}, status={adapter_info.status}" + ) + else: + print(f"`models` call failed with error={response.error_message}") + + elif _TEST_API_NAME.value == "completions": + print(f"Calling the /v1/completions.") + + request = jetstream_pb2.DecodeRequest( + text_content=jetstream_pb2.DecodeRequest.TextContent( + text=_TEXT.value, + ), + max_tokens=_MAX_TOKENS.value, + lora_adapter_id=_LORA_ADAPTER_ID.value, + ) + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + + response = stub.Decode(request) + + output = [] + for resp in response: + output.extend(resp.stream_content.samples[0].text) + + text_output = "".join(output) + + print(f"Prompt: {_TEXT.value}") + print(f"Response: {text_output}") + + elif _TEST_API_NAME.value == None: + print(f"`test_api_name` flag is not set. So exiting.") + return + + else: + print(f"API={_TEST_API_NAME.value} is not implemented yet. So exiting.") + return + + +if __name__ == "__main__": + app.run(main) diff --git a/jetstream/tools/multi_lora_decode_requester.py b/jetstream/tools/multi_lora_decode_requester.py new file mode 100644 index 00000000..a651f9f8 --- /dev/null +++ b/jetstream/tools/multi_lora_decode_requester.py @@ -0,0 +1,254 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decoding multiple LoRA requests via JetStream online serving. +""" + + +import argparse +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +import json +import random +import time +from typing import Any, AsyncGenerator, Optional +import os + + +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine.token_utils import load_vocab +from jetstream.external_tokenizers.llama3 import llama3_tokenizer +import numpy as np + + +@dataclass +class InputRequest: + prompt: str = "" + output: str = "" + output_len: int = 0 + sample_idx: int = -1 + adapter_id: str = "" + + +@dataclass +class RequestFuncOutput: + input_request: Optional[InputRequest] = None + generated_token_list: list[str] = field(default_factory=list) + generated_text: str = "" + success: bool = False + latency: float = 0 + ttft: float = 0 + + # Flatten the structure and return only the necessary results + def to_dict(self): + return { + "prompt": self.input_request.prompt, + "original_output": self.input_request.output, + "generated_text": self.generated_text, + "success": self.success, + "latency": self.latency, + "sample_idx": self.input_request.sample_idx, + } + + +def get_tokenizer( + model_id: str, + tokenizer_name: str, +) -> Any: + """Return a tokenizer or a tokenizer placholder.""" + if tokenizer_name == "test": + print("Using test tokenizer") + return "test" + elif model_id == "llama-3": + # Llama 3 uses a tiktoken tokenizer. + print(f"Using llama-3 tokenizer: {tokenizer_name}") + return llama3_tokenizer.Tokenizer(tokenizer_name) + else: + # Use JetStream tokenizer util. It's using the sentencepiece wrapper in + # seqio library. + print(f"Using tokenizer: {tokenizer_name}") + vocab = load_vocab(tokenizer_name) + return vocab.tokenizer + + +async def grpc_async_request( + api_url: str, request: Any +) -> tuple[list[str], float, float]: + """Send grpc synchronous request since the current grpc server is sync.""" + options = [("grpc.keepalive_timeout_ms", 10000)] + async with grpc.aio.insecure_channel(api_url, options=options) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + print("Making request") + ttft = 0 + token_list = [] + request_start_time = time.perf_counter() + response = stub.Decode(request) + async for resp in response: + if ttft == 0: + ttft = time.perf_counter() - request_start_time + token_list.extend(resp.stream_content.samples[0].token_ids) + latency = time.perf_counter() - request_start_time + return token_list, ttft, latency + + +async def send_request( + api_url: str, + tokenizer: Any, + input_request: InputRequest, +) -> RequestFuncOutput: + """Send the request to JetStream server.""" + # Tokenization on client side following MLPerf standard. + token_ids = tokenizer.encode(input_request.prompt) + request = jetstream_pb2.DecodeRequest( + token_content=jetstream_pb2.DecodeRequest.TokenContent( + token_ids=token_ids + ), + max_tokens=input_request.output_len, + lora_adapter_id=input_request.adapter_id, + ) + output = RequestFuncOutput() + output.input_request = input_request + generated_token_list, ttft, latency = await grpc_async_request( + api_url, request + ) + output.ttft = ttft + output.latency = latency + output.generated_token_list = generated_token_list + # generated_token_list is a list of token ids, decode it to generated_text. + output.generated_text = tokenizer.decode(generated_token_list) + output.success = True + return output + + +async def get_request( + input_requests: list[InputRequest], +) -> AsyncGenerator[InputRequest, None]: + input_requests = iter(input_requests) + + for request in input_requests: + yield request + + +async def send_multi_request( + api_url: str, + tokenizer: Any, + input_requests: list[InputRequest], +): + """Send multiple LoRA adapter requests.""" + tasks = [] + async for request in get_request(input_requests): + tasks.append( + asyncio.create_task( + send_request( + api_url=api_url, + tokenizer=tokenizer, + input_request=request, + ) + ) + ) + outputs = await asyncio.gather(*tasks) + + return outputs + + +def mock_adapter_requests(total_mock_requests: int): + """Generates a list of mock requests containing mock data.""" + data = [] + for index in range(total_mock_requests): + request = InputRequest() + request.prompt = f"22 year old" + if index == 0: + request.adapter_id = "" + else: + i = (index % 10) + 1 + request.adapter_id = f"test_lora_{i}" + request.output_len = 200 + data.append(request) + return data + + +def main(args: argparse.Namespace): + print(args) + + model_id = args.model + tokenizer_id = args.tokenizer + + api_url = f"{args.server}:{args.port}" + + tokenizer = get_tokenizer(model_id, tokenizer_id) + input_requests = mock_adapter_requests( + args.total_mock_requests + ) # e.g. [("AB", 2, "AB", 3)] + + request_outputs = asyncio.run( + send_multi_request( + api_url=api_url, + tokenizer=tokenizer, + input_requests=input_requests, + ) + ) + + output = [output.to_dict() for output in request_outputs] + + # Process output + for index, output in enumerate(output): + print(f"Prompt: {input_requests[index].prompt}") + print(f"AdapterId: {input_requests[index].adapter_id}") + print(f"Output: {output}") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Sending multiple serving requests to JetStream Server" + ) + parser.add_argument( + "--server", + type=str, + default="0.0.0.0", + help="Server address.", + ) + parser.add_argument("--port", type=str, default=9000) + parser.add_argument( + "--model", + type=str, + default="no_model", + help=( + "Name of the model like llama-2, llama-3, gemma. (it's just used to" + " label the benchmark, pick the tokenizer, the model config is" + " defined in config_lib, and passed as the server config flag when" + " we run the JetStream server)" + ), + ) + parser.add_argument( + "--total-mock-requests", + type=int, + default=3, + help="The maximum number of mock requests to send for benchmark testing.", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="test", + help=( + "Name or path of the tokenizer. (For mock model testing, use the" + " default value)" + ), + ) + + parsed_args = parser.parse_args() + main(parsed_args) diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index e215710f..bd394211 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -44,6 +44,12 @@ False, "Enable client side tokenization with tokenizer.", ) +_LORA_ADAPTER_ID = flags.DEFINE_string( + "lora_adapter_id", + "", + "ID of the adapter for this decode request.", + required=False, +) def _GetResponseAsync( @@ -89,6 +95,7 @@ def main(argv: Sequence[str]) -> None: ), max_tokens=_MAX_TOKENS.value, num_samples=_NUM_SAMPLES.value, + lora_adapter_id=_LORA_ADAPTER_ID.value, ) else: request = jetstream_pb2.DecodeRequest( @@ -97,6 +104,7 @@ def main(argv: Sequence[str]) -> None: ), max_tokens=_MAX_TOKENS.value, num_samples=_NUM_SAMPLES.value, + lora_adapter_id=_LORA_ADAPTER_ID.value, ) return _GetResponseAsync(stub, request)