Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SDXL support #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions merge_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False),
default="INFO",
)
@click.option("-xl", "--sdxl", "sdxl", is_flag=True)
def main(
model_a,
model_b,
Expand All @@ -137,6 +138,7 @@ def main(
presets_alpha_lambda,
presets_beta_lambda,
logging_level,
sdxl,
):
if logging_level:
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level)
Expand All @@ -157,6 +159,7 @@ def main(
block_weights_preset_beta_b,
presets_alpha_lambda,
presets_beta_lambda,
sdxl,
)

merged = merge_models(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sd-meh"
version = "0.9.4"
version = "0.10.0"
description = "stable diffusion merging execution helper"
authors = ["s1dlx <s1dlx@proton.me>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion sd_meh/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.4"
__version__ = "0.10.0"
32 changes: 27 additions & 5 deletions sd_meh/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
NUM_OUTPUT_BLOCKS = 12
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS

NUM_INPUT_BLOCKS_XL = 9
NUM_OUTPUT_BLOCKS_XL = 9
NUM_TOTAL_BLOCKS_XL = NUM_INPUT_BLOCKS_XL + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS_XL

KEY_POSITION_IDS = ".".join(
[
"cond_stage_model",
Expand Down Expand Up @@ -144,6 +148,11 @@ def merge_models(
) -> Dict:
thetas = load_thetas(models, prune, device, precision)

sdxl = (
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight"
in thetas["model_a"].keys()
)

logging.info(f"start merging with {merge_mode} method")
if re_basin:
merged = rebasin_merge(
Expand All @@ -157,6 +166,7 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
sdxl=sdxl,
)
else:
merged = simple_merge(
Expand All @@ -169,6 +179,7 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
sdxl=sdxl,
)

return un_prune_model(merged, thetas, models, device, prune, precision)
Expand Down Expand Up @@ -221,6 +232,7 @@ def simple_merge(
device: str = "cpu",
work_device: Optional[str] = None,
threads: int = 1,
sdxl: bool = False,
) -> Dict:
futures = []
with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress:
Expand All @@ -238,6 +250,7 @@ def simple_merge(
weights_clip,
device,
work_device,
sdxl,
)
futures.append(future)

Expand Down Expand Up @@ -270,6 +283,7 @@ def rebasin_merge(
device="cpu",
work_device=None,
threads: int = 1,
sdxl: bool = False,
):
# WARNING: not sure how this does when 3 models are involved...

Expand Down Expand Up @@ -299,6 +313,7 @@ def rebasin_merge(
device,
work_device,
threads,
sdxl,
)

log_vram("simple merge done")
Expand Down Expand Up @@ -367,6 +382,7 @@ def merge_key(
weights_clip: bool = False,
device: str = "cpu",
work_device: Optional[str] = None,
sdxl: bool = False,
) -> Optional[Tuple[str, Dict]]:
if work_device is None:
work_device = device
Expand All @@ -391,16 +407,22 @@ def merge_key(
if "time_embed" in key:
weight_index = 0 # before input blocks
elif ".out." in key:
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
weight_index = (
NUM_TOTAL_BLOCKS_XL - 1 if sdxl else NUM_TOTAL_BLOCKS - 1
) # after output blocks
elif m := re_inp.search(key):
weight_index = int(m.groups()[0])
elif re_mid.search(key):
weight_index = NUM_INPUT_BLOCKS
weight_index = NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS
elif m := re_out.search(key):
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + int(m.groups()[0])
weight_index = (
(NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS)
+ NUM_MID_BLOCK
+ int(m.groups()[0])
)

if weight_index >= NUM_TOTAL_BLOCKS:
raise ValueError(f"illegal block index {key}")
if weight_index >= (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS):
raise ValueError(f"illegal block index {weight_index} for key {key}")

if weight_index >= 0:
current_bases = {k: w[weight_index] for k, w in weights.items()}
Expand Down
6 changes: 2 additions & 4 deletions sd_meh/rebasin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,11 +2200,9 @@ def apply_permutation(ps: PermutationSpec, perm, params):
def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
for k in model_a:
try:
perm_params = get_permuted_param(
ps, perm, k, model_a
)
perm_params = get_permuted_param(ps, perm, k, model_a)
model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
except RuntimeError: # dealing with pix2pix and inpainting models
except RuntimeError: # dealing with pix2pix and inpainting models
continue
return model_a

Expand Down
17 changes: 11 additions & 6 deletions sd_meh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from sd_meh import merge_methods
from sd_meh.merge import NUM_TOTAL_BLOCKS
from sd_meh.merge import NUM_TOTAL_BLOCKS, NUM_TOTAL_BLOCKS_XL
from sd_meh.presets import BLOCK_WEIGHTS_PRESETS

MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction))
Expand All @@ -13,25 +13,25 @@
]


def compute_weights(weights, base):
def compute_weights(weights, base, sdxl: bool = False):
if not weights:
return [base] * NUM_TOTAL_BLOCKS
return [base] * (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS)

if "," not in weights:
return weights

w_alpha = list(map(float, weights.split(",")))
if len(w_alpha) == NUM_TOTAL_BLOCKS:
if len(w_alpha) == (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS):
return w_alpha


def assemble_weights_and_bases(preset, weights, base, greek_letter):
def assemble_weights_and_bases(preset, weights, base, greek_letter, sdxl: bool = False):
logging.info(f"Assembling {greek_letter} w&b")
if preset:
logging.info(f"Using {preset} preset")
base, *weights = BLOCK_WEIGHTS_PRESETS[preset]
bases = {greek_letter: base}
weights = {greek_letter: compute_weights(weights, base)}
weights = {greek_letter: compute_weights(weights, base, sdxl)}

logging.info(f"base_{greek_letter}: {bases[greek_letter]}")
logging.info(f"{greek_letter} weights: {weights[greek_letter]}")
Expand Down Expand Up @@ -70,12 +70,14 @@ def weights_and_bases(
block_weights_preset_beta_b,
presets_alpha_lambda,
presets_beta_lambda,
sdxl: bool = False,
):
weights, bases = assemble_weights_and_bases(
block_weights_preset_alpha,
weights_alpha,
base_alpha,
"alpha",
sdxl,
)

if block_weights_preset_alpha_b:
Expand All @@ -85,6 +87,7 @@ def weights_and_bases(
None,
None,
"alpha",
sdxl,
)
weights, bases = interpolate_presets(
weights,
Expand All @@ -101,6 +104,7 @@ def weights_and_bases(
weights_beta,
base_beta,
"beta",
sdxl,
)

if block_weights_preset_beta_b:
Expand All @@ -110,6 +114,7 @@ def weights_and_bases(
None,
None,
"beta",
sdxl,
)
weights, bases = interpolate_presets(
weights,
Expand Down
Loading