Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SVD update with conv support + LoRA add, weight update following recent updates #140

Merged
merged 14 commits into from
Jan 29, 2023
Merged
Binary file modified contents/disney_lora.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified contents/lion_illust.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified contents/lora_pti_example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified contents/pop_art.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed example_loras/analog_svd_distill.pt
Binary file not shown.
Binary file removed example_loras/analog_svd_distill.text_encoder.pt
Binary file not shown.
Binary file added example_loras/analog_svd_rank4.safetensors
Binary file not shown.
Binary file added example_loras/analog_svd_rank8.safetensors
Binary file not shown.
Binary file added example_loras/modern_disney_svd.safetensors
Binary file not shown.
173 changes: 115 additions & 58 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import shutil
import fire
from diffusers import StableDiffusionPipeline
from safetensors.torch import safe_open, save_file

import torch
from .lora import tune_lora_scale, weight_apply_lora
from .lora import (
tune_lora_scale,
patch_pipe,
collapse_lora,
monkeypatch_remove_lora,
)
from .to_ckpt_v2 import convert_to_ckpt


Expand All @@ -18,7 +24,8 @@ def add(
path_1: str,
path_2: str,
output_path: str,
alpha: float = 0.5,
alpha_1: float = 0.5,
alpha_2: float = 0.5,
mode: Literal[
"lpl",
"upl",
Expand All @@ -28,79 +35,116 @@ def add(
):
print("Lora Add, mode " + mode)
if mode == "lpl":
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
if with_text_lora
else []
):
print("Loading", _path_1, _path_2)
out_list = []
if opt == "text_encoder":
if not os.path.exists(_path_1):
print(f"No text encoder found in {_path_1}, skipping...")
continue
if not os.path.exists(_path_2):
print(f"No text encoder found in {_path_1}, skipping...")
continue

l1 = torch.load(_path_1)
l2 = torch.load(_path_2)

l1pairs = zip(l1[::2], l1[1::2])
l2pairs = zip(l2[::2], l2[1::2])

for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
x1.data = alpha * x1.data + (1 - alpha) * x2.data
y1.data = alpha * y1.data + (1 - alpha) * y2.data

out_list.append(x1)
out_list.append(y1)

if opt == "unet":

print("Saving merged UNET to", output_path)
torch.save(out_list, output_path)

elif opt == "text_encoder":
print("Saving merged text encoder to", _text_lora_path(output_path))
torch.save(
out_list,
_text_lora_path(output_path),
)
if path_1.endswith(".pt") and path_2.endswith(".pt"):
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
if with_text_lora
else []
):
print("Loading", _path_1, _path_2)
out_list = []
if opt == "text_encoder":
if not os.path.exists(_path_1):
print(f"No text encoder found in {_path_1}, skipping...")
continue
if not os.path.exists(_path_2):
print(f"No text encoder found in {_path_1}, skipping...")
continue

l1 = torch.load(_path_1)
l2 = torch.load(_path_2)

l1pairs = zip(l1[::2], l1[1::2])
l2pairs = zip(l2[::2], l2[1::2])

for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
x1.data = alpha_1 * x1.data + alpha_2 * x2.data
y1.data = alpha_1 * y1.data + alpha_2 * y2.data

out_list.append(x1)
out_list.append(y1)

if opt == "unet":

print("Saving merged UNET to", output_path)
torch.save(out_list, output_path)

elif opt == "text_encoder":
print("Saving merged text encoder to", _text_lora_path(output_path))
torch.save(
out_list,
_text_lora_path(output_path),
)

elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"):
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")

metadata = dict(safeloras_1.metadata())
metadata.update(dict(safeloras_2.metadata()))

ret_tensor = {}

for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())):
if keys.startswith("text_encoder") or keys.startswith("unet"):

tens1 = safeloras_1.get_tensor(keys)
tens2 = safeloras_2.get_tensor(keys)

tens = alpha_1 * tens1 + alpha_2 * tens2
ret_tensor[keys] = tens
else:
if keys in safeloras_1.keys():

tens1 = safeloras_1.get_tensor(keys)
else:
tens1 = safeloras_2.get_tensor(keys)

ret_tensor[keys] = tens1

save_file(ret_tensor, output_path, metadata)

elif mode == "upl":

print(
f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}."
)

loaded_pipeline = StableDiffusionPipeline.from_pretrained(
path_1,
).to("cpu")

weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
if with_text_lora:
patch_pipe(loaded_pipeline, path_2)

collapse_lora(loaded_pipeline.unet, alpha_1)
collapse_lora(loaded_pipeline.text_encoder, alpha_1)

weight_apply_lora(
loaded_pipeline.text_encoder,
torch.load(_text_lora_path(path_2)),
alpha=alpha,
target_replace_module=["CLIPAttention"],
)
monkeypatch_remove_lora(loaded_pipeline.unet)
monkeypatch_remove_lora(loaded_pipeline.text_encoder)

loaded_pipeline.save_pretrained(output_path)

elif mode == "upl-ckpt-v2":

assert output_path.endswith(".ckpt"), "Only .ckpt files are supported"
name = os.path.basename(output_path)[0:-5]

print(
f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token."
)

loaded_pipeline = StableDiffusionPipeline.from_pretrained(
path_1,
).to("cpu")

weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
if with_text_lora:
weight_apply_lora(
loaded_pipeline.text_encoder,
torch.load(_text_lora_path(path_2)),
alpha=alpha,
target_replace_module=["CLIPAttention"],
)
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)

collapse_lora(loaded_pipeline.unet, alpha_1)
collapse_lora(loaded_pipeline.text_encoder, alpha_1)

monkeypatch_remove_lora(loaded_pipeline.unet)
monkeypatch_remove_lora(loaded_pipeline.text_encoder)

_tmp_output = output_path + ".tmp"

Expand All @@ -109,6 +153,19 @@ def add(
# remove the tmp_output folder
shutil.rmtree(_tmp_output)

keys = sorted(tok_dict.keys())
tok_catted = torch.stack([tok_dict[k] for k in keys])
ret = {
"string_to_token": {"*": torch.tensor(265)},
"string_to_param": {"*": tok_catted},
"name": name,
}

torch.save(ret, output_path[:-5] + ".pt")
print(
f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
)

else:
print("Unknown mode", mode)
raise ValueError(f"Unknown mode {mode}")
Expand Down
41 changes: 25 additions & 16 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def loss_step(
scheduler,
t_mutliplier=1.0,
mixed_precision=False,
mask_temperature=1.0,
):
weight_dtype = torch.float32

Expand Down Expand Up @@ -231,15 +232,14 @@ def loss_step(
)
)
# resize to match model_pred
mask = (
F.interpolate(
mask.float(),
size=model_pred.shape[-2:],
mode="nearest",
)
+ 0.05
mask = F.interpolate(
mask.float(),
size=model_pred.shape[-2:],
mode="nearest",
)

mask = (mask + 0.01).pow(mask_temperature)

mask = mask / mask.max()

model_pred = model_pred * mask
Expand Down Expand Up @@ -422,6 +422,7 @@ def perform_tuning(
lr_scheduler_lora,
lora_unet_target_modules,
lora_clip_target_modules,
mask_temperature,
):

progress_bar = tqdm(range(num_steps))
Expand All @@ -447,6 +448,7 @@ def perform_tuning(
scheduler,
t_mutliplier=0.8,
mixed_precision=True,
mask_temperature=mask_temperature,
)
loss.backward()
torch.nn.utils.clip_grad_norm_(
Expand Down Expand Up @@ -506,7 +508,7 @@ def train(
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Literal[None, "object", "style"] = None,
placeholder_tokens: str = "<s>",
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
class_prompt: Optional[str] = None,
Expand Down Expand Up @@ -536,6 +538,7 @@ def train(
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
use_mask_captioned_data: bool = False,
mask_temperature: float = 1.0,
scale_lr: bool = False,
lr_scheduler: str = "linear",
lr_warmup_steps: int = 0,
Expand Down Expand Up @@ -568,9 +571,14 @@ def train(
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
# print(placeholder_tokens, initializer_tokens)
placeholder_tokens = placeholder_tokens.split("|")
if len(placeholder_tokens) == 0:
placeholder_tokens = []
print("PTI : Placeholder Tokens not given, using null token")
else:
placeholder_tokens = placeholder_tokens.split("|")

if initializer_tokens is None:
print("PTI : Initializer Token not give, random inits")
print("PTI : Initializer Tokens not given, doing random inits")
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
else:
initializer_tokens = initializer_tokens.split("|")
Expand All @@ -588,8 +596,8 @@ def train(
else:
token_map = {"DUMMY": "".join(placeholder_tokens)}

print("Placeholder Tokens", placeholder_tokens)
print("Initializer Tokens", initializer_tokens)
print("PTI : Placeholder Tokens", placeholder_tokens)
print("PTI : Initializer Tokens", initializer_tokens)

# get the models
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
Expand Down Expand Up @@ -639,7 +647,7 @@ def train(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)

index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_ids[0]
index_no_updates = torch.arange(len(tokenizer)) != -1

for tok_id in placeholder_token_ids:
index_no_updates[tok_id] = False
Expand Down Expand Up @@ -704,18 +712,18 @@ def train(
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
)
else:
print("USING EXTENDED UNET!!!")
print("PTI : USING EXTENDED UNET!!!")
lora_unet_target_modules = (
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
)
print("Will replace modules: ", lora_unet_target_modules)
print("PTI : Will replace modules: ", lora_unet_target_modules)

unet_lora_params, _ = inject_trainable_lora_extended(
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
)
print(f"PTI : has {len(unet_lora_params)} lora")

print("Before training:")
print("PTI : Before training:")
inspect_lora(unet)

params_to_optimize = [
Expand Down Expand Up @@ -787,6 +795,7 @@ def train(
lr_scheduler_lora=lr_scheduler_lora,
lora_unet_target_modules=lora_unet_target_modules,
lora_clip_target_modules=lora_clip_target_modules,
mask_temperature=mask_temperature,
)


Expand Down
Loading