Skip to content

[JAX] Decouple Recipe and ScalingMode #1728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jberchtold-nvidia
Copy link
Collaborator

Description

Currently the recipe and scaling mode are coupled in the TE/JAX extension. This is okay for the current recipes, such as delayed scaling, current scaling, and MXFP8 block scaling, as there is only a single scaling mode used per recipe. However, for the DeepSeek recipe this assumption no longer holds. For the DeepSeek recipe we will need 1x128 1D block scaling for inputs and 128x128 2D block scaling for weights. As a result, we need to decouple the two concepts of recipe and scaling mode.

This PR only decouples the recipe and scaling mode, it does not implement the DeepSeek recipe.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Define UsageContext which defines the context in which a quantizer will be used (e.g. is the quantizer used for x, kernel, or grad)
  • Add RecipeManager classes that provides recipe-specific functionality for quantization
  • Replace QuantizeConfig.SCALING_MODE with QuantizeConfig.RECIPE_MANAGER and update QuantizeFactory to use the recipe manager instead

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/jax-scaling-mode-and-recipe-decoupling branch from aa85930 to 575f4c4 Compare April 29, 2025 01:12
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0

…e instead of no quantization

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0

@jberchtold-nvidia jberchtold-nvidia requested a review from phu0ngng May 1, 2025 20:18
Comment on lines +172 to +175
class UsageContext:
"""Context of where a particular quantizer will be used which is needed by some recipes."""

usage_type: UsageType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, why do we need a new class just to wrap around an enum?



@dataclass
class QuantizerParams:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a look at this QuantizerParams and on how it is used, and I would prefer not to have this class for the following reasons:

  1. Whenever we need to query the scaling_mode or q_dtype or q_layout info, instead of
QuantizeConfig.RECIPE_MANAGER.get_quantizer_params(UsageContext(UsageType.X)).scaling_mode

We could have simply done

QuantizeConfig.RECIPE_MANAGER.get_scaling_mode(UsageType.X)

It is way simpler for other people to follow and make a contribution later.
2. With this QuantizerParam, we add one more level of object inside the Quantizer, which does not give any benefits. For the Quantizer create, we could do

q_x = QuantizerFactory.create(RECIPE_MANAGER.get_scaling_mode(UsageType.X),
                                                    RECIPE_MANAGER.get_quantize_dtype(UsageType.X),
                                                    RECIPE_MANAGER.get_quantize_layout(UsageType.X),
                                                    **args_x)

cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
cls.RECIPE_MANAGER = recipe_manager
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we merge QuantizeConfig and RecipeManager into a single class?
I don't see a clear need for them to exist separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants