-
Notifications
You must be signed in to change notification settings - Fork 416
[JAX] Decouple Recipe and ScalingMode #1728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Decouple Recipe and ScalingMode #1728
Conversation
Signed-off-by: Jeremy Berchtold <[email protected]>
aa85930
to
575f4c4
Compare
/te-ci L0 |
…e instead of no quantization Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: Jeremy Berchtold <[email protected]>
/te-ci L0 |
class UsageContext: | ||
"""Context of where a particular quantizer will be used which is needed by some recipes.""" | ||
|
||
usage_type: UsageType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, why do we need a new class just to wrap around an enum?
|
||
|
||
@dataclass | ||
class QuantizerParams: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a look at this QuantizerParams
and on how it is used, and I would prefer not to have this class for the following reasons:
- Whenever we need to query the
scaling_mode
orq_dtype
orq_layout
info, instead of
QuantizeConfig.RECIPE_MANAGER.get_quantizer_params(UsageContext(UsageType.X)).scaling_mode
We could have simply done
QuantizeConfig.RECIPE_MANAGER.get_scaling_mode(UsageType.X)
It is way simpler for other people to follow and make a contribution later.
2. With this QuantizerParam
, we add one more level of object inside the Quantizer, which does not give any benefits. For the Quantizer create, we could do
q_x = QuantizerFactory.create(RECIPE_MANAGER.get_scaling_mode(UsageType.X),
RECIPE_MANAGER.get_quantize_dtype(UsageType.X),
RECIPE_MANAGER.get_quantize_layout(UsageType.X),
**args_x)
cls.INITIALIZED = True | ||
cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 | ||
cls.FP8_FORMAT = fp8_recipe.fp8_format | ||
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) | ||
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) | ||
cls.RECIPE_MANAGER = recipe_manager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we merge QuantizeConfig
and RecipeManager
into a single class?
I don't see a clear need for them to exist separately.
Description
Currently the recipe and scaling mode are coupled in the TE/JAX extension. This is okay for the current recipes, such as delayed scaling, current scaling, and MXFP8 block scaling, as there is only a single scaling mode used per recipe. However, for the DeepSeek recipe this assumption no longer holds. For the DeepSeek recipe we will need 1x128 1D block scaling for inputs and 128x128 2D block scaling for weights. As a result, we need to decouple the two concepts of recipe and scaling mode.
This PR only decouples the recipe and scaling mode, it does not implement the DeepSeek recipe.
Type of change
Changes
UsageContext
which defines the context in which a quantizer will be used (e.g. is the quantizer used for x, kernel, or grad)RecipeManager
classes that provides recipe-specific functionality for quantizationQuantizeConfig.SCALING_MODE
withQuantizeConfig.RECIPE_MANAGER
and updateQuantizeFactory
to use the recipe manager insteadChecklist: