Modify model for ONNX compatibility (#87)
This commit is contained in:
198
kokoro/custom_stft.py
Normal file
198
kokoro/custom_stft.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
from attr import attr
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from scipy.signal import get_window
|
||||||
|
|
||||||
|
class CustomSTFT(nn.Module):
|
||||||
|
"""
|
||||||
|
STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d.
|
||||||
|
|
||||||
|
- forward STFT => Real-part conv1d + Imag-part conv1d
|
||||||
|
- inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum
|
||||||
|
- avoids F.unfold, so easier to export to ONNX
|
||||||
|
- uses replicate or constant padding for 'center=True' to approximate 'reflect'
|
||||||
|
(reflect is not supported for dynamic shapes in ONNX)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length=800,
|
||||||
|
hop_length=200,
|
||||||
|
win_length=800,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
pad_mode="replicate", # or 'constant'
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.n_fft = filter_length
|
||||||
|
self.center = center
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
|
# Number of frequency bins for real-valued STFT with onesided=True
|
||||||
|
self.freq_bins = self.n_fft // 2 + 1
|
||||||
|
|
||||||
|
# Build window
|
||||||
|
win_np = get_window(window, self.win_length, fftbins=True).astype(np.float32)
|
||||||
|
window_tensor = torch.from_numpy(win_np)
|
||||||
|
if self.win_length < self.n_fft:
|
||||||
|
# Zero-pad up to n_fft
|
||||||
|
extra = self.n_fft - self.win_length
|
||||||
|
window_tensor = F.pad(window_tensor, (0, extra))
|
||||||
|
elif self.win_length > self.n_fft:
|
||||||
|
window_tensor = window_tensor[: self.n_fft]
|
||||||
|
self.register_buffer("window", window_tensor)
|
||||||
|
|
||||||
|
# Precompute forward DFT (real, imag)
|
||||||
|
# PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...)
|
||||||
|
n = np.arange(self.n_fft)
|
||||||
|
k = np.arange(self.freq_bins)
|
||||||
|
angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft)
|
||||||
|
dft_real = np.cos(angle)
|
||||||
|
dft_imag = -np.sin(angle) # note negative sign
|
||||||
|
|
||||||
|
# Combine window and dft => shape (freq_bins, filter_length)
|
||||||
|
# We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length).
|
||||||
|
forward_window = window_tensor.numpy() # shape (n_fft,)
|
||||||
|
forward_real = dft_real * forward_window # (freq_bins, n_fft)
|
||||||
|
forward_imag = dft_imag * forward_window
|
||||||
|
|
||||||
|
# Convert to PyTorch
|
||||||
|
forward_real_torch = torch.from_numpy(forward_real).float()
|
||||||
|
forward_imag_torch = torch.from_numpy(forward_imag).float()
|
||||||
|
|
||||||
|
# Register as Conv1d weight => (out_channels, in_channels, kernel_size)
|
||||||
|
# out_channels = freq_bins, in_channels=1, kernel_size=n_fft
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_forward_real", forward_real_torch.unsqueeze(1)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_forward_imag", forward_imag_torch.unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Precompute inverse DFT
|
||||||
|
# Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc.
|
||||||
|
# For simplicity, we won't do the "DC/nyquist not doubled" approach here.
|
||||||
|
# If you want perfect real iSTFT, you can add that logic.
|
||||||
|
# This version just yields good approximate reconstruction with Hann + typical overlap.
|
||||||
|
inv_scale = 1.0 / self.n_fft
|
||||||
|
n = np.arange(self.n_fft)
|
||||||
|
angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins)
|
||||||
|
idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft)
|
||||||
|
idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft)
|
||||||
|
|
||||||
|
# Multiply by window again for typical overlap-add
|
||||||
|
# We also incorporate the scale factor 1/n_fft
|
||||||
|
inv_window = window_tensor.numpy() * inv_scale
|
||||||
|
backward_real = idft_cos * inv_window # (freq_bins, n_fft)
|
||||||
|
backward_imag = idft_sin * inv_window
|
||||||
|
|
||||||
|
# We'll implement iSTFT as real+imag conv_transpose with stride=hop.
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def transform(self, waveform: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Forward STFT => returns magnitude, phase
|
||||||
|
Output shape => (batch, freq_bins, frames)
|
||||||
|
"""
|
||||||
|
# waveform shape => (B, T). conv1d expects (B, 1, T).
|
||||||
|
# Optional center pad
|
||||||
|
if self.center:
|
||||||
|
pad_len = self.n_fft // 2
|
||||||
|
waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode)
|
||||||
|
|
||||||
|
x = waveform.unsqueeze(1) # => (B, 1, T)
|
||||||
|
# Convolution to get real part => shape (B, freq_bins, frames)
|
||||||
|
real_out = F.conv1d(
|
||||||
|
x,
|
||||||
|
self.weight_forward_real,
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
# Imag part
|
||||||
|
imag_out = F.conv1d(
|
||||||
|
x,
|
||||||
|
self.weight_forward_imag,
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# magnitude, phase
|
||||||
|
magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14)
|
||||||
|
phase = torch.atan2(imag_out, real_out)
|
||||||
|
# Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch
|
||||||
|
# In this case, PyTorch returns pi, ONNX returns -pi
|
||||||
|
correction_mask = (imag_out == 0) & (real_out < 0)
|
||||||
|
phase[correction_mask] = torch.pi
|
||||||
|
return magnitude, phase
|
||||||
|
|
||||||
|
|
||||||
|
def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None):
|
||||||
|
"""
|
||||||
|
Inverse STFT => returns waveform shape (B, T).
|
||||||
|
"""
|
||||||
|
# magnitude, phase => (B, freq_bins, frames)
|
||||||
|
# Re-create real/imag => shape (B, freq_bins, frames)
|
||||||
|
real_part = magnitude * torch.cos(phase)
|
||||||
|
imag_part = magnitude * torch.sin(phase)
|
||||||
|
|
||||||
|
# conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension
|
||||||
|
# so we do (B, freq_bins, frames) => (B, freq_bins, frames)
|
||||||
|
# But PyTorch conv_transpose1d expects (B, in_channels, input_length)
|
||||||
|
real_part = real_part # (B, freq_bins, frames)
|
||||||
|
imag_part = imag_part
|
||||||
|
|
||||||
|
# real iSTFT => convolve with "backward_real", "backward_imag", and sum
|
||||||
|
# We'll do 2 conv_transpose calls, each giving (B, 1, time),
|
||||||
|
# then add them => (B, 1, time).
|
||||||
|
real_rec = F.conv_transpose1d(
|
||||||
|
real_part,
|
||||||
|
self.weight_backward_real, # shape (freq_bins, 1, filter_length)
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
imag_rec = F.conv_transpose1d(
|
||||||
|
imag_part,
|
||||||
|
self.weight_backward_imag,
|
||||||
|
bias=None,
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
# sum => (B, 1, time)
|
||||||
|
waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part
|
||||||
|
|
||||||
|
# If we used "center=True" in forward, we should remove pad
|
||||||
|
if self.center:
|
||||||
|
pad_len = self.n_fft // 2
|
||||||
|
# Because of transposed convolution, total length might have extra samples
|
||||||
|
# We remove `pad_len` from start & end if possible
|
||||||
|
waveform = waveform[..., pad_len:-pad_len]
|
||||||
|
|
||||||
|
# If a specific length is desired, clamp
|
||||||
|
if length is not None:
|
||||||
|
waveform = waveform[..., :length]
|
||||||
|
|
||||||
|
# shape => (B, T)
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Full STFT -> iSTFT pass: returns time-domain reconstruction.
|
||||||
|
Same interface as your original code.
|
||||||
|
"""
|
||||||
|
mag, phase = self.transform(x)
|
||||||
|
return self.inverse(mag, phase, length=x.shape[-1])
|
||||||
@@ -7,6 +7,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from kokoro.custom_stft import CustomSTFT
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||||
def init_weights(m, mean=0.0, std=0.01):
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
@@ -254,7 +256,7 @@ class SourceModuleHnNSF(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
@@ -289,7 +291,11 @@ class Generator(nn.Module):
|
|||||||
self.ups.apply(init_weights)
|
self.ups.apply(init_weights)
|
||||||
self.conv_post.apply(init_weights)
|
self.conv_post.apply(init_weights)
|
||||||
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||||
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
self.stft = (
|
||||||
|
CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||||
|
if disable_complex
|
||||||
|
else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, s, f0):
|
def forward(self, x, s, f0):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -383,7 +389,8 @@ class Decoder(nn.Module):
|
|||||||
upsample_initial_channel,
|
upsample_initial_channel,
|
||||||
resblock_dilation_sizes,
|
resblock_dilation_sizes,
|
||||||
upsample_kernel_sizes,
|
upsample_kernel_sizes,
|
||||||
gen_istft_n_fft, gen_istft_hop_size):
|
gen_istft_n_fft, gen_istft_hop_size,
|
||||||
|
disable_complex=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||||
self.decode = nn.ModuleList()
|
self.decode = nn.ModuleList()
|
||||||
@@ -396,7 +403,7 @@ class Decoder(nn.Module):
|
|||||||
self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
|
self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
|
||||||
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||||
upsample_initial_channel, resblock_dilation_sizes,
|
upsample_initial_channel, resblock_dilation_sizes,
|
||||||
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex)
|
||||||
|
|
||||||
def forward(self, asr, F0_curve, N, s):
|
def forward(self, asr, F0_curve, N, s):
|
||||||
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class KModel(torch.nn.Module):
|
|||||||
|
|
||||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||||
|
|
||||||
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None):
|
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None, disable_complex: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not isinstance(config, dict):
|
if not isinstance(config, dict):
|
||||||
if not config:
|
if not config:
|
||||||
@@ -49,7 +49,7 @@ class KModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.decoder = Decoder(
|
self.decoder = Decoder(
|
||||||
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
|
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
|
||||||
dim_out=config['n_mels'], **config['istftnet']
|
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
|
||||||
)
|
)
|
||||||
if not model:
|
if not model:
|
||||||
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
|
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
|
||||||
@@ -72,30 +72,29 @@ class KModel(torch.nn.Module):
|
|||||||
pred_dur: Optional[torch.LongTensor] = None
|
pred_dur: Optional[torch.LongTensor] = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward_with_tokens(
|
||||||
self,
|
self,
|
||||||
phonemes: str,
|
input_ids: torch.LongTensor,
|
||||||
ref_s: torch.FloatTensor,
|
ref_s: torch.FloatTensor,
|
||||||
speed: Number = 1,
|
speed: Number = 1
|
||||||
return_output: bool = False # MARK: BACKWARD COMPAT
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
) -> Union['KModel.Output', torch.FloatTensor]:
|
input_lengths = torch.full(
|
||||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
(input_ids.shape[0],),
|
||||||
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
input_ids.shape[-1],
|
||||||
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
device=input_ids.device,
|
||||||
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
|
dtype=torch.long
|
||||||
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
|
)
|
||||||
|
|
||||||
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
|
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
|
||||||
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
|
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
|
||||||
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
|
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
|
||||||
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
ref_s = ref_s.to(self.device)
|
|
||||||
s = ref_s[:, 128:]
|
s = ref_s[:, 128:]
|
||||||
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||||
x, _ = self.predictor.lstm(d)
|
x, _ = self.predictor.lstm(d)
|
||||||
duration = self.predictor.duration_proj(x)
|
duration = self.predictor.duration_proj(x)
|
||||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
|
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
|
||||||
logger.debug(f"pred_dur: {pred_dur}")
|
|
||||||
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
|
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
|
||||||
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
|
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
|
||||||
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
|
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
|
||||||
@@ -104,5 +103,37 @@ class KModel(torch.nn.Module):
|
|||||||
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||||
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||||
asr = t_en @ pred_aln_trg
|
asr = t_en @ pred_aln_trg
|
||||||
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
|
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
|
||||||
return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio
|
return audio, pred_dur
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
phonemes: str,
|
||||||
|
ref_s: torch.FloatTensor,
|
||||||
|
speed: Number = 1,
|
||||||
|
return_output: bool = False
|
||||||
|
) -> Union['KModel.Output', torch.FloatTensor]:
|
||||||
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||||
|
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
||||||
|
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
||||||
|
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
|
||||||
|
ref_s = ref_s.to(self.device)
|
||||||
|
audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
|
||||||
|
audio = audio.squeeze().cpu()
|
||||||
|
pred_dur = pred_dur.cpu() if pred_dur is not None else None
|
||||||
|
logger.debug(f"pred_dur: {pred_dur}")
|
||||||
|
return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio
|
||||||
|
|
||||||
|
class KModelForONNX(torch.nn.Module):
|
||||||
|
def __init__(self, kmodel: KModel):
|
||||||
|
super().__init__()
|
||||||
|
self.kmodel = kmodel
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
ref_s: torch.FloatTensor,
|
||||||
|
speed: Number = 1
|
||||||
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
|
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
||||||
|
return waveform
|
||||||
|
|||||||
@@ -50,21 +50,21 @@ class TextEncoder(nn.Module):
|
|||||||
def forward(self, x, input_lengths, m):
|
def forward(self, x, input_lengths, m):
|
||||||
x = self.embedding(x) # [B, T, emb]
|
x = self.embedding(x) # [B, T, emb]
|
||||||
x = x.transpose(1, 2) # [B, emb, T]
|
x = x.transpose(1, 2) # [B, emb, T]
|
||||||
m = m.to(input_lengths.device).unsqueeze(1)
|
m = m.unsqueeze(1)
|
||||||
x.masked_fill_(m, 0.0)
|
x.masked_fill_(m, 0.0)
|
||||||
for c in self.cnn:
|
for c in self.cnn:
|
||||||
x = c(x)
|
x = c(x)
|
||||||
x.masked_fill_(m, 0.0)
|
x.masked_fill_(m, 0.0)
|
||||||
x = x.transpose(1, 2) # [B, T, chn]
|
x = x.transpose(1, 2) # [B, T, chn]
|
||||||
input_lengths = input_lengths.cpu().numpy()
|
lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
|
||||||
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
|
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
||||||
self.lstm.flatten_parameters()
|
self.lstm.flatten_parameters()
|
||||||
x, _ = self.lstm(x)
|
x, _ = self.lstm(x)
|
||||||
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
|
||||||
x_pad[:, :, :x.shape[-1]] = x
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
x = x_pad.to(x.device)
|
x = x_pad
|
||||||
x.masked_fill_(m, 0.0)
|
x.masked_fill_(m, 0.0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -108,17 +108,15 @@ class ProsodyPredictor(nn.Module):
|
|||||||
|
|
||||||
def forward(self, texts, style, text_lengths, alignment, m):
|
def forward(self, texts, style, text_lengths, alignment, m):
|
||||||
d = self.text_encoder(texts, style, text_lengths, m)
|
d = self.text_encoder(texts, style, text_lengths, m)
|
||||||
batch_size = d.shape[0]
|
m = m.unsqueeze(1)
|
||||||
text_size = d.shape[1]
|
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
|
||||||
input_lengths = text_lengths.cpu().numpy()
|
x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
|
||||||
x = nn.utils.rnn.pack_padded_sequence(d, input_lengths, batch_first=True, enforce_sorted=False)
|
|
||||||
m = m.to(text_lengths.device).unsqueeze(1)
|
|
||||||
self.lstm.flatten_parameters()
|
self.lstm.flatten_parameters()
|
||||||
x, _ = self.lstm(x)
|
x, _ = self.lstm(x)
|
||||||
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||||
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
|
||||||
x_pad[:, :x.shape[1], :] = x
|
x_pad[:, :x.shape[1], :] = x
|
||||||
x = x_pad.to(x.device)
|
x = x_pad
|
||||||
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
|
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
|
||||||
en = (d.transpose(-1, -2) @ alignment)
|
en = (d.transpose(-1, -2) @ alignment)
|
||||||
return duration.squeeze(-1), en
|
return duration.squeeze(-1), en
|
||||||
@@ -148,32 +146,33 @@ class DurationEncoder(nn.Module):
|
|||||||
self.sty_dim = sty_dim
|
self.sty_dim = sty_dim
|
||||||
|
|
||||||
def forward(self, x, style, text_lengths, m):
|
def forward(self, x, style, text_lengths, m):
|
||||||
masks = m.to(text_lengths.device)
|
masks = m
|
||||||
x = x.permute(2, 0, 1)
|
x = x.permute(2, 0, 1)
|
||||||
s = style.expand(x.shape[0], x.shape[1], -1)
|
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||||
x = torch.cat([x, s], axis=-1)
|
x = torch.cat([x, s], axis=-1)
|
||||||
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
input_lengths = text_lengths.cpu().numpy()
|
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
for block in self.lstms:
|
for block in self.lstms:
|
||||||
if isinstance(block, AdaLayerNorm):
|
if isinstance(block, AdaLayerNorm):
|
||||||
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||||
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
|
||||||
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
||||||
else:
|
else:
|
||||||
|
lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
x = nn.utils.rnn.pack_padded_sequence(
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
x, input_lengths, batch_first=True, enforce_sorted=False)
|
x, lengths, batch_first=True, enforce_sorted=False)
|
||||||
block.flatten_parameters()
|
block.flatten_parameters()
|
||||||
x, _ = block(x)
|
x, _ = block(x)
|
||||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
x, batch_first=True)
|
x, batch_first=True)
|
||||||
x = F.dropout(x, p=self.dropout, training=False)
|
x = F.dropout(x, p=self.dropout, training=False)
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
|
||||||
x_pad[:, :, :x.shape[-1]] = x
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
x = x_pad.to(x.device)
|
x = x_pad
|
||||||
|
|
||||||
return x.transpose(-1, -2)
|
return x.transpose(-1, -2)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
81
tests/test_custom_stft.py
Normal file
81
tests/test_custom_stft.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from kokoro.custom_stft import CustomSTFT
|
||||||
|
from kokoro.istftnet import TorchSTFT
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_audio():
|
||||||
|
# Generate a sample audio signal (sine wave)
|
||||||
|
sample_rate = 16000
|
||||||
|
duration = 1.0 # seconds
|
||||||
|
t = torch.linspace(0, duration, int(sample_rate * duration))
|
||||||
|
frequency = 440.0 # Hz
|
||||||
|
signal = torch.sin(2 * np.pi * frequency * t)
|
||||||
|
return signal.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
|
||||||
|
def test_stft_reconstruction(sample_audio):
|
||||||
|
# Initialize both STFT implementations
|
||||||
|
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
|
||||||
|
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
|
||||||
|
|
||||||
|
# Process through both implementations
|
||||||
|
custom_output = custom_stft(sample_audio)
|
||||||
|
torch_output = torch_stft(sample_audio)
|
||||||
|
|
||||||
|
# Compare outputs
|
||||||
|
assert torch.allclose(custom_output, torch_output, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_magnitude_phase_consistency(sample_audio):
|
||||||
|
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
|
||||||
|
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
|
||||||
|
|
||||||
|
# Get magnitude and phase from both implementations
|
||||||
|
custom_mag, custom_phase = custom_stft.transform(sample_audio)
|
||||||
|
torch_mag, torch_phase = torch_stft.transform(sample_audio)
|
||||||
|
|
||||||
|
# Compare magnitudes ignoring the boundary frames
|
||||||
|
custom_mag_center = custom_mag[..., 2:-2]
|
||||||
|
torch_mag_center = torch_mag[..., 2:-2]
|
||||||
|
assert torch.allclose(custom_mag_center, torch_mag_center, rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_processing():
|
||||||
|
# Create a batch of signals
|
||||||
|
batch_size = 4
|
||||||
|
sample_rate = 16000
|
||||||
|
duration = 0.1 # shorter duration for faster testing
|
||||||
|
t = torch.linspace(0, duration, int(sample_rate * duration))
|
||||||
|
frequency = 440.0
|
||||||
|
signals = torch.sin(2 * np.pi * frequency * t).unsqueeze(0).repeat(batch_size, 1)
|
||||||
|
|
||||||
|
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
|
||||||
|
|
||||||
|
# Process batch
|
||||||
|
output = custom_stft(signals)
|
||||||
|
|
||||||
|
# Check output shape
|
||||||
|
assert output.shape[0] == batch_size
|
||||||
|
assert len(output.shape) == 3 # (batch, 1, time)
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_window_sizes():
|
||||||
|
signal = torch.randn(1, 16000) # 1 second of random noise
|
||||||
|
|
||||||
|
# Test with different window sizes
|
||||||
|
for filter_length in [512, 1024, 2048]:
|
||||||
|
custom_stft = CustomSTFT(
|
||||||
|
filter_length=filter_length,
|
||||||
|
hop_length=filter_length // 4,
|
||||||
|
win_length=filter_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward and backward transform
|
||||||
|
output = custom_stft(signal)
|
||||||
|
|
||||||
|
# Check that output length is reasonable
|
||||||
|
assert output.shape[-1] >= signal.shape[-1]
|
||||||
Reference in New Issue
Block a user