Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux controlnet fix (control_modes batch & others) #9507

Merged
merged 7 commits into from
Sep 26, 2024
Prev Previous commit
Next Next commit
fix use_guidance when controlnet is a multi and does not have config
  • Loading branch information
yiyixuxu committed Sep 23, 2024
commit 560449d8f3c529f2777c88da6305d25ad15687ba
19 changes: 12 additions & 7 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def __call__(
if not isinstance(control_mode, int):
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1)
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see more context on this PR #9406


elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
Expand Down Expand Up @@ -789,11 +789,13 @@ def __call__(

# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError("For Multi-ControlNet, `control_mode` must be a list of the same " +
" length as the number of controlnets (control images) specified")
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
# set control mode
control_modes = []
for cmode in control_mode:
if cmode is None:
Expand Down Expand Up @@ -846,9 +848,12 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
if isinstance(self.controlnet, FluxMultiControlNetModel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR #9472 makes it possible to use a schnell base with dev controlnet, however it breaks the case when it is a FluxMultiControlNetModel because FluxMultiControlNetModel does not config

Copy link

@xziayro xziayro Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is safer to use this :
in controlnet_flux.py, in the foward method check this:

if isinstance(self.time_text_embed, CombinedTimestepTextProjEmbeddings):
    guidance = None

use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds

guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

# controlnet
Expand Down
Loading