diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py index 8f52990227f..d6afa0bdb7b 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py @@ -143,6 +143,9 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N patched_types.add(type(mod)) set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module + if mod_type == "dynamic_moe" and hasattr(mod, "num_experts"): + # override default number of outputs for dynamic moe + mod_types[mod_type].num_outputs = mod.num_experts+1 mod_extra_config = ( init_measure_object( mod, diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py index 91338ccc160..d843a859ebc 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py @@ -82,6 +82,7 @@ def create_mod_info_recursion(parent): "FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False), "GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock), "VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp), + "GaudiDeepseekV3MoE": ModuleInfo("dynamic_moe", PatchedGaudiDeepseekV3MoE), } diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index b69804255fd..9765ca14966 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -164,7 +164,7 @@ def init_linear(instance, mod_extra_config): def init_mixture_of_experts_linears(instance): parent_name = instance.orig_mod_parent.__class__.__name__ - if parent_name == "MixtralBlockSparseTop2MLP": + if parent_name == "MixtralBlockSparseTop2MLP" or (parent_name == "GaudiDeepseekV3MLP" and instance.orig_mod_parent.add_dummy_quant_input): # this linear is part of MixtureOfExperts block # MoE linears hold the weights but their forward logic is done using the dynamic op # therefore no measure object is saved causing no quant object as well @@ -730,6 +730,103 @@ def extra_repr(self) -> str: ) +class PatchedGaudiDeepseekV3MoE(PatchedModuleBase): + def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): + super().__init__(mod, parent, mod_extra_config, *args, **kwargs) + self.forward = self.forward_orig + if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: + self.dynamic_moe_op = get_hpu_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE, self.scale_format) + self.quant_input = self._mod_extra_config.inputs[0] + self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format) + self.register_scale( + "scale_intermediate", + [mod_extra_config.scale.inputs[x] for x in range(1, self.config.n_routed_experts+1)], + self.scale_format, + ) + mod.call_dynamic_moe_op = self.call_dynamic_moe_quant_op + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + mod.call_dynamic_moe_op = self.call_dynamic_moe_measure_op + + def call_dynamic_moe_quant_op(self, + hidden_states, + topk_idx, + topk_weight, + experts_min, + experts_max, + activation="silu"): + experts_range = range(experts_min, experts_max) + gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range] + down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range] + up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range] + scale_gate_proj = [self.experts[i].gate_proj.scale_weight for i in experts_range] + scale_down_proj = [self.experts[i].down_proj.scale_weight for i in experts_range] + scale_up_proj = [self.experts[i].up_proj.scale_weight for i in experts_range] + qinput = self.quant_input(hidden_states) + + return self.dynamic_moe_op( + hidden_states=qinput, + expert_routing_table=topk_idx, + router_weights=topk_weight, + w1=gate_proj_list, + w2=up_proj_list, + w3=down_proj_list, + d_scale_w1=scale_gate_proj, + d_scale_w2=scale_up_proj, + d_scale_w3=scale_down_proj, + d_scale_hidden_states=self.scale_input, + d_scale_intermediate_hidden_states=self.scale_intermediate[experts_min:experts_max], + permuted_weights=False, + activation=activation, + experts_min=experts_min, + experts_max=experts_max - 1, + ) + + def call_dynamic_moe_measure_op(self, + hidden_states, + topk_idx, + topk_weight, + experts_min, + experts_max, + permuted_weights=True, + activation="silu"): + experts_range = range(experts_min, experts_max) + gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range] + down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range] + up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range] + measure_input((hidden_states,), observer=self._mod_extra_config.inputs) + output, intermediate_amax = torch.ops.hpu.mixture_of_experts.fp8_measurement( + hidden_states=hidden_states, + expert_routing_table=topk_idx, + router_weights=topk_weight, + w1=gate_proj_list, + w2=up_proj_list, + w3=down_proj_list, + permuted_weights=permuted_weights, + activation=activation, + experts_min=experts_min, + experts_max=experts_max - 1, + measurement_mode=True, + ) + + # Update output and intermediate measures separately due to chunked MoE + measure_output([output], [self._mod_extra_config.outputs[0]]) + output_measure_list = [] + for i in range(experts_max-experts_min): + output_measure_list.append(intermediate_amax[i]) + measure_output(output_measure_list, self._mod_extra_config.outputs[experts_min+1:experts_max+1]) + return output + + def extra_repr(self) -> str: + member_names = ["scale_input"] + for x in range(1, self.config.n_routed_experts+1): + member_names.append("scale_intermediate["+str(x)+"]") + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, *member_names), + ) + + class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) diff --git a/neural_compressor/torch/algorithms/fp8_quant/utils/patched_module_restore_registry.py b/neural_compressor/torch/algorithms/fp8_quant/utils/patched_module_restore_registry.py index 9126019508f..50f3ce33653 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/utils/patched_module_restore_registry.py +++ b/neural_compressor/torch/algorithms/fp8_quant/utils/patched_module_restore_registry.py @@ -140,3 +140,10 @@ def __init__(self, patched_mod, *args, **kwargs): super().__init__() self.__dict__.update(patched_mod.__dict__) self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="GaudiDeepseekV3MoE") +class GaudiDeepseekV3MoE(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org