diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
index 560bc9003c6..96fbb2ca24b 100644
--- a/invokeai/app/invocations/denoise_latents.py
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -901,7 +901,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
                 # ext: freeu, seamless, ip adapter, lora
                 ext_manager.patch_unet(unet, cached_weights),
             ):
-                sd_backend = StableDiffusionBackend(unet, scheduler)
+                sd_backend = StableDiffusionBackend()
                 denoise_ctx.unet = unet
                 result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
 
diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py
index 9060d549776..7642c45c3ad 100644
--- a/invokeai/backend/stable_diffusion/denoise_context.py
+++ b/invokeai/backend/stable_diffusion/denoise_context.py
@@ -96,7 +96,8 @@ class DenoiseContext:
     timestep: Optional[torch.Tensor] = None
 
     # Arguments which will be passed to UNet model.
-    # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
+    # Available in `PRE_UNET_FORWARD`/`POST_UNET_FORWARD` callbacks
+    # and in `UNET_FORWARD` override, otherwise will be None.
     unet_kwargs: Optional[UNetKwargs] = None
 
     # SchedulerOutput class returned from step function(normally, generated by scheduler).
@@ -109,7 +110,8 @@ class DenoiseContext:
     latent_model_input: Optional[torch.Tensor] = None
 
     # [TMP] Defines on which conditionings current unet call will be runned.
-    # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
+    # Available in `PRE_UNET_FORWARD`/`POST_UNET_FORWARD` callbacks
+    # and in `UNET_FORWARD` override, otherwise will be None.
     conditioning_mode: Optional[ConditioningMode] = None
 
     # [TMP] Noise predictions from negative conditioning.
diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py
index 4191db734f9..9df7d18d229 100644
--- a/invokeai/backend/stable_diffusion/diffusion_backend.py
+++ b/invokeai/backend/stable_diffusion/diffusion_backend.py
@@ -1,25 +1,19 @@
 from __future__ import annotations
 
 import torch
-from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
-from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
+from diffusers.schedulers.scheduling_utils import SchedulerOutput
 from tqdm.auto import tqdm
 
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
 from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
 from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
+from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType
 from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
 
 
 class StableDiffusionBackend:
-    def __init__(
-        self,
-        unet: UNet2DConditionModel,
-        scheduler: SchedulerMixin,
-    ):
-        self.unet = unet
-        self.scheduler = scheduler
+    def __init__(self):
         config = get_config()
         self._sequential_guidance = config.sequential_guidance
 
@@ -31,7 +25,7 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa
 
         if ctx.inputs.noise is not None:
             batch_size = ctx.latents.shape[0]
-            # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
+            # latents = noise * ctx.scheduler.init_noise_sigma # it's like in t2l according to diffusers
             ctx.latents = ctx.scheduler.add_noise(
                 ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
             )
@@ -49,7 +43,7 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa
             ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
 
             # ext: tiles? [override: step]
-            ctx.step_output = self.step(ctx, ext_manager)
+            ctx.step_output = ext_manager.run_override(ExtensionOverrideType.STEP, self.step, ctx, ext_manager)
 
             # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
             # ext: preview[post_step, priority=low]
@@ -77,7 +71,9 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
             ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
 
         # ext: override combine_noise_preds
-        ctx.noise_pred = self.combine_noise_preds(ctx)
+        ctx.noise_pred = ext_manager.run_override(
+            ExtensionOverrideType.COMBINE_NOISE_PREDS, self.combine_noise_preds, ctx, ext_manager
+        )
 
         # ext: cfg_rescale [modify_noise_prediction]
         # TODO: rename
@@ -94,17 +90,6 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
 
         return step_output
 
-    @staticmethod
-    def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
-        guidance_scale = ctx.inputs.conditioning_data.guidance_scale
-        if isinstance(guidance_scale, list):
-            guidance_scale = guidance_scale[ctx.step_index]
-
-        # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
-        # in slightly different outputs. It is suspected that this is caused by small precision differences.
-        # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
-        return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
-
     def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
         sample = ctx.latent_model_input
         if conditioning_mode == ConditioningMode.Both:
@@ -122,15 +107,13 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio
         ctx.conditioning_mode = conditioning_mode
         ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
 
-        # ext: controlnet/ip/t2i [pre_unet]
-        ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
+        # ext: controlnet/ip/t2i [pre_unet_forward]
+        ext_manager.run_callback(ExtensionCallbackType.PRE_UNET_FORWARD, ctx)
 
-        # ext: inpaint [pre_unet, priority=low]
-        # or
-        # ext: inpaint [override: unet_forward]
-        noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
+        # ext: inpaint model/ic-light [override: unet_forward]
+        noise_pred = ext_manager.run_override(ExtensionOverrideType.UNET_FORWARD, self.unet_forward, ctx, ext_manager)
 
-        ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
+        ext_manager.run_callback(ExtensionCallbackType.POST_UNET_FORWARD, ctx)
 
         # clean up locals
         ctx.unet_kwargs = None
@@ -138,5 +121,17 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio
 
         return noise_pred
 
-    def _unet_forward(self, **kwargs) -> torch.Tensor:
-        return self.unet(**kwargs).sample
+    # pass extensions manager as arg to allow override access it
+    def combine_noise_preds(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> torch.Tensor:
+        guidance_scale = ctx.inputs.conditioning_data.guidance_scale
+        if isinstance(guidance_scale, list):
+            guidance_scale = guidance_scale[ctx.step_index]
+
+        # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
+        # in slightly different outputs. It is suspected that this is caused by small precision differences.
+        # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
+        return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
+
+    # pass extensions manager as arg to allow override access it
+    def unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> torch.Tensor:
+        return ctx.unet(**vars(ctx.unet_kwargs)).sample
diff --git a/invokeai/backend/stable_diffusion/extension_callback_type.py b/invokeai/backend/stable_diffusion/extension_callback_type.py
index e4c365007ba..8dfb1441568 100644
--- a/invokeai/backend/stable_diffusion/extension_callback_type.py
+++ b/invokeai/backend/stable_diffusion/extension_callback_type.py
@@ -7,6 +7,6 @@ class ExtensionCallbackType(Enum):
     POST_DENOISE_LOOP = "post_denoise_loop"
     PRE_STEP = "pre_step"
     POST_STEP = "post_step"
-    PRE_UNET = "pre_unet"
-    POST_UNET = "post_unet"
+    PRE_UNET_FORWARD = "pre_unet_forward"
+    POST_UNET_FORWARD = "post_unet_forward"
     POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
diff --git a/invokeai/backend/stable_diffusion/extension_override_type.py b/invokeai/backend/stable_diffusion/extension_override_type.py
new file mode 100644
index 00000000000..9256a736fd4
--- /dev/null
+++ b/invokeai/backend/stable_diffusion/extension_override_type.py
@@ -0,0 +1,7 @@
+from enum import Enum
+
+
+class ExtensionOverrideType(Enum):
+    STEP = "step"
+    UNET_FORWARD = "unet_forward"
+    COMBINE_NOISE_PREDS = "combine_noise_preds"
diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py
index 820d5d32a37..2667e7344fd 100644
--- a/invokeai/backend/stable_diffusion/extensions/base.py
+++ b/invokeai/backend/stable_diffusion/extensions/base.py
@@ -2,7 +2,7 @@
 
 from contextlib import contextmanager
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
 
 import torch
 from diffusers import UNet2DConditionModel
@@ -10,6 +10,7 @@
 if TYPE_CHECKING:
     from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
     from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
+    from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType
 
 
 @dataclass
@@ -35,22 +36,54 @@ def _decorator(function):
     return _decorator
 
 
+@dataclass
+class OverrideMetadata:
+    override_type: ExtensionOverrideType
+
+
+@dataclass
+class OverrideFunctionWithMetadata:
+    metadata: OverrideMetadata
+    function: Callable[..., Any]
+
+
+def override(override_type: ExtensionOverrideType):
+    def _decorator(function):
+        function._ext_metadata = OverrideMetadata(
+            override_type=override_type,
+        )
+        return function
+
+    return _decorator
+
+
 class ExtensionBase:
     def __init__(self):
         self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
+        self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {}
 
         # Register all of the callback methods for this instance.
         for func_name in dir(self):
             func = getattr(self, func_name)
             metadata = getattr(func, "_ext_metadata", None)
-            if metadata is not None and isinstance(metadata, CallbackMetadata):
-                if metadata.callback_type not in self._callbacks:
-                    self._callbacks[metadata.callback_type] = []
-                self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
+            if metadata is not None:
+                if isinstance(metadata, CallbackMetadata):
+                    if metadata.callback_type not in self._callbacks:
+                        self._callbacks[metadata.callback_type] = []
+                    self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
+                elif isinstance(metadata, OverrideMetadata):
+                    if metadata.override_type in self._overrides:
+                        raise RuntimeError(
+                            f"Override {metadata.override_type} defined multiple times in {type(self).__qualname__}"
+                        )
+                    self._overrides[metadata.override_type] = OverrideFunctionWithMetadata(metadata, func)
 
     def get_callbacks(self):
         return self._callbacks
 
+    def get_overrides(self):
+        return self._overrides
+
     @contextmanager
     def patch_extension(self, ctx: DenoiseContext):
         yield None
diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py
index a48a681af3f..4b8b748a1ef 100644
--- a/invokeai/backend/stable_diffusion/extensions/controlnet.py
+++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py
@@ -68,8 +68,8 @@ def resize_image(self, ctx: DenoiseContext):
             resize_mode=self._resize_mode,
         )
 
-    @callback(ExtensionCallbackType.PRE_UNET)
-    def pre_unet_step(self, ctx: DenoiseContext):
+    @callback(ExtensionCallbackType.PRE_UNET_FORWARD)
+    def pre_unet_forward(self, ctx: DenoiseContext):
         # skip if model not active in current step
         total_steps = len(ctx.inputs.timesteps)
         first_step = math.floor(self._begin_step_percent * total_steps)
diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py
index 6ee8ef6311c..cfe44f8125f 100644
--- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py
+++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py
@@ -1,15 +1,17 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Callable, Optional
 
 import torch
 from diffusers import UNet2DConditionModel
 
 from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
-from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
+from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType
+from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback, override
 
 if TYPE_CHECKING:
     from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
+    from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
 
 
 class InpaintModelExt(ExtensionBase):
@@ -68,9 +70,8 @@ def init_tensors(self, ctx: DenoiseContext):
             self._masked_latents = torch.zeros_like(ctx.latents[:1])
         self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
 
-    # Do last so that other extensions works with normal latents
-    @callback(ExtensionCallbackType.PRE_UNET, order=1000)
-    def append_inpaint_layers(self, ctx: DenoiseContext):
+    @override(ExtensionOverrideType.UNET_FORWARD)
+    def append_inpaint_layers(self, orig_func: Callable[..., Any], ctx: DenoiseContext, ext_manager: ExtensionsManager):
         batch_size = ctx.unet_kwargs.sample.shape[0]
         b_mask = torch.cat([self._mask] * batch_size)
         b_masked_latents = torch.cat([self._masked_latents] * batch_size)
@@ -78,6 +79,7 @@ def append_inpaint_layers(self, ctx: DenoiseContext):
             [ctx.unet_kwargs.sample, b_mask, b_masked_latents],
             dim=1,
         )
+        return orig_func(ctx, ext_manager)
 
     # Restore unmasked part as inpaint model can change unmasked part slightly
     @callback(ExtensionCallbackType.POST_DENOISE_LOOP)
diff --git a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
index 5c290ea4e79..c7d1fc40646 100644
--- a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
+++ b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
@@ -96,8 +96,8 @@ def _run_model(
 
         return model(t2i_image)
 
-    @callback(ExtensionCallbackType.PRE_UNET)
-    def pre_unet_step(self, ctx: DenoiseContext):
+    @callback(ExtensionCallbackType.PRE_UNET_FORWARD)
+    def pre_unet_forward(self, ctx: DenoiseContext):
         # skip if model not active in current step
         total_steps = len(ctx.inputs.timesteps)
         first_step = math.floor(self._begin_step_percent * total_steps)
diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py
index c8d585406a8..b9389e83bea 100644
--- a/invokeai/backend/stable_diffusion/extensions_manager.py
+++ b/invokeai/backend/stable_diffusion/extensions_manager.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from contextlib import ExitStack, contextmanager
-from typing import TYPE_CHECKING, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
 
 import torch
 from diffusers import UNet2DConditionModel
@@ -11,7 +11,12 @@
 if TYPE_CHECKING:
     from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
     from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
-    from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
+    from invokeai.backend.stable_diffusion.extension_override_type import ExtensionOverrideType
+    from invokeai.backend.stable_diffusion.extensions.base import (
+        CallbackFunctionWithMetadata,
+        ExtensionBase,
+        OverrideFunctionWithMetadata,
+    )
 
 
 class ExtensionsManager:
@@ -21,11 +26,19 @@ def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
         # A list of extensions in the order that they were added to the ExtensionsManager.
         self._extensions: List[ExtensionBase] = []
         self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
+        self._overrides: Dict[ExtensionOverrideType, OverrideFunctionWithMetadata] = {}
 
     def add_extension(self, extension: ExtensionBase):
         self._extensions.append(extension)
         self._regenerate_ordered_callbacks()
 
+        for override_type, override in extension.get_overrides().items():
+            if override_type in self._overrides:
+                raise RuntimeError(
+                    f"Override {override_type} already defined by {self._overrides[override_type].function.__qualname__}"
+                )
+            self._overrides[override_type] = override
+
     def _regenerate_ordered_callbacks(self):
         """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
         self._ordered_callbacks = {}
@@ -51,6 +64,16 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext
         for cb in callbacks:
             cb.function(ctx)
 
+    def run_override(self, override_type: ExtensionOverrideType, orig_function: Callable[..., Any], *args, **kwargs):
+        if self._is_canceled and self._is_canceled():
+            raise CanceledException
+
+        override = self._overrides.get(override_type, None)
+        if override is not None:
+            return override.function(orig_function, *args, **kwargs)
+        else:
+            return orig_function(*args, **kwargs)
+
     @contextmanager
     def patch_extensions(self, ctx: DenoiseContext):
         if self._is_canceled and self._is_canceled():