Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 20, 2024
1 parent 91c5f3c commit 4884b3a
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def __init__(
self._init_rope()

self._init_func = []

def register_init_func(self, func):
self._init_func.append(func)

def post_init(self):
for func in self._init_func:
func(self)
Expand Down Expand Up @@ -690,13 +690,13 @@ def __init__(
self.model.layers[layer_idx].self_attn.post_init()

self.model.layers[layer_idx].self_attn.pruner = self.pruner

# Initialize weights and apply final processing
self.post_init()

def _generate(**kwargs):
self.pruner.before_generate(self, **kwargs)
result = self.ori_generate(**kwargs)
self.pruner.before_generate(self, **kwargs)
result = self.ori_generate(**kwargs)
self.pruner.after_generate(self, **kwargs)
return result

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import PruneConfig, KVPruner
from .h2o import H2OConfig, H2OKVPruner
from .h2o import H2OConfig, H2OKVPruner
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

class PruneConfig(dict):
def __init__(self, real_drop=True):
self.real_drop = real_drop
Expand All @@ -24,7 +38,7 @@ def get_mask(self, model, **kwargs):
@property
def past_length(self):
return self._past_length

@past_length.setter
def past_length(self, value):
self._past_length = value
self._past_length = value
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ class H2OKVPruner(KVPruner):
def __init__(self, config: H2OConfig) -> None:
self.config = config
self.real_drop = self.config.real_drop


def self_attn_init(self, module):
module.h2o_kv_cache = H2OKVCache(
self.config.heavy_ratio,
Expand All @@ -207,7 +207,7 @@ def self_attn_init(self, module):
self.config.h2o_min_seqlen,
self.config.mean
)

def before_generate(self, model, **kwargs):
self.past_length = 0
max_length = kwargs['max_new_tokens'] if kwargs.get('max_new_tokens') else kwargs['max_length']
Expand All @@ -222,15 +222,15 @@ def after_generate(self, model, **kwargs):
for _, module in model.named_modules():
if "Attention" in module.__class__.__name__:
module.h2o_kv_cache.clean_scores()

def prune(self, module, query_states, key_states, value_states, causal_mask=None, **kwargs):
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(module.head_dim)
if causal_mask is not None: # no matter the length, we just slice it
attn_weights = attn_weights + causal_mask
if not self.config.real_drop:
module.h2o_kv_cache.clean_scores()
return module.h2o_kv_cache(attn_weights, key_states, value_states, **kwargs)

def get_mask(self, module, query_states, key_states, value_states, causal_mask=None, **kwargs):
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(module.head_dim)
if causal_mask is not None: # no matter the length, we just slice it
Expand All @@ -240,4 +240,4 @@ def get_mask(self, module, query_states, key_states, value_states, causal_mask=N
self.config.recent_ratio,
attn_weights,
local=self.config.local)
return mask
return mask

0 comments on commit 4884b3a

Please sign in to comment.