Skip to content

Commit

Permalink
flux controlnet fix (control_modes batch & others) (#9507)
Browse files Browse the repository at this point in the history
* flux controlnet mode to take into account batch size

* incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case

* fix

* fix use_guidance when controlnet is a multi and does not have config

---------

Co-authored-by: Christopher Beckham <christopher.j.beckham@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
3 people committed Sep 26, 2024
1 parent 1c6ede9 commit 9cd3755
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
23 changes: 12 additions & 11 deletions src/diffusers/models/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,17 @@ def forward(
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]

control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]

return control_block_samples, control_single_block_samples
39 changes: 25 additions & 14 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,12 @@ def __call__(
width_control_image,
)

# set control mode
# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)

elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
Expand Down Expand Up @@ -785,16 +787,22 @@ def __call__(

control_image = control_images

# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
control_modes = []
for cmode in control_mode:
if cmode is None:
cmode = -1
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
control_modes.append(control_mode)
control_mode = control_modes

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
Expand Down Expand Up @@ -840,9 +848,12 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

guidance = (
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
)
if isinstance(self.controlnet, FluxMultiControlNetModel):
use_guidance = self.controlnet.nets[0].config.guidance_embeds
else:
use_guidance = self.controlnet.config.guidance_embeds

guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

# controlnet
Expand Down

0 comments on commit 9cd3755

Please sign in to comment.