-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add initial Sample-Layer Attention for GPTQ (PyTorch) #1237
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,16 +12,19 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import hashlib | ||
|
||
import numpy as np | ||
from functools import partial | ||
from tqdm import tqdm | ||
from typing import Callable, List, Dict, Any, Tuple | ||
from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING | ||
|
||
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS | ||
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \ | ||
HessianScoresGranularity, HessianMode | ||
from model_compression_toolkit.logger import Logger | ||
if TYPE_CHECKING: # pragma: no cover | ||
from model_compression_toolkit.core.common import BaseNode | ||
|
||
|
||
class HessianInfoService: | ||
|
@@ -228,6 +231,61 @@ def compute(self, | |
return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \ | ||
and len(next_iter_remain_samples[0]) > 0 else None | ||
|
||
def compute_trackable_per_sample_hessian(self, | ||
hessian_scores_request: HessianScoresRequest, | ||
inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]: | ||
""" | ||
Compute hessian score per image hash. We compute the score directly for images rather than via data generator, | ||
as data generator might yield different images each time, depending on how it was defined, | ||
|
||
Args: | ||
hessian_scores_request: hessian scores request | ||
inputs_batch: a list containing a batch of inputs. | ||
|
||
Returns: | ||
A dict of Hessian scores per image hash per layer {image hash: {layer: score}} | ||
""" | ||
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()] | ||
hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name)) | ||
|
||
hessian_score_by_image_hash = {} | ||
|
||
if not inputs_batch or not isinstance(inputs_batch, list): | ||
raise TypeError('Expected a non-empty list of inputs') # pragma: no cover | ||
if len(inputs_batch) > 1: | ||
raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover | ||
|
||
# Get the framework-specific calculator Hessian-approximation scores | ||
fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph, | ||
input_images=inputs_batch, | ||
hessian_scores_request=hessian_scores_request, | ||
num_iterations_for_approximation=self.num_iterations_for_approximation) | ||
hessian_scores = fw_hessian_calculator.compute() | ||
for i in range(inputs_batch[0].shape[0]): | ||
img_hash = self.calc_image_hash(inputs_batch[0][i]) | ||
hessian_score_by_image_hash[img_hash] = { | ||
node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores) | ||
} | ||
|
||
return hessian_score_by_image_hash | ||
|
||
@staticmethod | ||
def calc_image_hash(image): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add documentation |
||
""" | ||
Calculates hash for an input image. | ||
|
||
Args: | ||
image: input 3d image (without batch). | ||
|
||
Returns: | ||
Image hash. | ||
|
||
""" | ||
if not len(image.shape) == 3: # pragma: no cover | ||
raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}') | ||
image_bytes = image.astype(np.float32).tobytes() | ||
return hashlib.md5(image_bytes).hexdigest() | ||
|
||
def fetch_hessian(self, | ||
hessian_scores_request: HessianScoresRequest, | ||
required_size: int, | ||
|
@@ -248,7 +306,7 @@ def fetch_hessian(self, | |
OC for per-output-channel when the requested node has OC output-channels, etc.) | ||
""" | ||
|
||
if len(hessian_scores_request.target_nodes) == 0: | ||
if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover | ||
return [] | ||
|
||
if required_size == 0: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum): | |
PER_TENSOR = 2 | ||
|
||
|
||
class HessianEstimationDistribution(str, Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this? I think you can use the rademacher in all cases and remove this enum and the addition to the gptq config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I saw different results, not sure. In any case it changes the existing behavior so I don't think this belongs to this PR. This isn't exposed to user anyway, so there is no problem to remove it later. |
||
""" | ||
Distribution for Hutchinson estimator random vector | ||
""" | ||
GAUSSIAN = 'gaussian' | ||
RADEMACHER = 'rademacher' | ||
|
||
|
||
class HessianScoresRequest: | ||
""" | ||
Request configuration for the Hessian-approximation scores. | ||
|
@@ -53,7 +61,8 @@ class HessianScoresRequest: | |
def __init__(self, | ||
mode: HessianMode, | ||
granularity: HessianScoresGranularity, | ||
target_nodes: List): | ||
target_nodes: List, | ||
distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN): | ||
""" | ||
Attributes: | ||
mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations). | ||
|
@@ -64,16 +73,18 @@ def __init__(self, | |
self.mode = mode # w.r.t activations or weights | ||
self.granularity = granularity # per element, per layer, per channel | ||
self.target_nodes = target_nodes | ||
self.distribution = distribution | ||
|
||
def __eq__(self, other): | ||
# Checks if the other object is an instance of HessianScoresRequest | ||
# and then checks if all attributes are equal. | ||
return isinstance(other, HessianScoresRequest) and \ | ||
self.mode == other.mode and \ | ||
self.granularity == other.granularity and \ | ||
self.target_nodes == other.target_nodes | ||
self.target_nodes == other.target_nodes and \ | ||
self.distribution == other.distribution | ||
|
||
def __hash__(self): | ||
# Computes the hash based on the attributes. | ||
# The use of a tuple here ensures that the hash is influenced by all the attributes. | ||
return hash((self.mode, self.granularity, tuple(self.target_nodes))) | ||
return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe assert a non-empty list as well?