diff --git a/.gitignore b/.gitignore index 68bc17f..082c0a3 100644 --- a/.gitignore +++ b/.gitignore @@ -156,5 +156,4 @@ cython_debug/ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +# option (not recommended) you can uncomment the following to ignore the entire idea folder. \ No newline at end of file diff --git a/benchmarks/latency.py b/benchmarks/latency.py index 3b79c75..2b9a978 100644 --- a/benchmarks/latency.py +++ b/benchmarks/latency.py @@ -27,6 +27,13 @@ choices=[0, 1], help="0: exeute at GPU (baseline), 1: offload to CPU.", ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="batch size for inference.", + ) + parser.add_argument("--beam_num", type=int, default=1, help="Beam search number.") args = parser.parse_args() @@ -58,15 +65,17 @@ # enough input length break prefill_time, decode_time, hit_rate = model.generate( - text, output_token=output_token, input_token=input_token + [text], output_token=output_token, input_token=input_token ) prefill_time_sum += prefill_time decode_time_sum += decode_time hit_rate_sum += hit_rate - print( - f"input_token: {input_token}, output_token: {output_token}, " - f"prefill_time: {prefill_time_sum / n_sample}, " - f"decode_time: {decode_time_sum / n_sample}, " - f"hit_rate: {hit_rate_sum / n_sample}," - f"{output_token / (prefill_time_sum + decode_time_sum):.2f}token/s" - ) + # write to file + with open("latency.txt", "a") as f: + f.write( + f"input_token: {input_token}, output_token: {output_token}, " + f"prefill_time: {prefill_time_sum / n_sample}, " + f"decode_time: {decode_time_sum / n_sample}, " + f"hit_rate: {hit_rate_sum / n_sample}," + f"{output_token *n_sample/ (prefill_time_sum + decode_time_sum):.2f}token/s\n" + ) diff --git a/src/fiddler/__init__.py b/src/fiddler/__init__.py index 7af8c1b..2b6efc3 100644 --- a/src/fiddler/__init__.py +++ b/src/fiddler/__init__.py @@ -1,2 +1 @@ -from .infer import FiddlerMixtral from .mixtral import FiddlerMixtral diff --git a/src/fiddler/infer.py b/src/fiddler/infer.py index b682bcf..96ccecb 100644 --- a/src/fiddler/infer.py +++ b/src/fiddler/infer.py @@ -3,6 +3,7 @@ from mixtral import FiddlerMixtral + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -33,9 +34,9 @@ default=20, help="Number of tokens to generate.", ) + parser.add_argument("--beam-width", type=int, default=1, help="Beam search width.") args = parser.parse_args() - model = FiddlerMixtral(args) prefill_time, decode_time, hit_rate = model.generate( args.input, output_token=args.n_token diff --git a/src/fiddler/mixtral.py b/src/fiddler/mixtral.py index 6088aab..df40b6a 100644 --- a/src/fiddler/mixtral.py +++ b/src/fiddler/mixtral.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence import transformers @@ -28,9 +29,10 @@ def __init__(self, args): self.past_key_value = transformers.cache_utils.DynamicCache.from_legacy_cache() self.past_key_values_length = 0 self.cpu_offload = args.cpu_offload - + self.beam_width = args.beam_width self.n_layer = len(self.model.layers) self.n_expert = len(self.model.layers[0].block_sparse_moe.experts) + # TODO: find this value based on device config self.latency_cpu = 7 @@ -49,7 +51,7 @@ def __init__(self, args): ) self.set_expert_loc(n_expert_on_gpu) - print(self.expert_loc) + # print(self.expert_loc) self.bring_expert_to_gpu() @@ -355,16 +357,27 @@ def calc_n_expert_on_gpu(self): ) # get the amount of free memory on GPU total_mem = torch.cuda.get_device_properties(self.dev).total_memory - free_mem = total_mem * 0.95 - torch.cuda.memory_allocated(self.dev) + free_mem = total_mem * 0.95 - torch.cuda.memory_allocated(self.dev) # TODO: magic number return int((free_mem) // (n_param * 2)) - def generate(self, text, output_token=20, input_token=None): + def initial_beam_tensor(self, input_tensor): + # transpose tensor of shape (beam_width, seq_len, beam_width) to (beam_width, 1) properly + assert input_tensor.shape[-1] == self.beam_width + input_tensor = input_tensor[:, -1] + row_idx = torch.tensor( + [i * self.beam_width for i in range(input_tensor.shape[0] // self.beam_width)] + ) + output_tensor = input_tensor[row_idx].view(-1, 1) + return output_tensor + + def generate(self, text=None, output_token=20, input_token=None): + torch.set_num_threads(16) # TODO: set appropriately self.past_key_value = transformers.cache_utils.DynamicCache.from_legacy_cache() self.past_key_values_length = 0 self.cnt_expert_hit = 0 self.cnt_expert_all = 0 - + input_ids, position_ids = self.tokenize(text) if input_token is not None: @@ -374,42 +387,90 @@ def generate(self, text, output_token=20, input_token=None): tick = time.time() is_decode = False prefill_time, decode_time = 0, 0 + decode_strings = ["" for _ in range(input_ids.shape[0])] + search_start = False + probs = torch.full((input_ids.shape[0], 1), 1.0) + for i_token in range(output_token): - # tick = time.time() - print(self.tokenizer.decode(input_ids[0, :])) - logits = self.mixtral_forward( - input_ids, - position_ids, - is_decode, - ) - # print('Time:', time.time() - tick) + if self.beam_width == 1: + print(self.tokenizer.decode(input_ids[0])) + # TODO: streaming output for beam search + if is_decode: + for i in range(input_ids.shape[0]): + decode_strings[i] += " " + self.tokenizer.decode(input_ids[i, :]) + + logits = self.mixtral_forward(input_ids, position_ids, is_decode) logits = logits.to("cpu") + # logits.shape: (batch_size, seq_len, vocab_size) + + # normalize logits + logits = F.softmax(logits, dim=-1) + + # greedy search: + # output = torch.argmax(logits, dim=-1) - output = torch.argmax(logits, dim=-1) - self.past_key_values_length += output.shape[-1] - input_ids = output[:, -1].unsqueeze(0).to(self.dev) - position_ids = torch.arange( - self.past_key_values_length, - self.past_key_values_length + 1, - dtype=torch.long, - device=self.dev, + # beam_search: + self.past_key_values_length += logits.shape[1] + if search_start: + new_probs, output = torch.topk(logits, 1, dim=-1) + new_probs = new_probs[:, -1].flatten().view(-1, 1) + else: + new_probs, output = torch.topk(logits, self.beam_width, dim=-1) + new_probs = self.initial_beam_tensor(new_probs) + output = self.initial_beam_tensor(output) + search_start = True + # new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True) + probs = probs * new_probs + + input_ids = output[:, -1].flatten().view(-1, 1).to(self.dev) + # input_ids.shape: (batch_size, seq_len=1) + + position_ids = ( + torch.arange( + self.past_key_values_length, + self.past_key_values_length + 1, + dtype=torch.long, + device=self.dev, + ) + .unsqueeze(0) + .view(-1, 1) ) - position_ids = position_ids.unsqueeze(0).view(-1, 1) + # position_ids.shape: (1, 1) if not is_decode: prefill_time += time.time() - tick tick = time.time() is_decode = True decode_time = time.time() - tick - return prefill_time, decode_time, self.cnt_expert_hit / self.cnt_expert_all + probs = probs.view(-1, self.beam_width) + max_ids = torch.argmax(probs, dim=-1) + + print("--------------------") + print(f"Input: {text}") + print(f"Output: {decode_strings[max_ids[0]]}") + + return ( + prefill_time, + decode_time, + self.cnt_expert_hit / self.cnt_expert_all, + ) def tokenize(self, text): + input_ids = [] encodings = self.tokenizer(text, return_tensors="pt") - input_ids = encodings.input_ids.to(self.dev) + input_id = encodings.input_ids.to(self.dev) + for i in range(self.beam_width): + input_ids.append(input_id[0]) + + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ).to(self.dev) + position_ids = torch.arange( 0, input_ids.shape[-1], dtype=torch.long, device=self.dev ) position_ids = position_ids.unsqueeze(0).view(-1, input_ids.shape[-1]) + return input_ids, position_ids @torch.no_grad() @@ -429,14 +490,18 @@ def mixtral_forward(self, input_ids, position_ids, is_decode): past_key_value=self.past_key_value, use_cache=True, ) + # inps.shape: (batch_size, seq_len/token_num, embed_dim) inps = inps_residual + inps inps_residual = inps inps = layer.post_attention_layernorm(inps) - inps = inps.view(-1, hidden_dim) + # inps.shape: (batch_size*seq_len*embed_dim/hidden_dim, hidden_dim) router_logits = layer.block_sparse_moe.gate(inps) routing_weights = F.softmax(router_logits, dim=1) + # routing_weights.shape: (batch_size*seq_len, num_experts) routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1) + # routing_weights.shape: (batch_size*seq_len, 2) + # selected_experts.shape: (batch_size*seq_len, 2) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # intermediate variable to store the output of experts @@ -484,7 +549,7 @@ def mixtral_forward(self, input_ids, position_ids, is_decode): # end of one expert - elif not is_decode: + else: # prefill stage with offloading expert_mask = torch.nn.functional.one_hot( selected_experts, num_classes=8 @@ -495,7 +560,6 @@ def mixtral_forward(self, input_ids, position_ids, is_decode): cost_per_expert = np.zeros( (len(experts), 2), dtype=float ) # 0: CPU, 1: GPU - for i_expert in range(len(experts)): idx, top_2 = torch.where(expert_mask[i_expert]) idxs.append(idx) @@ -510,11 +574,10 @@ def mixtral_forward(self, input_ids, position_ids, is_decode): cost_per_expert[i_expert, 1] = 0 self.cnt_expert_hit += top_2.shape[0] self.cnt_expert_all += top_2.shape[0] - + # second, partition experts processing between CPU and GPU so that we can minimize: # max(sum of cost at CPU, sum of cost at GPU) - # greedy algorithm is just as there are only 8 experts for - # Mixtral + # greedy algorithm is just as there are only 8 experts for Mixtral best_config = -1 best_cost = float("inf") for config in range(1 << len(experts)): @@ -538,28 +601,6 @@ def mixtral_forward(self, input_ids, position_ids, is_decode): else: gpu_experts.append(i_expert) - def run_expert_in_thread(): - for i_expert in cpu_experts: - top_2_list = top_2s[i_expert].tolist() - idx_list = idxs[i_expert].tolist() - current_state = inps[None, top_2_list].reshape(-1, hidden_dim) - current_state = self.run_expert_at_cpu( - i_layer, - i_expert, - current_state.to("cpu", non_blocking=True), - routing_weights[top_2_list, idx_list, None].to( - "cpu", non_blocking=True - ), - ) - inps_after_experts.index_add_( - 0, - top_2s[i_expert].to(self.dev, non_blocking=True), - current_state.to(self.dev, non_blocking=True), - ) - - thread = threading.Thread(target=run_expert_in_thread) - thread.start() - for i_expert in gpu_experts: top_2_list = top_2s[i_expert].tolist() idx_list = idxs[i_expert].tolist() @@ -581,44 +622,23 @@ def run_expert_in_thread(): current_state.to(self.dev, non_blocking=True), ) - thread.join() - - else: - # decode stage with offloading - assert input_ids.shape[-1] == 1 - expert_0, expert_1 = int(selected_experts[0][0]), int( - selected_experts[0][1] - ) - routing_weights_0, routing_weights_1 = ( - routing_weights[:, 0, None], - routing_weights[:, 1, None], - ) - - assert expert_0 != expert_1 - - self.cnt_expert_all += 2 - - if self.is_expert_in_gpu(i_layer, expert_0): - inps_after_experts += experts[expert_0](inps, routing_weights_0) - self.cnt_expert_hit += 1 - else: - inps_after_experts += self.run_expert_at_cpu( - i_layer, - expert_0, - inps.to("cpu", non_blocking=True), - routing_weights_0.to("cpu", non_blocking=True), - ).to(self.dev, non_blocking=True) - - if self.is_expert_in_gpu(i_layer, expert_1): - inps_after_experts += experts[expert_1](inps, routing_weights_1) - self.cnt_expert_hit += 1 - else: - inps_after_experts += self.run_expert_at_cpu( + for i_expert in cpu_experts: + top_2_list = top_2s[i_expert].tolist() + idx_list = idxs[i_expert].tolist() + current_state = inps[None, top_2_list].reshape(-1, hidden_dim) + current_state = self.run_expert_at_cpu( i_layer, - expert_1, - inps.to("cpu", non_blocking=True), - routing_weights_1.to("cpu", non_blocking=True), - ).to(self.dev, non_blocking=True) + i_expert, + current_state.to("cpu", non_blocking=True), + routing_weights[top_2_list, idx_list, None].to( + "cpu", non_blocking=True + ), + ) + inps_after_experts.index_add_( + 0, + top_2s[i_expert].to(self.dev, non_blocking=True), + current_state.to(self.dev, non_blocking=True), + ) # addition because there's residual connection over moe layer inps = inps_residual + inps_after_experts.reshape(original_inps_shape)