Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Jul 16, 2024
1 parent d600112 commit b1ab771
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ omit =
*/intel_extension_for_transformers/langchain/**
*/intel_extension_for_transformers/llama_index/**
*/intel_extension_for_transformers/transformers/utils/get_throughput.py
*/intel_extension_for_transformers/transformers/kv_cache_compression/models/**
exclude_lines =
pragma: no cover
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@

from .prune.h2o import H2OConfig, H2OKVPruner
from .models.modeling_llama import LlamaForCausalLM
from .models.modeling_gaudi_llama import GaudiLlamaForCausalLM
from intel_extension_for_transformers.transformers.utils.utility import LazyImport

GaudiLlamaForCausalLM = LazyImport(".models.modeling_gaudi_llama.GaudiLlamaForCausalLM")
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,10 @@ def __init__(
# Initialize weights and apply final processing
self.post_init()

def _generate(**kwargs):
self.pruner.before_generate(self, **kwargs)
result = self.ori_generate(**kwargs)
self.pruner.after_generate(self, **kwargs)
def _generate(*args, **kwargs):
self.pruner.before_generate(self, *args, **kwargs)
result = self.ori_generate(*args, **kwargs)
self.pruner.after_generate(self, *args, **kwargs)
return result

self.ori_generate = self.generate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def local_heavy_hitter_mask(attn_weights, heavy_budget, no_padding_seq_length=No
for token_index in range(heavy_budget+padding_length, seq_length):

tmp_attn_index = nn.functional.softmax(
attn_weights[:,:,token_index,:], dim=-1, dtype=torch.float32).to(dtype_attn_weights)
attn_weights[:,:,token_index,:], dim=-1, dtype=torch.float32).to(dtype_attn_weights)
_, tmp_topk_index = accumulated_attention_score.topk(k=heavy_budget-1, dim=-1)
zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool)
mask_bottom_index = zeros_index.scatter(-1, tmp_topk_index, True) #(head, keys)
Expand Down Expand Up @@ -123,6 +123,9 @@ def __init__(
mean=False
):
## bsz, num_heads, seq_len, head_dim
assert 0 <= heavy_ratio <= 1 and 0 <= recent_ratio <= 1, "ratio should be in [0, 1]"
assert heavy_budget is None or heavy_budget >= 0, "heavy_budget should be non-negative"
assert recent_budget is None or recent_budget >= 0, "recent_budget should be non-negative"
self.heavy_ratio = heavy_ratio
self.recent_ratio = recent_ratio
self.heavy_budget = heavy_budget
Expand Down Expand Up @@ -221,15 +224,22 @@ def self_attn_init(self, module):
)

def before_generate(self, model, inputs, *args, **kwargs):
assert self.real_drop is True, 'H2O only support real drop mode when use generate func.'
self.past_length = 0
max_length = kwargs['max_new_tokens'] if kwargs.get('max_new_tokens') else kwargs['max_length']
max_length += inputs.size(-1)
if kwargs.get('max_new_tokens', None):
max_length = kwargs['max_new_tokens'] + inputs.size(-1)
elif kwargs.get('max_length', None):
max_length = kwargs['max_length']
else:
max_length = model.config.max_length
if max_length <= inputs.size(-1):
max_length += inputs.size(-1)
for _, module in model.named_modules():
if "Attention" in module.__class__.__name__:
if module.h2o_kv_cache.heavy_budget is None:
module.h2o_kv_cache.heavy_budget = int(max_length * module.h2o_kv_cache.heavy_ratio)
module.h2o_kv_cache.heavy_budget = round(max_length * module.h2o_kv_cache.heavy_ratio)
if module.h2o_kv_cache.recent_budget is None:
module.h2o_kv_cache.recent_budget = int(max_length * module.h2o_kv_cache.recent_ratio)
module.h2o_kv_cache.recent_budget = round(max_length * module.h2o_kv_cache.recent_ratio)
if self.prune_kv_cache_size is None:
self.prune_kv_cache_size = module.h2o_kv_cache.recent_budget + module.h2o_kv_cache.heavy_budget

Expand Down

0 comments on commit b1ab771

Please sign in to comment.