Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlllc committed Aug 23, 2024
1 parent b5cfa41 commit 9920db7
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions ddsp/vocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,27 +634,31 @@ def __init__(self,
'noise_magnitude': win_length // 2 + 1,
'noise_phase': win_length // 2 + 1
}
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map, use_pitch_aug=use_pitch_aug, use_naive_v2=True)

self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map, use_pitch_aug=use_pitch_aug, use_naive_v2=True, use_conv_stack=True)

def fast_source_gen(self, f0_frames):
n = torch.arange(self.block_size, device=f0_frames.device)
s0 = f0_frames / self.sampling_rate
ds0 = F.pad(s0[:, 1:, :] - s0[:, :-1, :], (0, 0, 0, 1))
rad = s0 * (n + 1) + 0.5 * ds0 * n * (n + 1) / self.block_size
s0 = s0 + ds0 * n / self.block_size
rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0_frames)
rad += F.pad(rad_acc[:, :-1, :], (0, 0, 1, 0))
rad -= torch.round(rad)
combtooth = torch.sinc(rad / (s0 + 1e-5)).reshape(f0_frames.shape[0], -1)
phase_frames = 2 * np.pi * rad[:, :, :1]
return combtooth, phase_frames

def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, aug_shift=None, initial_phase=None, infer=True, **kwargs):
'''
units_frames: B x n_frames x n_unit
f0_frames: B x n_frames x 1
volume_frames: B x n_frames x 1
spk_id: B x 1
'''
# exciter phase
f0 = upsample(f0_frames, self.block_size)
if infer:
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
else:
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
if initial_phase is not None:
x += initial_phase.to(x) / 2 / np.pi
x = x - torch.round(x)
x = x.to(f0)

phase_frames = 2 * np.pi * x[:, ::self.block_size, :]
# combtooth exciter signal
combtooth, phase_frames = self.fast_source_gen(f0_frames)

# parameter prediction
ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict, aug_shift=aug_shift)
Expand All @@ -664,9 +668,7 @@ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_d
noise_filter= torch.exp(ctrls['noise_magnitude'] + 1.j * np.pi * ctrls['noise_phase']) / 128
noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1)

# combtooth exciter signal
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
combtooth = combtooth.squeeze(-1)
# harmonic part filter
if combtooth.shape[-1] > self.win_length // 2:
pad_mode = 'reflect'
else:
Expand All @@ -681,7 +683,7 @@ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_d
return_complex = True,
pad_mode = pad_mode)

# noise exciter signal
# noise part filter
noise = torch.randn_like(combtooth)
noise_stft = torch.stft(
noise,
Expand Down

0 comments on commit 9920db7

Please sign in to comment.