-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flux controlnet fix (control_modes batch & others) #9507
Changes from 1 commit
370f382
c333d89
7c95f0b
61e0950
db8178a
560449d
53bee55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -752,7 +752,7 @@ def __call__( | |
if not isinstance(control_mode, int): | ||
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") | ||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) | ||
control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1) | ||
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) | ||
|
||
elif isinstance(self.controlnet, FluxMultiControlNetModel): | ||
control_images = [] | ||
|
@@ -789,11 +789,13 @@ def __call__( | |
|
||
# Here we ensure that `control_mode` has the same length as the control_image. | ||
if isinstance(control_mode, list) and len(control_mode) != len(control_image): | ||
raise ValueError("For Multi-ControlNet, `control_mode` must be a list of the same " + | ||
" length as the number of controlnets (control images) specified") | ||
raise ValueError( | ||
"For Multi-ControlNet, `control_mode` must be a list of the same " | ||
+ " length as the number of controlnets (control images) specified" | ||
) | ||
if not isinstance(control_mode, list): | ||
control_mode = [control_mode] * len(control_image) | ||
# set control mode | ||
# set control mode | ||
control_modes = [] | ||
for cmode in control_mode: | ||
if cmode is None: | ||
|
@@ -846,9 +848,12 @@ def __call__( | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
timestep = t.expand(latents.shape[0]).to(latents.dtype) | ||
|
||
guidance = ( | ||
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None | ||
) | ||
if isinstance(self.controlnet, FluxMultiControlNetModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this PR #9472 makes it possible to use a schnell base with dev controlnet, however it breaks the case when it is a FluxMultiControlNetModel because FluxMultiControlNetModel does not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is safer to use this :
|
||
use_guidance = self.controlnet.nets[0].config.guidance_embeds | ||
else: | ||
use_guidance = self.controlnet.config.guidance_embeds | ||
|
||
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None | ||
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None | ||
|
||
# controlnet | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see more context on this PR #9406