-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support mixed MX element dtype in mx_mm
function and MXLinear
.
#1667
base: main
Are you sure you want to change the base?
Support mixed MX element dtype in mx_mm
function and MXLinear
.
#1667
Conversation
Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1667
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 5c8eb6d with merge base 8afd10e (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function): | |||
# 1. input @ weight_t = output (forward pass) | |||
# 2. grad_output @ weight = grad_input (backward pass) | |||
# 3. input_t @ grad_output = grad_weight (backward pass) | |||
# | |||
# input, weight and grad_output have each their own MX element dtype. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "can have"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
this makes sense, it would be great to cover with a test the easiest place to test it would be here (
MXLinear . Would you be interested in doing that in this PR?
by the way, pytorch/pytorch#146414 outlines bringing MX dtypes to PyTorch core, and we plan to evolve |
…er factory method. Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface of `MXLinear` and `swap_linear_with_mx_linear`. Some additional unit test coverage has been added on MXLinear.
mx_mm
function.mx_mm
function and MXLinear
.
I added the support of this feature in I expanded the coverage in the test you mentioned (plus a small test on the factory side to check the 2 cases above are working properly). Thanks for the link on PyTorch MX plan 👍 I would assume that the MX "simulated" mode is going to stay in TorchAO for some time, as it is very useful for testing + getting ready for MX hardware until it is widely available. |
""" | ||
|
||
@classmethod | ||
@torch.no_grad() | ||
def from_float(cls, mod, elem_dtype, block_size): | ||
mod.__class__ = MXLinear | ||
mod.elem_dtype = elem_dtype | ||
# Single element dtype passed for input, weight and gradient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we do
def from_float(
...,
elem_dtype,
...,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
...
): ...
we plan to create a proper config object for this in the future, but for now would be good to keep things simple and avoid mixing types in the API (such as dtype vs tuple)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I then enforce named argument in MXLinear.from_float
and swap_linear_with_mx_linear
for block_size
and filter_fn
? And have a default block_size=32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable!
yep! great to hear this is useful. |
Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients. This PR is simply adding a more general interface to
mx_mm
. A similar choice could be done withMXLinear
General issue: #1666