diff --git a/generator_process/actions/depth_to_image.py b/generator_process/actions/depth_to_image.py index c2111c4f..99b74731 100644 --- a/generator_process/actions/depth_to_image.py +++ b/generator_process/actions/depth_to_image.py @@ -194,7 +194,7 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs - self.check_inputs(prompt, height, width, strength, callback_steps) + self.check_inputs(prompt=prompt, image=image, mask_image=depth_image, height=height, width=width, strength=strength, callback_steps=callback_steps, output_type=output_type) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) @@ -366,11 +366,10 @@ def __call__( # Inference with torch.inference_mode() if device not in ('mps', "dml") else nullcontext(): - def callback(pipe, step, timestep, callback_kwargs): + def callback(step, _, latents): if future.check_cancelled(): raise InterruptedError() - future.add_response(step_latents(pipe, step_preview_mode, callback_kwargs["latents"], generator, step, steps)) - return callback_kwargs + future.add_response(step_latents(pipe, step_preview_mode, latents, generator, step, steps)) try: result = pipe( prompt=prompt, @@ -383,7 +382,7 @@ def callback(pipe, step, timestep, callback_kwargs): num_inference_steps=steps, guidance_scale=cfg_scale, generator=generator, - callback_on_step_end=callback, + callback=callback, callback_steps=1, output_type="np" )