Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 19, 2024
1 parent d9fff9a commit c399e9b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def __init__(
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.clamp_scaling = _ClampValue(scaling_min_val)

@brevitas.jit.script_method
def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
Expand All @@ -178,9 +177,8 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te
# We first apply any restriction to scaling
# For IntQuant, this is no-op, retrocompatible.
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
# Clamping avoids eventual log(0) with restrict_val
value = self.clamp_scaling(self.value)
value = abs_binary_sign_grad(self.restrict_clamp_scaling(value))
# We can clamp after restrict val since the learned parameter is already in log-domain
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
return value / threshold

def _load_from_state_dict(
Expand Down Expand Up @@ -426,6 +424,7 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
out = self.restrict_scaling(out)
out = out / threshold
# We can clamp after restrict val since the learned parameter is already in log-domain
out = abs_binary_sign_grad(self.clamp_scaling(out))
return out

Expand Down

0 comments on commit c399e9b

Please sign in to comment.