Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor DPMSolverMultistepScheduler using sigmas #4986

Merged
merged 39 commits into from
Sep 19, 2023
Merged

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Sep 12, 2023

I think I overcomplicated things with this #4690 by trying to follow k-diffusion implementation.

trying it again with a simpler approach here: we simply calculate lambda_t, sigma_t and alpha_t from sigma

to-do

  • dpmsolver_multistep
  • dpmsolver_singlestep
  • DEIS
  • UniPC

testing

Notes:
I compared the results against the current implementation (for dpm multistep only). There is some slight numerical difference when using k-sigmas; for example, in the last testing examples below on the 4th row (sde-dpmsolver++, use_karras_sigma=True), you can visually see the outputs are slightly different.
Also, I had to loosen a test here #4986 (comment).
However, the new implementation is more accurate by using the sigma directly.

import torch
from diffusers import StableDiffusionKDiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline
import gc

# make sure to update your branch name 
branch = 'dpm-mstep-sigma-2'
#branch = "main"


# test 1: dpmsolver_+++, use_karras_sigma=False 
seed = 33

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config, use_karras_sigmas=False
)

prompt = "an astronaut riding a horse on mars"

generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]

image.save(f"[{branch}]_test_dpmsolver++.png")


# test 2: dpmsolver_+++, use_karras_sigma=True 
seed = 33

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config, use_karras_sigmas=True
)

prompt = "an astronaut riding a horse on mars"

generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]

image.save(f"[{branch}]_test_dpmsolver++_k_sigma.png")

# # test 3: sde-dpmsolver++, use_karras_sigma=False 

seed = 33

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config, use_karras_sigmas=False, algorithm_type="sde-dpmsolver++"
)

prompt = "an astronaut riding a horse on mars"

generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]

image.save(f"[{branch}]_test_sde-dpmsolver++.png")

## test 4: sde-dpmsolver++, use_karras_sigma=True
seed = 33

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
)

prompt = "an astronaut riding a horse on mars"

generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
image.save(f"[{branch}]_test_sde-dpmsolver++_k_sigma.png")
main this PR
dpm-testing _test_dpmsolver++ dpm-mstep-sigma-2 _test_dpmsolver++
dpm-testing _test_dpmsolver++_k_sigma dpm-mstep-sigma-2 _test_dpmsolver++_k_sigma
dpm-testing _test_sde-dpmsolver++ dpm-mstep-sigma-2 _test_sde-dpmsolver++
dpm-testing _test_sde-dpmsolver++_k_sigma dpm-mstep-sigma-2 _test_sde-dpmsolver++_k_sigma

@@ -54,6 +54,7 @@ def check_over_configs(self, time_step=0, **config):

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
Copy link
Collaborator Author

@yiyixuxu yiyixuxu Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used 0,1,2 directly as timestep index here, but they are not even in the self.timestepes. In the previous implementation, it will default to the last timesteps when it is outside of the timesteps range, but I don't think it is intended. I changed it here, I think it makes more sense this way. Let me know if it's not the case

@@ -241,11 +242,3 @@ def test_fp16_support(self):
sample = scheduler.step(residual, t, sample).prev_sample

assert sample.dtype == torch.float16

def test_unique_timesteps(self, **config):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need this test anymore because we allow duplicated timesteps now

@yiyixuxu
Copy link
Collaborator Author

@patrickvonplaten another review please :)

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
assert len(scheduler.timesteps) == scheduler.num_inference_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@yiyixuxu yiyixuxu merged commit 8263cf0 into main Sep 19, 2023
11 of 12 checks passed
@yiyixuxu yiyixuxu deleted the dpm-mstep-sigma-2 branch September 19, 2023 21:21
@rvorias
Copy link

rvorias commented Sep 22, 2023

@yiyixuxu This refactor has broken stable diffusion (1.5) img2img for unipc

both use denoising strength = 0.6

Using unipc code before this commit:
image

Using unipc after this commit:
image

@ljk1291
Copy link

ljk1291 commented Sep 22, 2023

@yiyixuxu This refactor has broken stable diffusion (1.5) img2img for unipc

both use denoising strength = 0.6

Using unipc code before this commit: image

Using unipc after this commit: image

Same with DPMSolverMultistepScheduler using Karras sigmas. Just returns noise with anything under 1.0 denoising strength. Havent tested it without Karras sigmas.

@yiyixuxu
Copy link
Collaborator Author

@ljk1291
oh, thanks for reporting! I'm looking into it now!
if you have a reproducible example and can open a new issue - that would be great!

@burgalon
Copy link
Contributor

it seems like this bug still resurfaces somehow when using the legacy inpainting pipeline for some reason

@yiyixuxu
Copy link
Collaborator Author

Hi @burgalon

can you be more specific? could you maybe open a bug report and provide an example?

thanks!

YiYi

@burgalon
Copy link
Contributor

burgalon commented Nov 26, 2023

@yiyixuxu
oops... happened to reply here #4631 (comment)

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@yiyixuxu yiyixuxu mentioned this pull request Jan 8, 2024
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants