Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux controlnet fix (control_modes batch & others) #9507

Merged
merged 7 commits into from
Sep 26, 2024
23 changes: 12 additions & 11 deletions src/diffusers/models/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,17 @@ def forward(
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]

control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
if block_samples is not None and control_block_samples is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flux controlnet output can contain None depends on the controlnet has single layers or not,
e.g. union has it https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/main/config.json
but canny does not https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/blob/main/config.json#L16

control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]

return control_block_samples, control_single_block_samples
39 changes: 25 additions & 14 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,12 @@ def __call__(
width_control_image,
)

# set control mode
# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see more context on this PR #9406


elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
Expand Down Expand Up @@ -785,16 +787,22 @@ def __call__(

control_image = control_images

# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
control_modes = []
for cmode in control_mode:
if cmode is None:
cmode = -1
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
control_modes.append(control_mode)
control_mode = control_modes

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
Expand Down Expand Up @@ -840,9 +848,12 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
if isinstance(self.controlnet, FluxMultiControlNetModel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR #9472 makes it possible to use a schnell base with dev controlnet, however it breaks the case when it is a FluxMultiControlNetModel because FluxMultiControlNetModel does not config

Copy link

@xziayro xziayro Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is safer to use this :
in controlnet_flux.py, in the foward method check this:

if isinstance(self.time_text_embed, CombinedTimestepTextProjEmbeddings):
    guidance = None

use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds

guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

# controlnet
Expand Down
Loading