Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial Sample-Layer Attention for GPTQ (PyTorch) #1237

Merged
merged 5 commits into from
Oct 7, 2024
Merged

Conversation

irenaby
Copy link
Collaborator

@irenaby irenaby commented Oct 6, 2024

Pull Request Description:

Add hessian estimation per image hash.
Add sample-layer attention distillation loss.
Add weights per layer to soft round loss.
Update GPTQ config and its generation for sample layer attention.

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

@@ -55,98 +56,145 @@ def __init__(self,
hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=num_iterations_for_approximation)

def forward_pass(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no change, just extracted to method


def _compute_per_tensor(self, output, target_activation_tensors):
assert self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR
ipts_hessian_approx_scores = [torch.tensor([0.0], requires_grad=True, device=output.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no change (except line 177), just extracted to method

@@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum):
PER_TENSOR = 2


class HessianEstimationDistribution(str, Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? I think you can use the rademacher in all cases and remove this enum and the addition to the gptq config.
The results (on average) should be similar for both methods if Im not mistaken (unless you seen a different behavior?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I saw different results, not sure. In any case it changes the existing behavior so I don't think this belongs to this PR. This isn't exposed to user anyway, so there is no problem to remove it later.

return hessian_score_by_image_hash

@staticmethod
def calc_image_hash(image):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation

if not isinstance(inputs_batch, list):
raise TypeError('Expected a list of inputs') # pragma: no cover
if len(inputs_batch) > 1:
raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assert a non-empty list as well?

hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=self.num_iterations_for_approximation)
hessian_scores = fw_hessian_calculator.compute()
for b in range(inputs_batch[0].shape[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename 'b' (it made me think this is a batch, but this is a single image)

output = self.concat_tensors(output_tensors)
return output, target_activation_tensors

def _generate_random_vectors_batch(self, shape, distribution: HessianEstimationDistribution, device) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add missing type hints


# Update node Hessian approximation mean over random iterations
ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an overlap between _compute_per_tensor and _compute_per_channel. Is there a good reason why not to extract the common logic into a shared function to avoid this duplication?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not exactly the same. I'm sure it can be rewritten, but it's not as straightforward as in other places that were extracted, so was not a priority.

self.hessian_service.compute_trackable_per_sample_hessian(request, inputs)
)
for img_hash, v in hessian_score_per_image_per_layer.items():
hessian_score_per_image_per_layer[img_hash] = {k: t.max(axis=0) for k, t in v.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we computing t.max(axis=0)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the definition of the sample-layer attention. My understanding is that we are trying to approximate the upper bound, so we take the max among channels (this is per image so axis=0 is channels)

self.hessian_score_per_layer = None # for fixed layer weights
self.hessian_score_per_image_per_layer = None # for sample-layer attention
if self.use_sample_layer_attention:
assert (hessian_cfg.norm_scores is False and hessian_cfg.log_norm is False and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a comment on why this is true?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to NotImplementedError and added comment.

img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch]
for img_hash in img_hashes:
if img_hash not in self.hessian_score_per_image_per_layer:
score_per_image_layer_per = self._compute_sample_layer_attention_scores(input_tensors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you meant "score_per_image_per_layer"?
In general, I think this function needs to be rewritten cause it's hard to track what's going on here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything should be rewritten. Added more comments, hope it's clearer now.

raise TypeError(f'gradual_activation_quantization argument should be bool or '
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') # pragma: no cover
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, add a TODO to replace the 'no cover' since this case should be tested.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's true for all no cover throughout the code

@@ -40,32 +40,36 @@ def __init__(self, beta_scheduler: Callable[[int], float]):

self.count_iter = 0

def __call__(self, model: nn.Module, entropy_reg: float):
def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, the default case (where the weighting is all ones) is taking care outside of this function. If this is the case, maybe it is may be better to remove the default value of layer_weights?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not mistaken. I will remove it for now due to missing testing/coverage, but in principle I don't see a connection between the two. I think in this case it makes sense to have a default behavior, the fact that it was eventually more convenient to pass explicit weights in the specific use case should not really affect this.

@irenaby irenaby merged commit b26dd82 into main Oct 7, 2024
35 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants