-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat (llm/awq): activation-aware weight scaling #1213
base: dev
Are you sure you want to change the base?
Conversation
@@ -251,6 +250,48 @@ def apply(self, model, is_training, quantization_enabled): | |||
self.enable_param_quantization(model, is_training) | |||
|
|||
|
|||
class disable_enable_quantization: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially interesting for all use cases where we need to do this. I'd expose flags to disable weight/act/bias quantization
src/brevitas/graph/equalize.py
Outdated
@@ -780,9 +781,11 @@ def _no_equalize(): | |||
for module in chain(src_axes.values(), sink_axes.values()): | |||
rewriters.extend(module.instantiate_rewriters(rewriter_class, scaling_factors)) | |||
|
|||
# Apply rewriters before offloading | |||
# Apply rewriters before offloading, if parametrize_inplace is True. Note that parametrizations | |||
# are not immediately to prevent potential errors if the model is offloaded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate a bit more the issue here?
raise ValueError # early exit to break later inference | ||
|
||
# patch layer 0 to catch input and kwargs | ||
layers[0] = Catcher(layers[0]) | ||
blocks[0] = Catcher(blocks[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this part of the codebase, why can't we do what we do in GPTQ to catch the input to the first block?
We can also move that piece of code to some utils in exmples/common/generative
src/brevitas/utils/python_utils.py
Outdated
@@ -64,3 +65,30 @@ def run(*args, **kwargs): | |||
return function(*args, **kwargs) | |||
|
|||
return run | |||
|
|||
|
|||
def longest_common_prefix(strings: List[str]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems overly specific to AWQ, not sure if this should live here
"ffn.act": block.ffn.act, | ||
"ffn.down_proj": block.ffn.down_proj,}, | ||
)) | ||
elif "falcon" in str(block.__class__).lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only Llama for now
Reason for this PR
Implementation of AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
Using weight-only quantization and the configuration:
Changes Made in this PR
RegionAWQ
, inheriting fromRegion
to aggregate the information of the modules s on which AWQ optimizes the scale.auto_scale
andauto_clip
to rely on Brevitas quantizers.Testing Summary
Testing
apply_awq
against the author's repository.Risk Highlight
Checklist
dev
branch.