diff --git a/notebooks/aqlm_cuda_graph.ipynb b/notebooks/aqlm_cuda_graph.ipynb index 3eb8e83d..f44dd31a 100644 --- a/notebooks/aqlm_cuda_graph.ipynb +++ b/notebooks/aqlm_cuda_graph.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "!pip install aqlm[gpu]>=1.1.0\n", "!pip install accelerate>=0.27.0\n", - "!pip install git+https://github.com/huggingface/transformers.git@main" + "!pip install transformers>=4.41.0" ] }, { @@ -210,12 +210,16 @@ "source": [ "import torch\n", "\n", - "def decode_one_tokens(model, cur_token, input_pos, cache_position):\n", + "def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):\n", " logits = model(\n", - " cur_token, position_ids=None, cache_position=cache_position, return_dict=False, use_cache=True\n", + " cur_token,\n", + " position_ids=input_pos,\n", + " cache_position=cache_position,\n", + " past_key_values=past_key_values,\n", + " return_dict=False,\n", + " use_cache=True\n", " )[0]\n", - " new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)\n", - "\n", + " new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]\n", " return new_token\n", "\n", "MAX_NEW_TOKENS = 128" @@ -242,7 +246,14 @@ "\n", "input_ids = tokenizer(\"I'm AQLM, \", return_tensors=\"pt\").to(\"cuda\")[\"input_ids\"]\n", "seq_length = input_ids.shape[1]\n", - "quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + MAX_NEW_TOKENS * 2 + 1)" + "\n", + "past_key_values = StaticCache(\n", + " quantized_model.config,\n", + " 1,\n", + " seq_length + MAX_NEW_TOKENS * 2 + 1,\n", + " quantized_model.device,\n", + " quantized_model.dtype\n", + ")" ] }, { @@ -284,7 +295,9 @@ }, "outputs": [], "source": [ - "logits = quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]\n", + "logits = quantized_model(\n", + " input_ids, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True\n", + ")[0]\n", "next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)\n", "generated_ids[:, [seq_length]] = next_token" ] @@ -314,8 +327,8 @@ " cache_position = torch.tensor([seq_length + 1], device=\"cuda\")\n", " for _ in range(1, MAX_NEW_TOKENS):\n", " with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):\n", - " next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position)\n", - " generated_ids.index_copy_(1, cache_position, next_token)\n", + " next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position, past_key_values)\n", + " generated_ids[:, cache_position] = next_token.int()\n", " cache_position += 1" ] }, @@ -363,8 +376,8 @@ "with torch.no_grad():\n", " for _ in range(MAX_NEW_TOKENS):\n", " with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):\n", - " next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position)\n", - " generated_ids.index_copy_(1, cache_position, next_token)\n", + " next_token = decode_one_tokens(quantized_model, next_token.clone(), None, cache_position, past_key_values)\n", + " generated_ids[:, cache_position] = next_token.int()\n", " cache_position += 1\n", "end = time.perf_counter()" ]