Skip to content

Commit

Permalink
Merge pull request #4 from tang-t21/main
Browse files Browse the repository at this point in the history
Clean up and add new features
  • Loading branch information
kamahori authored Apr 25, 2024
2 parents 1ad7a6c + ecf8abd commit 83c16a3
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 101 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,4 @@ cython_debug/
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
25 changes: 17 additions & 8 deletions benchmarks/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
choices=[0, 1],
help="0: exeute at GPU (baseline), 1: offload to CPU.",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="batch size for inference.",
)
parser.add_argument("--beam_num", type=int, default=1, help="Beam search number.")

args = parser.parse_args()

Expand Down Expand Up @@ -58,15 +65,17 @@
# enough input length
break
prefill_time, decode_time, hit_rate = model.generate(
text, output_token=output_token, input_token=input_token
[text], output_token=output_token, input_token=input_token
)
prefill_time_sum += prefill_time
decode_time_sum += decode_time
hit_rate_sum += hit_rate
print(
f"input_token: {input_token}, output_token: {output_token}, "
f"prefill_time: {prefill_time_sum / n_sample}, "
f"decode_time: {decode_time_sum / n_sample}, "
f"hit_rate: {hit_rate_sum / n_sample},"
f"{output_token / (prefill_time_sum + decode_time_sum):.2f}token/s"
)
# write to file
with open("latency.txt", "a") as f:
f.write(
f"input_token: {input_token}, output_token: {output_token}, "
f"prefill_time: {prefill_time_sum / n_sample}, "
f"decode_time: {decode_time_sum / n_sample}, "
f"hit_rate: {hit_rate_sum / n_sample},"
f"{output_token *n_sample/ (prefill_time_sum + decode_time_sum):.2f}token/s\n"
)
1 change: 0 additions & 1 deletion src/fiddler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .infer import FiddlerMixtral
from .mixtral import FiddlerMixtral
3 changes: 2 additions & 1 deletion src/fiddler/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mixtral import FiddlerMixtral


if __name__ == "__main__":
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -33,9 +34,9 @@
default=20,
help="Number of tokens to generate.",
)
parser.add_argument("--beam-width", type=int, default=1, help="Beam search width.")

args = parser.parse_args()

model = FiddlerMixtral(args)
prefill_time, decode_time, hit_rate = model.generate(
args.input, output_token=args.n_token
Expand Down
198 changes: 109 additions & 89 deletions src/fiddler/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import transformers


Expand All @@ -28,9 +29,10 @@ def __init__(self, args):
self.past_key_value = transformers.cache_utils.DynamicCache.from_legacy_cache()
self.past_key_values_length = 0
self.cpu_offload = args.cpu_offload

self.beam_width = args.beam_width
self.n_layer = len(self.model.layers)
self.n_expert = len(self.model.layers[0].block_sparse_moe.experts)


# TODO: find this value based on device config
self.latency_cpu = 7
Expand All @@ -49,7 +51,7 @@ def __init__(self, args):
)

self.set_expert_loc(n_expert_on_gpu)
print(self.expert_loc)
# print(self.expert_loc)

self.bring_expert_to_gpu()

Expand Down Expand Up @@ -355,16 +357,27 @@ def calc_n_expert_on_gpu(self):
)
# get the amount of free memory on GPU
total_mem = torch.cuda.get_device_properties(self.dev).total_memory
free_mem = total_mem * 0.95 - torch.cuda.memory_allocated(self.dev)
free_mem = total_mem * 0.95 - torch.cuda.memory_allocated(self.dev) # TODO: magic number
return int((free_mem) // (n_param * 2))

def generate(self, text, output_token=20, input_token=None):
def initial_beam_tensor(self, input_tensor):
# transpose tensor of shape (beam_width, seq_len, beam_width) to (beam_width, 1) properly
assert input_tensor.shape[-1] == self.beam_width
input_tensor = input_tensor[:, -1]
row_idx = torch.tensor(
[i * self.beam_width for i in range(input_tensor.shape[0] // self.beam_width)]
)
output_tensor = input_tensor[row_idx].view(-1, 1)
return output_tensor

def generate(self, text=None, output_token=20, input_token=None):
torch.set_num_threads(16) # TODO: set appropriately
self.past_key_value = transformers.cache_utils.DynamicCache.from_legacy_cache()
self.past_key_values_length = 0

self.cnt_expert_hit = 0
self.cnt_expert_all = 0

input_ids, position_ids = self.tokenize(text)

if input_token is not None:
Expand All @@ -374,42 +387,90 @@ def generate(self, text, output_token=20, input_token=None):
tick = time.time()
is_decode = False
prefill_time, decode_time = 0, 0
decode_strings = ["" for _ in range(input_ids.shape[0])]
search_start = False
probs = torch.full((input_ids.shape[0], 1), 1.0)

for i_token in range(output_token):
# tick = time.time()
print(self.tokenizer.decode(input_ids[0, :]))
logits = self.mixtral_forward(
input_ids,
position_ids,
is_decode,
)
# print('Time:', time.time() - tick)
if self.beam_width == 1:
print(self.tokenizer.decode(input_ids[0]))
# TODO: streaming output for beam search
if is_decode:
for i in range(input_ids.shape[0]):
decode_strings[i] += " " + self.tokenizer.decode(input_ids[i, :])

logits = self.mixtral_forward(input_ids, position_ids, is_decode)

logits = logits.to("cpu")
# logits.shape: (batch_size, seq_len, vocab_size)

# normalize logits
logits = F.softmax(logits, dim=-1)

# greedy search:
# output = torch.argmax(logits, dim=-1)

output = torch.argmax(logits, dim=-1)
self.past_key_values_length += output.shape[-1]
input_ids = output[:, -1].unsqueeze(0).to(self.dev)
position_ids = torch.arange(
self.past_key_values_length,
self.past_key_values_length + 1,
dtype=torch.long,
device=self.dev,
# beam_search:
self.past_key_values_length += logits.shape[1]
if search_start:
new_probs, output = torch.topk(logits, 1, dim=-1)
new_probs = new_probs[:, -1].flatten().view(-1, 1)
else:
new_probs, output = torch.topk(logits, self.beam_width, dim=-1)
new_probs = self.initial_beam_tensor(new_probs)
output = self.initial_beam_tensor(output)
search_start = True
# new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True)
probs = probs * new_probs

input_ids = output[:, -1].flatten().view(-1, 1).to(self.dev)
# input_ids.shape: (batch_size, seq_len=1)

position_ids = (
torch.arange(
self.past_key_values_length,
self.past_key_values_length + 1,
dtype=torch.long,
device=self.dev,
)
.unsqueeze(0)
.view(-1, 1)
)
position_ids = position_ids.unsqueeze(0).view(-1, 1)
# position_ids.shape: (1, 1)
if not is_decode:
prefill_time += time.time() - tick
tick = time.time()
is_decode = True
decode_time = time.time() - tick
return prefill_time, decode_time, self.cnt_expert_hit / self.cnt_expert_all
probs = probs.view(-1, self.beam_width)
max_ids = torch.argmax(probs, dim=-1)

print("--------------------")
print(f"Input: {text}")
print(f"Output: {decode_strings[max_ids[0]]}")

return (
prefill_time,
decode_time,
self.cnt_expert_hit / self.cnt_expert_all,
)

def tokenize(self, text):
input_ids = []
encodings = self.tokenizer(text, return_tensors="pt")
input_ids = encodings.input_ids.to(self.dev)
input_id = encodings.input_ids.to(self.dev)
for i in range(self.beam_width):
input_ids.append(input_id[0])

input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
).to(self.dev)

position_ids = torch.arange(
0, input_ids.shape[-1], dtype=torch.long, device=self.dev
)
position_ids = position_ids.unsqueeze(0).view(-1, input_ids.shape[-1])

return input_ids, position_ids

@torch.no_grad()
Expand All @@ -429,14 +490,18 @@ def mixtral_forward(self, input_ids, position_ids, is_decode):
past_key_value=self.past_key_value,
use_cache=True,
)
# inps.shape: (batch_size, seq_len/token_num, embed_dim)
inps = inps_residual + inps
inps_residual = inps
inps = layer.post_attention_layernorm(inps)

inps = inps.view(-1, hidden_dim)
# inps.shape: (batch_size*seq_len*embed_dim/hidden_dim, hidden_dim)
router_logits = layer.block_sparse_moe.gate(inps)
routing_weights = F.softmax(router_logits, dim=1)
# routing_weights.shape: (batch_size*seq_len, num_experts)
routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
# routing_weights.shape: (batch_size*seq_len, 2)
# selected_experts.shape: (batch_size*seq_len, 2)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

# intermediate variable to store the output of experts
Expand Down Expand Up @@ -484,7 +549,7 @@ def mixtral_forward(self, input_ids, position_ids, is_decode):

# end of one expert

elif not is_decode:
else:
# prefill stage with offloading
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=8
Expand All @@ -495,7 +560,6 @@ def mixtral_forward(self, input_ids, position_ids, is_decode):
cost_per_expert = np.zeros(
(len(experts), 2), dtype=float
) # 0: CPU, 1: GPU

for i_expert in range(len(experts)):
idx, top_2 = torch.where(expert_mask[i_expert])
idxs.append(idx)
Expand All @@ -510,11 +574,10 @@ def mixtral_forward(self, input_ids, position_ids, is_decode):
cost_per_expert[i_expert, 1] = 0
self.cnt_expert_hit += top_2.shape[0]
self.cnt_expert_all += top_2.shape[0]

# second, partition experts processing between CPU and GPU so that we can minimize:
# max(sum of cost at CPU, sum of cost at GPU)
# greedy algorithm is just as there are only 8 experts for
# Mixtral
# greedy algorithm is just as there are only 8 experts for Mixtral
best_config = -1
best_cost = float("inf")
for config in range(1 << len(experts)):
Expand All @@ -538,28 +601,6 @@ def mixtral_forward(self, input_ids, position_ids, is_decode):
else:
gpu_experts.append(i_expert)

def run_expert_in_thread():
for i_expert in cpu_experts:
top_2_list = top_2s[i_expert].tolist()
idx_list = idxs[i_expert].tolist()
current_state = inps[None, top_2_list].reshape(-1, hidden_dim)
current_state = self.run_expert_at_cpu(
i_layer,
i_expert,
current_state.to("cpu", non_blocking=True),
routing_weights[top_2_list, idx_list, None].to(
"cpu", non_blocking=True
),
)
inps_after_experts.index_add_(
0,
top_2s[i_expert].to(self.dev, non_blocking=True),
current_state.to(self.dev, non_blocking=True),
)

thread = threading.Thread(target=run_expert_in_thread)
thread.start()

for i_expert in gpu_experts:
top_2_list = top_2s[i_expert].tolist()
idx_list = idxs[i_expert].tolist()
Expand All @@ -581,44 +622,23 @@ def run_expert_in_thread():
current_state.to(self.dev, non_blocking=True),
)

thread.join()

else:
# decode stage with offloading
assert input_ids.shape[-1] == 1
expert_0, expert_1 = int(selected_experts[0][0]), int(
selected_experts[0][1]
)
routing_weights_0, routing_weights_1 = (
routing_weights[:, 0, None],
routing_weights[:, 1, None],
)

assert expert_0 != expert_1

self.cnt_expert_all += 2

if self.is_expert_in_gpu(i_layer, expert_0):
inps_after_experts += experts[expert_0](inps, routing_weights_0)
self.cnt_expert_hit += 1
else:
inps_after_experts += self.run_expert_at_cpu(
i_layer,
expert_0,
inps.to("cpu", non_blocking=True),
routing_weights_0.to("cpu", non_blocking=True),
).to(self.dev, non_blocking=True)

if self.is_expert_in_gpu(i_layer, expert_1):
inps_after_experts += experts[expert_1](inps, routing_weights_1)
self.cnt_expert_hit += 1
else:
inps_after_experts += self.run_expert_at_cpu(
for i_expert in cpu_experts:
top_2_list = top_2s[i_expert].tolist()
idx_list = idxs[i_expert].tolist()
current_state = inps[None, top_2_list].reshape(-1, hidden_dim)
current_state = self.run_expert_at_cpu(
i_layer,
expert_1,
inps.to("cpu", non_blocking=True),
routing_weights_1.to("cpu", non_blocking=True),
).to(self.dev, non_blocking=True)
i_expert,
current_state.to("cpu", non_blocking=True),
routing_weights[top_2_list, idx_list, None].to(
"cpu", non_blocking=True
),
)
inps_after_experts.index_add_(
0,
top_2s[i_expert].to(self.dev, non_blocking=True),
current_state.to(self.dev, non_blocking=True),
)

# addition because there's residual connection over moe layer
inps = inps_residual + inps_after_experts.reshape(original_inps_shape)
Expand Down

0 comments on commit 83c16a3

Please sign in to comment.