Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Manager for managing multiple loras with different scaling parameters & tokens #172

Merged
merged 3 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added example_loras/and.safetensors
Binary file not shown.
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .dataset import *
from .utils import *
from .preprocess_files import *
from .lora_manager import *
53 changes: 1 addition & 52 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
collapse_lora,
monkeypatch_remove_lora,
)
from .lora_manager import lora_join
from .to_ckpt_v2 import convert_to_ckpt


Expand All @@ -20,58 +21,6 @@ def _text_lora_path(path: str) -> str:
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])


def lora_join(lora_safetenors: list):
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
total_metadata = {}
total_tensor = {}
total_rank = 0
ranklist = []
for _metadata in metadatas:
rankset = []
for k, v in _metadata.items():
if k.endswith("rank"):
rankset.append(int(v))

assert len(set(rankset)) == 1, "Rank should be the same per model"
total_rank += rankset[0]
total_metadata.update(_metadata)
ranklist.append(rankset[0])

tensorkeys = set()
for safelora in lora_safetenors:
tensorkeys.update(safelora.keys())

for keys in tensorkeys:
if keys.startswith("text_encoder") or keys.startswith("unet"):
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]

is_down = keys.endswith("down")

if is_down:
_tensor = torch.cat(tensorset, dim=0)
assert _tensor.shape[0] == total_rank
else:
_tensor = torch.cat(tensorset, dim=1)
assert _tensor.shape[1] == total_rank

total_tensor[keys] = _tensor
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
total_metadata[keys_rank] = str(total_rank)
token_size_list = []
for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):
if total_metadata.get(token, None) is not None:
del total_metadata[token]
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")

token_size_list.append(len(tokens))

return total_tensor, total_metadata, ranklist, token_size_list


def add(
path_1: str,
path_2: str,
Expand Down
134 changes: 134 additions & 0 deletions lora_diffusion/lora_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import List
import torch
from safetensors import safe_open
from diffusers import StableDiffusionPipeline
from .lora import (
monkeypatch_or_replace_safeloras,
apply_learned_embed_in_clip,
set_lora_diag,
parse_safeloras_embeds,
)


def lora_join(lora_safetenors: list):
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
total_metadata = {}
total_tensor = {}
total_rank = 0
ranklist = []
for _metadata in metadatas:
rankset = []
for k, v in _metadata.items():
if k.endswith("rank"):
rankset.append(int(v))

assert len(set(rankset)) == 1, "Rank should be the same per model"
total_rank += rankset[0]
total_metadata.update(_metadata)
ranklist.append(rankset[0])

tensorkeys = set()
for safelora in lora_safetenors:
tensorkeys.update(safelora.keys())

for keys in tensorkeys:
if keys.startswith("text_encoder") or keys.startswith("unet"):
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]

is_down = keys.endswith("down")

if is_down:
_tensor = torch.cat(tensorset, dim=0)
assert _tensor.shape[0] == total_rank
else:
_tensor = torch.cat(tensorset, dim=1)
assert _tensor.shape[1] == total_rank

total_tensor[keys] = _tensor
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
total_metadata[keys_rank] = str(total_rank)
token_size_list = []
for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):

total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"

print(f"Embedding {token} replaced to <s{idx}-{jdx}>")

if total_metadata.get(token, None) is not None:
del total_metadata[token]

token_size_list.append(len(tokens))

return total_tensor, total_metadata, ranklist, token_size_list


class DummySafeTensorObject:
def __init__(self, tensor: dict, metadata):
self.tensor = tensor
self._metadata = metadata

def keys(self):
return self.tensor.keys()

def metadata(self):
return self._metadata

def get_tensor(self, key):
return self.tensor[key]


class LoRAManager:
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):

self.lora_paths_list = lora_paths_list
self.pipe = pipe
self._setup()

def _setup(self):

self._lora_safetenors = [
safe_open(path, framework="pt", device="cpu")
for path in self.lora_paths_list
]

(
total_tensor,
total_metadata,
self.ranklist,
self.token_size_list,
) = lora_join(self._lora_safetenors)

self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)

monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
tok_dict = parse_safeloras_embeds(self.total_safelora)

apply_learned_embed_in_clip(
tok_dict,
self.pipe.text_encoder,
self.pipe.tokenizer,
token=None,
idempotent=True,
)

def tune(self, scales):

diags = []
for scale, rank in zip(scales, self.ranklist):
diags = diags + [scale] * rank

set_lora_diag(self.pipe.unet, torch.tensor(diags))

def prompt(self, prompt):
if prompt is not None:
for idx, tok_size in enumerate(self.token_size_list):
prompt = prompt.replace(
f"<{idx + 1}>",
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
)
# TODO : Rescale LoRA + Text inputs based on prompt scale params

return prompt