diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py index fd98d194667..3a961b9c1e0 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_llama.py @@ -137,10 +137,10 @@ def __init__( self._init_rope() self._init_func = [] - + def register_init_func(self, func): self._init_func.append(func) - + def post_init(self): for func in self._init_func: func(self) @@ -690,13 +690,13 @@ def __init__( self.model.layers[layer_idx].self_attn.post_init() self.model.layers[layer_idx].self_attn.pruner = self.pruner - + # Initialize weights and apply final processing self.post_init() def _generate(**kwargs): - self.pruner.before_generate(self, **kwargs) - result = self.ori_generate(**kwargs) + self.pruner.before_generate(self, **kwargs) + result = self.ori_generate(**kwargs) self.pruner.after_generate(self, **kwargs) return result diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/__init__.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/__init__.py index 33a5058953d..7f9ede223e8 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/__init__.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/__init__.py @@ -1,2 +1,16 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .base import PruneConfig, KVPruner -from .h2o import H2OConfig, H2OKVPruner \ No newline at end of file +from .h2o import H2OConfig, H2OKVPruner diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/base.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/base.py index a6000057926..ee0e5b25b72 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/base.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/base.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + class PruneConfig(dict): def __init__(self, real_drop=True): self.real_drop = real_drop @@ -25,7 +39,7 @@ def get_mask(self, model, **kwargs): @property def past_length(self): return self._past_length - + @past_length.setter def past_length(self, value): - self._past_length = value \ No newline at end of file + self._past_length = value diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py index e8346c7209d..802ebc495be 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py @@ -242,7 +242,7 @@ def prune(self, module, query_states, key_states, value_states, causal_mask=None if not self.config.real_drop: module.h2o_kv_cache.clean_scores() return module.h2o_kv_cache(attn_weights, key_states, value_states, **kwargs) - + def get_mask(self, module, query_states, key_states, value_states, causal_mask=None, **kwargs): attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(module.head_dim) if causal_mask is not None: # no matter the length, we just slice it @@ -252,4 +252,4 @@ def get_mask(self, module, query_states, key_states, value_states, causal_mask=N self.config.recent_ratio, attn_weights, local=self.config.local) - return mask \ No newline at end of file + return mask