-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for GLIGEN textbox model.
- Loading branch information
1 parent
472b1cc
commit 3696d16
Showing
9 changed files
with
491 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
import torch | ||
from torch import nn, einsum | ||
from ldm.modules.attention import CrossAttention | ||
from inspect import isfunction | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
def uniq(arr): | ||
return{el: True for el in arr}.keys() | ||
|
||
|
||
def default(val, d): | ||
if exists(val): | ||
return val | ||
return d() if isfunction(d) else d | ||
|
||
|
||
# feedforward | ||
class GEGLU(nn.Module): | ||
def __init__(self, dim_in, dim_out): | ||
super().__init__() | ||
self.proj = nn.Linear(dim_in, dim_out * 2) | ||
|
||
def forward(self, x): | ||
x, gate = self.proj(x).chunk(2, dim=-1) | ||
return x * torch.nn.functional.gelu(gate) | ||
|
||
|
||
class FeedForward(nn.Module): | ||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): | ||
super().__init__() | ||
inner_dim = int(dim * mult) | ||
dim_out = default(dim_out, dim) | ||
project_in = nn.Sequential( | ||
nn.Linear(dim, inner_dim), | ||
nn.GELU() | ||
) if not glu else GEGLU(dim, inner_dim) | ||
|
||
self.net = nn.Sequential( | ||
project_in, | ||
nn.Dropout(dropout), | ||
nn.Linear(inner_dim, dim_out) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class GatedCrossAttentionDense(nn.Module): | ||
def __init__(self, query_dim, context_dim, n_heads, d_head): | ||
super().__init__() | ||
|
||
self.attn = CrossAttention( | ||
query_dim=query_dim, | ||
context_dim=context_dim, | ||
heads=n_heads, | ||
dim_head=d_head) | ||
self.ff = FeedForward(query_dim, glu=True) | ||
|
||
self.norm1 = nn.LayerNorm(query_dim) | ||
self.norm2 = nn.LayerNorm(query_dim) | ||
|
||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) | ||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) | ||
|
||
# this can be useful: we can externally change magnitude of tanh(alpha) | ||
# for example, when it is set to 0, then the entire model is same as | ||
# original one | ||
self.scale = 1 | ||
|
||
def forward(self, x, objs): | ||
|
||
x = x + self.scale * \ | ||
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) | ||
x = x + self.scale * \ | ||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) | ||
|
||
return x | ||
|
||
|
||
class GatedSelfAttentionDense(nn.Module): | ||
def __init__(self, query_dim, context_dim, n_heads, d_head): | ||
super().__init__() | ||
|
||
# we need a linear projection since we need cat visual feature and obj | ||
# feature | ||
self.linear = nn.Linear(context_dim, query_dim) | ||
|
||
self.attn = CrossAttention( | ||
query_dim=query_dim, | ||
context_dim=query_dim, | ||
heads=n_heads, | ||
dim_head=d_head) | ||
self.ff = FeedForward(query_dim, glu=True) | ||
|
||
self.norm1 = nn.LayerNorm(query_dim) | ||
self.norm2 = nn.LayerNorm(query_dim) | ||
|
||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) | ||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) | ||
|
||
# this can be useful: we can externally change magnitude of tanh(alpha) | ||
# for example, when it is set to 0, then the entire model is same as | ||
# original one | ||
self.scale = 1 | ||
|
||
def forward(self, x, objs): | ||
|
||
N_visual = x.shape[1] | ||
objs = self.linear(objs) | ||
|
||
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( | ||
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] | ||
x = x + self.scale * \ | ||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) | ||
|
||
return x | ||
|
||
|
||
class GatedSelfAttentionDense2(nn.Module): | ||
def __init__(self, query_dim, context_dim, n_heads, d_head): | ||
super().__init__() | ||
|
||
# we need a linear projection since we need cat visual feature and obj | ||
# feature | ||
self.linear = nn.Linear(context_dim, query_dim) | ||
|
||
self.attn = CrossAttention( | ||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head) | ||
self.ff = FeedForward(query_dim, glu=True) | ||
|
||
self.norm1 = nn.LayerNorm(query_dim) | ||
self.norm2 = nn.LayerNorm(query_dim) | ||
|
||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) | ||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) | ||
|
||
# this can be useful: we can externally change magnitude of tanh(alpha) | ||
# for example, when it is set to 0, then the entire model is same as | ||
# original one | ||
self.scale = 1 | ||
|
||
def forward(self, x, objs): | ||
|
||
B, N_visual, _ = x.shape | ||
B, N_ground, _ = objs.shape | ||
|
||
objs = self.linear(objs) | ||
|
||
# sanity check | ||
size_v = math.sqrt(N_visual) | ||
size_g = math.sqrt(N_ground) | ||
assert int(size_v) == size_v, "Visual tokens must be square rootable" | ||
assert int(size_g) == size_g, "Grounding tokens must be square rootable" | ||
size_v = int(size_v) | ||
size_g = int(size_g) | ||
|
||
# select grounding token and resize it to visual token size as residual | ||
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[ | ||
:, N_visual:, :] | ||
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g) | ||
out = torch.nn.functional.interpolate( | ||
out, (size_v, size_v), mode='bicubic') | ||
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1) | ||
|
||
# add residual to visual feature | ||
x = x + self.scale * torch.tanh(self.alpha_attn) * residual | ||
x = x + self.scale * \ | ||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) | ||
|
||
return x | ||
|
||
|
||
class FourierEmbedder(): | ||
def __init__(self, num_freqs=64, temperature=100): | ||
|
||
self.num_freqs = num_freqs | ||
self.temperature = temperature | ||
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) | ||
|
||
@torch.no_grad() | ||
def __call__(self, x, cat_dim=-1): | ||
"x: arbitrary shape of tensor. dim: cat dim" | ||
out = [] | ||
for freq in self.freq_bands: | ||
out.append(torch.sin(freq * x)) | ||
out.append(torch.cos(freq * x)) | ||
return torch.cat(out, cat_dim) | ||
|
||
|
||
class PositionNet(nn.Module): | ||
def __init__(self, in_dim, out_dim, fourier_freqs=8): | ||
super().__init__() | ||
self.in_dim = in_dim | ||
self.out_dim = out_dim | ||
|
||
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) | ||
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy | ||
|
||
self.linears = nn.Sequential( | ||
nn.Linear(self.in_dim + self.position_dim, 512), | ||
nn.SiLU(), | ||
nn.Linear(512, 512), | ||
nn.SiLU(), | ||
nn.Linear(512, out_dim), | ||
) | ||
|
||
self.null_positive_feature = torch.nn.Parameter( | ||
torch.zeros([self.in_dim])) | ||
self.null_position_feature = torch.nn.Parameter( | ||
torch.zeros([self.position_dim])) | ||
|
||
def forward(self, boxes, masks, positive_embeddings): | ||
B, N, _ = boxes.shape | ||
masks = masks.unsqueeze(-1) | ||
|
||
# embedding position (it may includes padding as placeholder) | ||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C | ||
|
||
# learnable null embedding | ||
positive_null = self.null_positive_feature.view(1, 1, -1) | ||
xyxy_null = self.null_position_feature.view(1, 1, -1) | ||
|
||
# replace padding with learnable null embedding | ||
positive_embeddings = positive_embeddings * \ | ||
masks + (1 - masks) * positive_null | ||
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null | ||
|
||
objs = self.linears( | ||
torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) | ||
assert objs.shape == torch.Size([B, N, self.out_dim]) | ||
return objs | ||
|
||
|
||
class Gligen(nn.Module): | ||
def __init__(self, modules, position_net, key_dim): | ||
super().__init__() | ||
self.module_list = nn.ModuleList(modules) | ||
self.position_net = position_net | ||
self.key_dim = key_dim | ||
self.max_objs = 30 | ||
|
||
def _set_position(self, boxes, masks, positive_embeddings): | ||
objs = self.position_net(boxes, masks, positive_embeddings) | ||
|
||
def func(key, x): | ||
module = self.module_list[key] | ||
return module(x, objs) | ||
return func | ||
|
||
def set_position(self, latent_image_shape, position_params, device): | ||
batch, c, h, w = latent_image_shape | ||
masks = torch.zeros([self.max_objs], device="cpu") | ||
boxes = [] | ||
positive_embeddings = [] | ||
for p in position_params: | ||
x1 = (p[4]) / w | ||
y1 = (p[3]) / h | ||
x2 = (p[4] + p[2]) / w | ||
y2 = (p[3] + p[1]) / h | ||
masks[len(boxes)] = 1.0 | ||
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)] | ||
positive_embeddings += [p[0]] | ||
append_boxes = [] | ||
append_conds = [] | ||
if len(boxes) < self.max_objs: | ||
append_boxes = [torch.zeros( | ||
[self.max_objs - len(boxes), 4], device="cpu")] | ||
append_conds = [torch.zeros( | ||
[self.max_objs - len(boxes), self.key_dim], device="cpu")] | ||
|
||
box_out = torch.cat( | ||
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1) | ||
masks = masks.unsqueeze(0).repeat(batch, 1) | ||
conds = torch.cat(positive_embeddings + | ||
append_conds).unsqueeze(0).repeat(batch, 1, 1) | ||
return self._set_position( | ||
box_out.to(device), | ||
masks.to(device), | ||
conds.to(device)) | ||
|
||
def set_empty(self, latent_image_shape, device): | ||
batch, c, h, w = latent_image_shape | ||
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1) | ||
box_out = torch.zeros([self.max_objs, 4], | ||
device="cpu").repeat(batch, 1, 1) | ||
conds = torch.zeros([self.max_objs, self.key_dim], | ||
device="cpu").repeat(batch, 1, 1) | ||
return self._set_position( | ||
box_out.to(device), | ||
masks.to(device), | ||
conds.to(device)) | ||
|
||
def cleanup(self): | ||
pass | ||
|
||
def get_models(self): | ||
return [self] | ||
|
||
def load_gligen(sd): | ||
sd_k = sd.keys() | ||
output_list = [] | ||
key_dim = 768 | ||
for a in ["input_blocks", "middle_block", "output_blocks"]: | ||
for b in range(20): | ||
k_temp = filter(lambda k: "{}.{}.".format(a, b) | ||
in k and ".fuser." in k, sd_k) | ||
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp) | ||
|
||
n_sd = {} | ||
for k in k_temp: | ||
n_sd[k[1]] = sd[k[0]] | ||
if len(n_sd) > 0: | ||
query_dim = n_sd["linear.weight"].shape[0] | ||
key_dim = n_sd["linear.weight"].shape[1] | ||
|
||
if key_dim == 768: # SD1.x | ||
n_heads = 8 | ||
d_head = query_dim // n_heads | ||
else: | ||
d_head = 64 | ||
n_heads = query_dim // d_head | ||
|
||
gated = GatedSelfAttentionDense( | ||
query_dim, key_dim, n_heads, d_head) | ||
gated.load_state_dict(n_sd, strict=False) | ||
output_list.append(gated) | ||
|
||
if "position_net.null_positive_feature" in sd_k: | ||
in_dim = sd["position_net.null_positive_feature"].shape[0] | ||
out_dim = sd["position_net.linears.4.weight"].shape[0] | ||
|
||
class WeightsLoader(torch.nn.Module): | ||
pass | ||
w = WeightsLoader() | ||
w.position_net = PositionNet(in_dim, out_dim) | ||
w.load_state_dict(sd, strict=False) | ||
|
||
gligen = Gligen(output_list, w.position_net, key_dim) | ||
return gligen |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.