diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 42bb49bc9ca1..8cc511a5610f 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -100,7 +100,7 @@ def forward(self, *args, **kwargs): if self.convert_fn is not None: args = tree_map(self.convert_fn, args) kwargs = tree_map(self.convert_fn, kwargs) - ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext() with ctx: return super().forward(*args, **kwargs)