Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

400 add math ops and refactored fixed point config #404

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from

Conversation

LeoBuron
Copy link
Member

This feature adds basic support for future use of quantized fixed point gradients. It includes a refactor of the existing a new fixed point config which is refactored because the old one was doing to much. Following, I changed the naming a bit and have now a fixed point configuration with datafields only. The quantization function is now moved to the quantize_to_fixed_point file using the fixed point configuration.

@LeoBuron LeoBuron added the enhancement New feature or request label Nov 28, 2024
@LeoBuron LeoBuron self-assigned this Nov 28, 2024
Copy link
Contributor

@glencoe glencoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if branches (at least the one in the rounding function) I'd like to see addressed. Everything else I mention is just comments i guess.

Oh... almost forgot: Not sure there are tests for the old fxp config (i fear not). Do you want to add a few for the rounding behaviour? I mean a lot of the rest here is hard to test and not really test worthy, i guess.



def _round(number: torch.Tensor, fxp_conf: FixedPointConfigV2) -> torch.Tensor:
if fxp_conf.stochastic_rounding:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That if statement will not be found during tracing, therefore we might prefer to use two different functions. Following this thought, we'd end up with two different models, one using fxp rounding while the other doesn't. Not sure if that'd be ok for you, but i think it's the cleaner approach. That way you'd not even have to add the stochastic rounding flag to fxp_config and the fxp_config could just be (total_Bits, frac_bits).

Also, from experience flag arguments like this often make it hard to know which rounding method is actually used, especially if the fxp_conf object is mutable.

For now i think it'd even be fine to just add a comment, that you'd have to be aware of this when tracing. Or we flag entire modules as not being fit for tracing. Maybe @mokouMonday has an opinion here (i guess you're one of the people for which that matters).

Rest looks good ;).

class RoundToFixedPoint(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
if len(args) != 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you won't catch the case where people supply more parameters as kwargs, but I don't even think you need to handle that here

x: torch.Tensor = args[0]
fxp_config: FixedPointConfigV2 | None = args[1]
ctx.grad_fxp_config: FixedPointConfigV2 | None = args[2]
if fxp_config is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This instead might be an actual problem. The RoundToFixedPoint function not rounding to fixed point but instead acting as an identity seems like a Liskov violation to me. I'd never expect this function to not somehow round a value. We could actually raise an error in case no fxp config is given.



@dataclass
class FixedPointConfigV2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could add a

class FixedPointConfig:
  __slots__ = ("frac_bits", "total_bits", "stochastic_rounding")
  
  def __init__(self, total_bits: int, frac_bits: int, stochastic_rounding: bool):
    #...

This prevents adding more fields to an object of this type after creation and makes attribute access way faster. I.e.,

fxp = FixedPointConfig(4, 2, False)
fxp.something_else = 4 # runtime error

Finally, since this lives in the quantized_grads fixed_point package rn, you could just call this FixedPointConfig as well, no problem. Another options is: we place it in the same package as the old FixedPointConfig and call it FixedPointConfigV2.

I think both of these would make it easier to replace the old config by your newer one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants