-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
400 add math ops and refactored fixed point config #404
base: develop
Are you sure you want to change the base?
400 add math ops and refactored fixed point config #404
Conversation
Signed-off-by: Leo Buron <[email protected]>
Signed-off-by: Leo Buron <[email protected]>
Signed-off-by: Leo Buron <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if
branches (at least the one in the rounding function) I'd like to see addressed. Everything else I mention is just comments i guess.
Oh... almost forgot: Not sure there are tests for the old fxp config (i fear not). Do you want to add a few for the rounding behaviour? I mean a lot of the rest here is hard to test and not really test worthy, i guess.
|
||
|
||
def _round(number: torch.Tensor, fxp_conf: FixedPointConfigV2) -> torch.Tensor: | ||
if fxp_conf.stochastic_rounding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That if statement will not be found during tracing, therefore we might prefer to use two different functions. Following this thought, we'd end up with two different models, one using fxp rounding while the other doesn't. Not sure if that'd be ok for you, but i think it's the cleaner approach. That way you'd not even have to add the stochastic rounding flag to fxp_config and the fxp_config could just be (total_Bits, frac_bits)
.
Also, from experience flag arguments like this often make it hard to know which rounding method is actually used, especially if the fxp_conf
object is mutable.
For now i think it'd even be fine to just add a comment, that you'd have to be aware of this when tracing. Or we flag entire modules as not being fit for tracing. Maybe @mokouMonday has an opinion here (i guess you're one of the people for which that matters).
Rest looks good ;).
class RoundToFixedPoint(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: | ||
if len(args) != 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you won't catch the case where people supply more parameters as kwargs, but I don't even think you need to handle that here
x: torch.Tensor = args[0] | ||
fxp_config: FixedPointConfigV2 | None = args[1] | ||
ctx.grad_fxp_config: FixedPointConfigV2 | None = args[2] | ||
if fxp_config is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This instead might be an actual problem. The RoundToFixedPoint function not rounding to fixed point but instead acting as an identity seems like a Liskov violation to me. I'd never expect this function to not somehow round a value. We could actually raise an error in case no fxp config is given.
|
||
|
||
@dataclass | ||
class FixedPointConfigV2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could add a
class FixedPointConfig:
__slots__ = ("frac_bits", "total_bits", "stochastic_rounding")
def __init__(self, total_bits: int, frac_bits: int, stochastic_rounding: bool):
#...
This prevents adding more fields to an object of this type after creation and makes attribute access way faster. I.e.,
fxp = FixedPointConfig(4, 2, False)
fxp.something_else = 4 # runtime error
Finally, since this lives in the quantized_grads fixed_point package rn, you could just call this FixedPointConfig
as well, no problem. Another options is: we place it in the same package as the old FixedPointConfig
and call it FixedPointConfigV2
.
I think both of these would make it easier to replace the old config by your newer one.
This feature adds basic support for future use of quantized fixed point gradients. It includes a refactor of the existing a new fixed point config which is refactored because the old one was doing to much. Following, I changed the naming a bit and have now a fixed point configuration with datafields only. The quantization function is now moved to the quantize_to_fixed_point file using the fixed point configuration.