Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compressed-tensors support for KV cache Quantization #103

Open
wants to merge 1 commit into
base: upstream-a564d10af
Choose a base branch
from

Conversation

dbogunowicz
Copy link

@dbogunowicz dbogunowicz commented Jun 20, 2024

Feature description

Implements quantized kv cache support for the transformer models that have been quantized using compressed-tensors.

Introduces:

  • CompressedTensorsQuantizedCacheConfig - a config object that stores the static qparams for quantizing/dequantizing KV Cache
  • minor improvements to the QuantizedCache interface, to make it more general in the future for other implementations
  • CompressedTensorsCache - very lightweight wrapper around QuantizedCache, that allows us to enable kv cache quantization

Manual testing:

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import CompressedTensorsQuantizedCacheConfig
import torch
import time 
investigate_mem_consumption = True

model_id = "/root/compressed-tensors/llama1.1b_new_quant_out" # quantized model with kv cache quantization enabled

tokenizer = AutoTokenizer.from_pretrained("Xenova/llama2.c-stories15M") # somehow the tokenizer is missing from the model in `model_id`
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager").to("cuda:0")

inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

cache_config = CompressedTensorsQuantizedCacheConfig.from_pretrained(model_id)

out_quant = model.generate(**inputs, cache_implementation="quantized", cache_config=cache_config, min_new_tokens=40, return_dict_in_generate=True)
out = model.generate(**inputs, min_new_tokens=40, return_dict_in_generate=True)

assert out_quant.sequences.allclose(out.sequences) # assert same tokens get generated regardless of cache type
assert out_quant.past_key_values._quantized_key_cache[-1].dtype == torch.int8 # assert that we are actually caching quantized tensors

Note

Compatible with this branch of compressed-tensors: neuralmagic/compressed-tensors#86
Pending missing items: Add tests, add compressed-tensors as transformers dependency, fix ugly import issues (circular imports between transformers and compressed-tensors)

@@ -16,6 +19,13 @@
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer

if is_compressed_tensors_available() or True: # hack for now
Copy link
Author

@dbogunowicz dbogunowicz Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take care of this hack once the PR is in the "landable" state

self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx], cache_type="key")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of rewriting update() method for CompressedTensorsCache I decided to expand the interface.


@staticmethod
def _establish_quant_dtype(num_bits: int, type_: str) -> torch.dtype:
if num_bits == 8 and type_ == "int":
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should potentially add more supported quantization types, to be discussed @bfineran

@dbogunowicz dbogunowicz requested review from mgoin, Satrat and bfineran June 20, 2024 12:57
@dbogunowicz
Copy link
Author

@mgoin

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import CompressedTensorsQuantizedCacheConfig, CompressedTensorsCache

model_id = "/root/compressed-tensors/llama1.1b_new_quant_out" # quantized model with kv cache quantization enabled

tokenizer = AutoTokenizer.from_pretrained("Xenova/llama2.c-stories15M") # somehow the tokenizer is missing from the model in `model_id`
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager").to("cuda:0")

inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

cache_config = CompressedTensorsQuantizedCacheConfig.from_pretrained(model_id)
cache_object = CompressedTensorsCache(cache_config)

out = model(**inputs)
out_quant = model(**inputs, past_key_values=cache_object)

assert (out.logits == out_quant.logits).all()

working fine!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant