Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mixed MX element dtype in mx_mm function and MXLinear. #1667

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

balancap
Copy link

@balancap balancap commented Feb 5, 2025

Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients. This PR is simply adding a more general interface to mx_mm. A similar choice could be done with MXLinear

General issue: #1666

Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients.
Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1667

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 5c8eb6d with merge base 8afd10e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function):
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# input, weight and grad_output have each their own MX element dtype.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "can have"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@vkuzo
Copy link
Contributor

vkuzo commented Feb 5, 2025

this makes sense, it would be great to cover with a test

the easiest place to test it would be here (

def test_linear_eager(elem_dtype, bias, input_shape):
), and that requires adding this to MXLinear. Would you be interested in doing that in this PR?

by the way, pytorch/pytorch#146414 outlines bringing MX dtypes to PyTorch core, and we plan to evolve torchao/prototype/mx_formats/ accordingly

@vkuzo vkuzo added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Feb 5, 2025
…er factory method.

Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface
of `MXLinear` and `swap_linear_with_mx_linear`.

Some additional unit test coverage has been added on MXLinear.
@balancap balancap changed the title Support mixed MX element dtype in mx_mm function. Support mixed MX element dtype in mx_mm function and MXLinear. Feb 5, 2025
@balancap
Copy link
Author

balancap commented Feb 5, 2025

I added the support of this feature in MXLinear too. In order to avoid breaking the interface (and keeping things simple in the single dtype case), you can now pass either a single element dtype or a tuple of 3.

I expanded the coverage in the test you mentioned (plus a small test on the factory side to check the 2 cases above are working properly).

Thanks for the link on PyTorch MX plan 👍 I would assume that the MX "simulated" mode is going to stay in TorchAO for some time, as it is very useful for testing + getting ready for MX hardware until it is widely available.

"""

@classmethod
@torch.no_grad()
def from_float(cls, mod, elem_dtype, block_size):
mod.__class__ = MXLinear
mod.elem_dtype = elem_dtype
# Single element dtype passed for input, weight and gradient.
Copy link
Contributor

@vkuzo vkuzo Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we do

def from_float(
    ...,
    elem_dtype,
    ...,
    elem_dtype_weight_override=None,
    elem_dtype_grad_output_override=None,
    ...
): ...

we plan to create a proper config object for this in the future, but for now would be good to keep things simple and avoid mixing types in the API (such as dtype vs tuple)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I then enforce named argument in MXLinear.from_float and swap_linear_with_mx_linear for block_size and filter_fn? And have a default block_size=32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable!

@vkuzo
Copy link
Contributor

vkuzo commented Feb 5, 2025

I would assume that the MX "simulated" mode is going to stay in TorchAO for some time, as it is very useful for testing + getting ready for MX hardware until it is widely available.

yep! great to hear this is useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants