From 9cc90064b1865f5112921e562909366416cb9e2a Mon Sep 17 00:00:00 2001 From: Ryan Y <101505765+KeepNoob@users.noreply.github.com> Date: Wed, 20 Mar 2024 10:04:37 +0800 Subject: [PATCH] Update pipeline_ddpm.py Inspect the signature of the scheduler, then decide to whether to pass the generator as an augment to the self.scheduler.step() --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 093a3cdfe512..2119c701bac9 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -109,12 +109,18 @@ def __call__( # set step values self.scheduler.set_timesteps(num_inference_steps) + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + if accepts_generator: + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + else: + image = self.scheduler.step(model_output, t, image, ).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()