Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux controlnet fix (control_modes batch & others) #9507

Merged
merged 7 commits into from
Sep 26, 2024

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Sep 23, 2024

this PR:

  1. fix control_mode shape when num_images_per_prompt > 1 (continue the PR from @christopher-beckham Fix flux controlnet mode to take into account batch size #9406)
  2. fix a bug when multi-controlnet (follow up on from Several fixes to Flux ControlNet pipelines #9472)
  3. fix a bug when use multi-controlnet with regular controllers that may not have single blocks

slow tests here

import torch
torch.cuda.set_device(2) 

# num_images_per_prompt = 1
num_images_per_prompt = 2
branch = "test" 
# branch = "main"

import gc
def flush():
    gc.collect()
    torch.cuda.empty_cache()

# test1: single controlnet (canny)
import torch
from diffusers.utils import load_image
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.models.controlnet_flux import FluxControlNetModel

base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")

control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
generator = torch.Generator(device="cuda").manual_seed(42)
images_out = pipe(
    prompt, 
    control_image=control_image,
    controlnet_conditioning_scale=0.6,
    num_inference_steps=28, 
    num_images_per_prompt=num_images_per_prompt,
    guidance_scale=3.5,
    generator=generator,
).images
print(images_out)
for i, image in enumerate(images_out):
    image.save(f"yiyi_test_7_{branch}_num_images_per_prompt_{num_images_per_prompt}_test1_out_{i}.png")


del pipe
flush()

# test2: single controlnet (union)
import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel

base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'

controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")

control_image_canny = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/canny.jpg")
controlnet_conditioning_scale = 0.5
control_mode = 0

width, height = control_image_canny.size

prompt = 'A bohemian-style female travel blogger with sun-kissed skin and messy beach waves.'
generator = torch.Generator(device="cuda").manual_seed(42)

images_out = pipe(
    prompt, 
    control_image=control_image_canny,
    control_mode=control_mode,
    width=width,
    height=height,
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    num_inference_steps=24, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    guidance_scale=3.5,
).images
for i, image in enumerate(images_out):
    image.save(f"yiyi_test_7_{branch}_num_images_per_prompt_{num_images_per_prompt}_test2_out_{i}.png")

del pipe
flush()



# test3: multiple controlnet (regular controlnets)
# note that we only have 1 regular controlnet now, so testing with 2 canny (this has no real use case but want to make sure it works regardless)
import torch
from diffusers.utils import load_image
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel

base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_canny = 'InstantX/FLUX.1-dev-Controlnet-Canny'
controlnet = FluxControlNetModel.from_pretrained(controlnet_canny, torch_dtype=torch.bfloat16)
multi_controlnet = FluxMultiControlNetModel([controlnet] * 2)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=multi_controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")

control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
generator = torch.Generator(device="cuda").manual_seed(42)
# currently does not work on main branch
try:    
    images_out = pipe(
        prompt,
        control_image=[control_image, control_image],
        controlnet_conditioning_scale=[0.6, 0.6],
        num_inference_steps=28,
        guidance_scale=3.5,
        num_images_per_prompt=num_images_per_prompt,
        generator=generator,
    ).images
    for i, image in enumerate(images_out):
        image.save(f"yiyi_test_7_{branch}_num_images_per_prompt_{num_images_per_prompt}_test3_out_{i}.png")
except Exception as e:
    print(e)

del pipe
flush()

# test4: multi controlnet with union
import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel

base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'InstantX/FLUX.1-dev-Controlnet-Union'

controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel

pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = 'A bohemian-style female travel blogger with sun-kissed skin and messy beach waves.'
control_image_depth = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/resolve/main/images/depth.jpg")
control_mode_depth = 2

control_image_canny = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/resolve/main/images/canny.jpg")
control_mode_canny = 0

width, height = control_image_canny.size
generator = torch.Generator(device="cuda").manual_seed(42)
out_images = pipe(
    prompt,
    control_image=[control_image_depth, control_image_canny],
    control_mode=[control_mode_depth, control_mode_canny],
    width=width,
    height=height,
    controlnet_conditioning_scale=[0.2, 0.4],
    num_inference_steps=24, 
    num_images_per_prompt=num_images_per_prompt,
    guidance_scale=3.5,
    generator=generator,
).images


for i, out in enumerate(out_images):
    out.save(f'yiyi_test_7_{branch}_num_images_per_prompt_{num_images_per_prompt}_test4_out_{i}.png')

control_single_block_samples, single_block_samples
)
]
if block_samples is not None and control_block_samples is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flux controlnet output can contain None depends on the controlnet has single layers or not,
e.g. union has it https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/main/config.json
but canny does not https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/blob/main/config.json#L16

control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see more context on this PR #9406

guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
if isinstance(self.controlnet, FluxMultiControlNetModel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR #9472 makes it possible to use a schnell base with dev controlnet, however it breaks the case when it is a FluxMultiControlNetModel because FluxMultiControlNetModel does not config

Copy link

@xziayro xziayro Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is safer to use this :
in controlnet_flux.py, in the foward method check this:

if isinstance(self.time_text_embed, CombinedTimestepTextProjEmbeddings):
    guidance = None

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@asomoza asomoza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM I also tested it just in case.

@christopher-beckham
Copy link
Contributor

christopher-beckham commented Sep 25, 2024

One suggestion I have, it doesn't appear that any checking is actually done to see the lengths of all the arguments in zip are the same length. This could lead to a "silent bug" where e.g. the user may accidentally mis-specify something like conditioning_scale=[0.6] (i.e. a length-1 list instead of 2) and zip doesn't give any warnings if the length of each of its arguments are not the same.

# Regular Multi-ControlNets
# load all ControlNets into memories
else:
for i, (image, mode, scale, controlnet) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
):

It might be useful to add a check in for that before the zip call. I wrote this below:

lengths = [
    len(controlnet_cond), len(controlnet_mode), len(conditioning_scale), 
        len(self.nets)
]
if len(set(lengths)) != 1:
    raise ValueError(
        "`controlnet_cond`, `controlnet_mode`, `conditioning_scale`, `self.nets` " + 
        f"  must all have the same length, we got (respectively): {lengths}"
    )

Otherwise it lgtm. Thanks!

Edit: actually ideally it would also be good to check if they're lists too... Type checking in Python is annoying, isn't it?

@yiyixuxu
Copy link
Collaborator Author

@christopher-beckham nice suggestion! thanks!
I think we should add a input check for all our controlnet (will do that in a follow up PR)

@yiyixuxu yiyixuxu merged commit 9cd3755 into main Sep 26, 2024
17 of 18 checks passed
@yiyixuxu yiyixuxu deleted the flux_controlnet_modes_yiyi branch September 26, 2024 05:09
@christopher-beckham
Copy link
Contributor

@yiyixuxu thanks! Just fyi I have something already in the works for the input checking, and can PR that today if you want. It's actually a lot easier than I thought it was, just some logic which needs to be pulled in from the corresponding controlnet pipeline for SDXL. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants