Skip to content

Commit

Permalink
Explain wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Feb 12, 2025
1 parent 13ea528 commit 99bbb4f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
9 changes: 9 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,15 @@ def transpose(tensor: torch.Tensor, axis: int):
return tensor.permute(shape)


# When fuse_scaling = False, the scaling parameters are instances of nn.Parameter,
# which are registered to the scaling modules (used in the parametrization of the
# the weights). By default, these parameters have requires_grad set to True, and when
# registering the parametrizations, the forward pass of the parametrization modules
# is run (when unsafe is False, see _init_ of ParametrizationList), as well as when
# a module is part of multiple regions. Therefore, wrapping _cross_layer_equalization
# with torch.no_grad() prevents gradient functions (and gradient-related intermediate
# tensors) from being recorded when running this forward, thus preventing unnecessary
# memory consumption during the algorithm execution.
@torch.no_grad()
def _cross_layer_equalization(
model: nn.Module,
Expand Down
11 changes: 6 additions & 5 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,12 @@ def fused_rotation_no_fx(model, calibration_loader, args):
use_parametrized_rotations=args.optimize_rotations)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')
for r in rewriters:
# The weights between model and new_model are tied, so this check prevents
# rotating the weights twice
if not isinstance(r, ModuleInstanceTransformTensor):
model = r.apply(model)
with torch.no_grad():
for r in rewriters:
# The weights between model and new_model are tied, so this check prevents
# rotating the weights twice
if not isinstance(r, ModuleInstanceTransformTensor):
model = r.apply(model)
remove_hooks(new_model)


Expand Down

0 comments on commit 99bbb4f

Please sign in to comment.