Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial Sample-Layer Attention for GPTQ (PyTorch) #1237

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion model_compression_toolkit/core/common/hessian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, HessianScoresGranularity
from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
)
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import hashlib

import numpy as np
from functools import partial
from tqdm import tqdm
from typing import Callable, List, Dict, Any, Tuple
from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
HessianScoresGranularity, HessianMode
from model_compression_toolkit.logger import Logger
if TYPE_CHECKING:
from model_compression_toolkit.core.common import BaseNode


class HessianInfoService:
Expand Down Expand Up @@ -228,6 +231,51 @@ def compute(self,
return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \
and len(next_iter_remain_samples[0]) > 0 else None

def compute_trackable_per_sample_hessian(self,
hessian_scores_request: HessianScoresRequest,
inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]:
"""
Compute hessian score per image hash. We compute the score directly for images rather than via data generator,
as data generator might yield different images each time, depending on how it was defined,

Args:
hessian_scores_request: hessian scores request
inputs_batch: a list containing a batch of inputs.

Returns:
A dict of Hessian scores per image hash per layer {image hash: {layer: score}}
"""
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))

hessian_score_by_image_hash = {}

if not isinstance(inputs_batch, list):
raise TypeError('Expected a list of inputs')
if len(inputs_batch) > 1:
raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs')

# Get the framework-specific calculator Hessian-approximation scores
fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
input_images=inputs_batch,
hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=self.num_iterations_for_approximation)
hessian_scores = fw_hessian_calculator.compute()
for b in range(inputs_batch[0].shape[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename 'b' (it made me think this is a batch, but this is a single image)

img_hash = self.calc_image_hash(inputs_batch[0][b])
hessian_score_by_image_hash[img_hash] = {
node: score[b] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
}

return hessian_score_by_image_hash

@staticmethod
def calc_image_hash(image):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation

if len(image.shape) != 3:
raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}')
image_bytes = image.astype(np.float32).tobytes()
return hashlib.md5(image_bytes).hexdigest()

def fetch_hessian(self,
hessian_scores_request: HessianScoresRequest,
required_size: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum):
PER_TENSOR = 2


class HessianEstimationDistribution(str, Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? I think you can use the rademacher in all cases and remove this enum and the addition to the gptq config.
The results (on average) should be similar for both methods if Im not mistaken (unless you seen a different behavior?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I saw different results, not sure. In any case it changes the existing behavior so I don't think this belongs to this PR. This isn't exposed to user anyway, so there is no problem to remove it later.

"""
Distribution for Hutchinson estimator random vector
"""
GAUSSIAN = 'gaussian'
RADEMACHER = 'rademacher'


class HessianScoresRequest:
"""
Request configuration for the Hessian-approximation scores.
Expand All @@ -53,7 +61,8 @@ class HessianScoresRequest:
def __init__(self,
mode: HessianMode,
granularity: HessianScoresGranularity,
target_nodes: List):
target_nodes: List,
distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN):
"""
Attributes:
mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations).
Expand All @@ -64,16 +73,18 @@ def __init__(self,
self.mode = mode # w.r.t activations or weights
self.granularity = granularity # per element, per layer, per channel
self.target_nodes = target_nodes
self.distribution = distribution

def __eq__(self, other):
# Checks if the other object is an instance of HessianScoresRequest
# and then checks if all attributes are equal.
return isinstance(other, HessianScoresRequest) and \
self.mode == other.mode and \
self.granularity == other.granularity and \
self.target_nodes == other.target_nodes
self.target_nodes == other.target_nodes and \
self.distribution == other.distribution

def __hash__(self):
# Computes the hash based on the attributes.
# The use of a tuple here ensures that the hash is influenced by all the attributes.
return hash((self.mode, self.granularity, tuple(self.target_nodes)))
return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution))
Loading
Loading