Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor adapter weight loading and mapping #2193

Merged
merged 8 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,10 @@ fn main() -> Result<(), LauncherError> {
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
// skip download if a path is provided
if adapter.contains('=') {
continue;
}
download_convert_model(
adapter,
None,
Expand Down
187 changes: 187 additions & 0 deletions server/tests/utils/test_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights


def test_get_attn_weights():
# create a mock layer
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

# call the function
result = get_attn_weights(2, mock_layer)

# assert the result
expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected


def test_get_mlp_weights_with_gate_up_proj():
# create a mock layer with gate_up_proj
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

# call the function
result = get_mlp_weights(3, mock_layer)

# assert the result
expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected


def test_get_mlp_weights_without_gate_up_proj():
# create a mock layer without gate_up_proj
mock_layer = Mock()
mock_layer.mlp = Mock(spec=[])

# call the function
result = get_mlp_weights(1, mock_layer)

# assert the result
assert result == {}


@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_attn_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

result = get_attn_weights(layer_index, mock_layer)

for k in ["q", "k", "v"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.self_attn.{k}_proj"
)

assert (layer_index, "o_proj") in result
assert (
result[(layer_index, "o_proj")][0]
== f"model.layers.{layer_index}.self_attn.o_proj"
)


@pytest.mark.parametrize("layer_index", [0, 1, 5])
def test_get_mlp_weights_different_layers(layer_index):
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

result = get_mlp_weights(layer_index, mock_layer)

for k in ["gate", "up", "down"]:
assert (layer_index, f"{k}_proj") in result
assert (
result[(layer_index, f"{k}_proj")][0]
== f"model.layers.{layer_index}.mlp.{k}_proj"
)


def test_get_attn_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

result = get_attn_weights(2, mock_layer)

expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected


def test_get_mlp_weights_llama_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

result = get_mlp_weights(3, mock_layer)

expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected


def test_get_attn_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.self_attn.query_key_value = Mock()
mock_layer.self_attn.o_proj = Mock()

result = get_attn_weights(2, mock_layer)

expected = {
(2, "q_proj"): (
"model.layers.2.self_attn.q_proj",
mock_layer.self_attn.query_key_value,
),
(2, "k_proj"): (
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
),
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
}
assert result == expected


def test_get_mlp_weights_gemma_compatibility():
mock_layer = Mock()
mock_layer.mlp.gate_proj = Mock()
mock_layer.mlp.up_proj = Mock()
mock_layer.mlp.down_proj = Mock()

# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
# This is necessary because the use of `Mock` automatically creates any
# attributes that are accessed, even if they don't exist in the actual
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
# follow the wrong execution path and return an incorrect result.
del mock_layer.mlp.gate_up_proj

result = get_mlp_weights(3, mock_layer)

expected = {
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
}
assert result == expected
13 changes: 1 addition & 12 deletions server/text_generation_server/adapters/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Set, Tuple

import torch

Expand All @@ -31,14 +31,3 @@ def map_weights_for_model(
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass

@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass
60 changes: 26 additions & 34 deletions server/text_generation_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,6 @@ def map_weights_for_model(
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names

def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)

@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
Expand Down Expand Up @@ -192,22 +176,38 @@ def _transpose_weights(self):
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]

# prepare pre-loaded lora weights for use in the model.
#
# this method processes and organizes lora weights for a specific layer type across all layers:
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
# - retrieves weights from `module_map` based on the `layer_type`.
# - processes `nlayers` number of layers.
# - converts weights to the specified `dtype`.
# - shards weights across `world_size` number of processes using the `process_group`.
# - maps weights to specific layers using `target_to_layer`.
# - tracks `unused_weight_names` to identify any unused weights.
#
# the method handles weight transposition, scaling, and padding to ensure compatibility
# with SGMV or BGMV operations.
@classmethod
def load(
def prepare_weights(
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
nlayers: int,
dtype: torch.dtype,
world_size: int,
process_group: ProcessGroup,
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers

for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]
weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device

Expand All @@ -216,10 +216,10 @@ def load(
return None

lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_a = lora_a.to(base_device, dtype)

lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
lora_b = lora_b.to(base_device, dtype)

scale = get_scaling_factor(
config.lora_alpha,
Expand All @@ -236,12 +236,8 @@ def load(
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

# pad lora ranks to be compatible with sgmv
lora_a_list = [
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]

if lora_a_list:
# update rank if it was padded
Expand All @@ -252,8 +248,8 @@ def load(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
process_group=process_group,
),
config,
)
Expand Down Expand Up @@ -293,10 +289,6 @@ def can_vectorize(self, pg: ProcessGroup) -> bool:
for rank_data in self.rank_data.values()
)

@classmethod
def key(cls) -> str:
return "lora"

@classmethod
def load(
self,
Expand Down
16 changes: 2 additions & 14 deletions server/text_generation_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ class BatchAdapterWeights(ABC):
def has_adapter(self, adapter_index: int) -> bool:
pass

@abstractclassmethod
def key(cls) -> str:
pass

@abstractclassmethod
def load(
cls,
Expand All @@ -71,13 +67,6 @@ def remove_adapter(self, adapter_idx: int):
return
del self.adapter_weights[adapter_idx]

@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)

def is_empty(self) -> bool:
return len(self.adapter_weights) == 0

Expand All @@ -101,7 +90,7 @@ def get_data(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
batch_data = batched_weights
return batch_data


Expand Down Expand Up @@ -133,8 +122,7 @@ def from_meta(
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get("lora")
for lora_data in self.data.values():
if lora_data is None:
continue

Expand Down
Loading
Loading