diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ae1fbc7719a0..6a333862a909 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -119,7 +119,8 @@ def __init__( if use_fp8: self.op_hooks.append(FP8Hook()) if overlap_allgather: - self.op_hook = ZeroOpHook() + self.op_hooks.append(ZeroOpHook()) + if use_fp8 or overlap_allgather: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: p.__class__ = ColoParameter