Skip to content

Commit

Permalink
[ssl] add nestrq
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Sep 20, 2024
1 parent a198c1d commit 42798bc
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 0 deletions.
1 change: 1 addition & 0 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def _stack_features(self, input: torch.Tensor,
def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1))
mask = mask.unsqueeze(2).repeat(1, 1, self.num_codebooks)
loss = torch.nn.functional.cross_entropy(
logits,
target.contiguous().view(-1),
Expand Down
199 changes: 199 additions & 0 deletions wenet/ssl/nestrq/nestrq_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import math
from typing import Dict, Tuple
import torch
from wenet.ssl.bestrq.bestrq_model import quantize_vector

from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.encoder_layer import ConformerEncoderLayer


class NestRQModel(torch.nn.Module):
""" https://arxiv.org/pdf/2409.08680
"""

def __init__(
self,
encoder: torch.nn.Module,
num_mel_bins: int = 80,
embedding_dim: int = 16,
num_embeddings: int = 8192,
num_codebooks: int = 1,
out_bias: bool = False,
) -> None:
super().__init__()
self.num_codebooks = num_codebooks
self.num_embeddings = num_embeddings

# encoder
self.encoder = encoder
# n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, self.encoder.output_size(),
num_embeddings))
torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02)
self.out_bias = out_bias
if self.out_bias:
self.encoder_top_n_out_bias = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, num_embeddings))
torch.nn.init.zeros_(self.encoder_top_n_out_bias)

# stack input: eg: fbank
self.stack_frames = self.encoder.embed.right_context + 1
self.stride = self.encoder.embed.subsampling_rate
input_dim = num_mel_bins * self.stride

# random projectoin
self.projection = torch.nn.parameter.Parameter(
torch.empty(input_dim, embedding_dim * self.num_codebooks),
requires_grad=False,
)
torch.nn.init.xavier_uniform_(self.projection)
self.norm = torch.nn.LayerNorm(self.stack_frames * num_mel_bins,
eps=1e-6,
elementwise_affine=False,
bias=False)

# codebook
# [num_embeddings, num_codebooks, num_embeddings] means
# [C, G, D] see quantize_vector
self.embeddings = torch.nn.parameter.Parameter(
torch.empty(num_embeddings, self.num_codebooks, embedding_dim),
requires_grad=False,
)
torch.nn.init.normal_(self.embeddings)
self.embeddings /= (self.embeddings.norm(dim=-1, p=2, keepdim=True) +
1e-8)

# force reset encoder papameter
self.reset_encoder_parameter()

def reset_encoder_parameter(self):

def _reset_parameter(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.trunc_normal_(module.weight.data,
mean=0.0,
std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups /
(module.in_channels * module.kernel_size[0]))
torch.nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, torch.Tensor):
torch.nn.init.trunc_normal_(module)
else:
raise NotImplementedError("other module not support now")

encoders = self.encoder.encoders
for _, layer in enumerate(encoders):
self_attn = layer.self_attn
_reset_parameter(self_attn.linear_q)
_reset_parameter(self_attn.linear_k)
_reset_parameter(self_attn.linear_v)
_reset_parameter(self_attn.linear_out)
if isinstance(self_attn, RelPositionMultiHeadedAttention):
_reset_parameter(self_attn.pos_bias_u)
_reset_parameter(self_attn.pos_bias_v)
if isinstance(layer, ConformerEncoderLayer):
conv1, conv2 = (layer.conv_module.pointwise_conv1,
layer.conv_module.depthwise_conv)
_reset_parameter(conv1)
_reset_parameter(conv2)

def forward(
self,
batch: Dict,
device: torch.device,
):
xs = batch['feats'].to(device)
xs_lens = batch['feats_lengths'].to(device)
input = xs

# 1 stack fbank, out_mask is for compute loss (NPT)
stack_input, stack_out_mask = self._stack_features(input, xs_lens)
masked_xs = xs

# 2 get nearest embedding
target_ids = self._nearest_embedding_idx(stack_input)
target_ids = target_ids[:, :out_mask.size(1), :]

# 3 forward xxx-formaer block and its subsampling layer
# TODO(mddct): encoder causal mask
out, out_mask = self.encoder(masked_xs, xs_lens)

# 4 get logits
out = out.unsqueeze(1) # [B, 1, T', dim]
top_n_out = self.encoder_top_n_out.unsqueeze(
0) # [1, num_codebooks, dim, num_embeddings]
out = torch.matmul(out,
top_n_out) # [B, num_codebooks, T', num_embeddings]
if self.out_bias:
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2)

# shift input and target for next token prediction
out = out[:, :, :-1:]
target_ids = target_ids[:, 1:, :]
masks = out_mask.squeeze(1) * stack_out_mask
masks = masks[:, 1:]

# 5 compute loss
loss = self._compute_loss(out, target_ids, mask=masks)

# 6 other info: num codes used in batch, unique num codes used in batch
num_codes = masks.sum() * self.num_codebooks
uniq_num_codes = torch.tensor(
torch.unique(target_ids * masks.unsqueeze(2)).numel()).detach()
ids_corr = out.argmax(dim=-1, keepdim=False).transpose(1,
2) == target_ids
codes_acc = (ids_corr * masks.unsqueeze(2)).sum() / num_codes
return {
"codes_acc": codes_acc,
"loss": loss,
"num_codes": num_codes,
"uniq_num_codes": uniq_num_codes,
"th_accuracy": codes_acc,
}

def _stack_features(
self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

mask = make_non_pad_mask(input_lens)
mask_stride = mask.unfold(
1,
size=self.stack_frames,
step=self.stride,
)
subsampline_mask, _ = torch.min(mask_stride, dim=-1)

stack_input = input.unfold(1, size=self.stack_frames, step=self.stride)
stack_input = stack_input.transpose(-1, -2)
b, n, f, d = stack_input.size()
stack_input = stack_input.reshape(b, n, f * d)

return self.norm(stack_input), subsampline_mask

def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1))
mask = mask.unsqueeze(2).repeat(1, 1, self.num_codebooks)
loss = torch.nn.functional.cross_entropy(
logits,
target.contiguous().view(-1),
reduction='none',
)
loss = (loss * mask.view(-1)).sum() / mask.sum()
return loss

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = torch.matmul(xs, self.projection.to(xs.device))
xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings
B, T, C = xs.size()
xs_flatten = xs.view(B * T, C)
_, codes, _ = quantize_vector(xs_flatten, codebooks)

return codes.reshape(B, T, -1) # [B, T, num_codebooks]

0 comments on commit 42798bc

Please sign in to comment.