diff --git a/kokoro/custom_stft.py b/kokoro/custom_stft.py new file mode 100644 index 0000000..15f3378 --- /dev/null +++ b/kokoro/custom_stft.py @@ -0,0 +1,198 @@ +from attr import attr +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window + +class CustomSTFT(nn.Module): + """ + STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d. + + - forward STFT => Real-part conv1d + Imag-part conv1d + - inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum + - avoids F.unfold, so easier to export to ONNX + - uses replicate or constant padding for 'center=True' to approximate 'reflect' + (reflect is not supported for dynamic shapes in ONNX) + """ + + def __init__( + self, + filter_length=800, + hop_length=200, + win_length=800, + window="hann", + center=True, + pad_mode="replicate", # or 'constant' + ): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_fft = filter_length + self.center = center + self.pad_mode = pad_mode + + # Number of frequency bins for real-valued STFT with onesided=True + self.freq_bins = self.n_fft // 2 + 1 + + # Build window + win_np = get_window(window, self.win_length, fftbins=True).astype(np.float32) + window_tensor = torch.from_numpy(win_np) + if self.win_length < self.n_fft: + # Zero-pad up to n_fft + extra = self.n_fft - self.win_length + window_tensor = F.pad(window_tensor, (0, extra)) + elif self.win_length > self.n_fft: + window_tensor = window_tensor[: self.n_fft] + self.register_buffer("window", window_tensor) + + # Precompute forward DFT (real, imag) + # PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...) + n = np.arange(self.n_fft) + k = np.arange(self.freq_bins) + angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft) + dft_real = np.cos(angle) + dft_imag = -np.sin(angle) # note negative sign + + # Combine window and dft => shape (freq_bins, filter_length) + # We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length). + forward_window = window_tensor.numpy() # shape (n_fft,) + forward_real = dft_real * forward_window # (freq_bins, n_fft) + forward_imag = dft_imag * forward_window + + # Convert to PyTorch + forward_real_torch = torch.from_numpy(forward_real).float() + forward_imag_torch = torch.from_numpy(forward_imag).float() + + # Register as Conv1d weight => (out_channels, in_channels, kernel_size) + # out_channels = freq_bins, in_channels=1, kernel_size=n_fft + self.register_buffer( + "weight_forward_real", forward_real_torch.unsqueeze(1) + ) + self.register_buffer( + "weight_forward_imag", forward_imag_torch.unsqueeze(1) + ) + + # Precompute inverse DFT + # Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc. + # For simplicity, we won't do the "DC/nyquist not doubled" approach here. + # If you want perfect real iSTFT, you can add that logic. + # This version just yields good approximate reconstruction with Hann + typical overlap. + inv_scale = 1.0 / self.n_fft + n = np.arange(self.n_fft) + angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins) + idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft) + idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft) + + # Multiply by window again for typical overlap-add + # We also incorporate the scale factor 1/n_fft + inv_window = window_tensor.numpy() * inv_scale + backward_real = idft_cos * inv_window # (freq_bins, n_fft) + backward_imag = idft_sin * inv_window + + # We'll implement iSTFT as real+imag conv_transpose with stride=hop. + self.register_buffer( + "weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1) + ) + self.register_buffer( + "weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1) + ) + + + + def transform(self, waveform: torch.Tensor): + """ + Forward STFT => returns magnitude, phase + Output shape => (batch, freq_bins, frames) + """ + # waveform shape => (B, T). conv1d expects (B, 1, T). + # Optional center pad + if self.center: + pad_len = self.n_fft // 2 + waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode) + + x = waveform.unsqueeze(1) # => (B, 1, T) + # Convolution to get real part => shape (B, freq_bins, frames) + real_out = F.conv1d( + x, + self.weight_forward_real, + bias=None, + stride=self.hop_length, + padding=0, + ) + # Imag part + imag_out = F.conv1d( + x, + self.weight_forward_imag, + bias=None, + stride=self.hop_length, + padding=0, + ) + + # magnitude, phase + magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14) + phase = torch.atan2(imag_out, real_out) + # Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch + # In this case, PyTorch returns pi, ONNX returns -pi + correction_mask = (imag_out == 0) & (real_out < 0) + phase[correction_mask] = torch.pi + return magnitude, phase + + + def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None): + """ + Inverse STFT => returns waveform shape (B, T). + """ + # magnitude, phase => (B, freq_bins, frames) + # Re-create real/imag => shape (B, freq_bins, frames) + real_part = magnitude * torch.cos(phase) + imag_part = magnitude * torch.sin(phase) + + # conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension + # so we do (B, freq_bins, frames) => (B, freq_bins, frames) + # But PyTorch conv_transpose1d expects (B, in_channels, input_length) + real_part = real_part # (B, freq_bins, frames) + imag_part = imag_part + + # real iSTFT => convolve with "backward_real", "backward_imag", and sum + # We'll do 2 conv_transpose calls, each giving (B, 1, time), + # then add them => (B, 1, time). + real_rec = F.conv_transpose1d( + real_part, + self.weight_backward_real, # shape (freq_bins, 1, filter_length) + bias=None, + stride=self.hop_length, + padding=0, + ) + imag_rec = F.conv_transpose1d( + imag_part, + self.weight_backward_imag, + bias=None, + stride=self.hop_length, + padding=0, + ) + # sum => (B, 1, time) + waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part + + # If we used "center=True" in forward, we should remove pad + if self.center: + pad_len = self.n_fft // 2 + # Because of transposed convolution, total length might have extra samples + # We remove `pad_len` from start & end if possible + waveform = waveform[..., pad_len:-pad_len] + + # If a specific length is desired, clamp + if length is not None: + waveform = waveform[..., :length] + + # shape => (B, T) + return waveform + + def forward(self, x: torch.Tensor): + """ + Full STFT -> iSTFT pass: returns time-domain reconstruction. + Same interface as your original code. + """ + mag, phase = self.transform(x) + return self.inverse(mag, phase, length=x.shape[-1]) diff --git a/kokoro/istftnet.py b/kokoro/istftnet.py index 01289b2..e61e510 100644 --- a/kokoro/istftnet.py +++ b/kokoro/istftnet.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from kokoro.custom_stft import CustomSTFT + # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py def init_weights(m, mean=0.0, std=0.01): @@ -254,7 +256,7 @@ class SourceModuleHnNSF(nn.Module): class Generator(nn.Module): - def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size): + def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False): super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) @@ -289,7 +291,11 @@ class Generator(nn.Module): self.ups.apply(init_weights) self.conv_post.apply(init_weights) self.reflection_pad = nn.ReflectionPad1d((1, 0)) - self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) + self.stft = ( + CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) + if disable_complex + else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) + ) def forward(self, x, s, f0): with torch.no_grad(): @@ -383,7 +389,8 @@ class Decoder(nn.Module): upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, - gen_istft_n_fft, gen_istft_hop_size): + gen_istft_n_fft, gen_istft_hop_size, + disable_complex=False): super().__init__() self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) self.decode = nn.ModuleList() @@ -396,7 +403,7 @@ class Decoder(nn.Module): self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1))) self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, - upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size) + upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex) def forward(self, asr, F0_curve, N, s): F0 = self.F0_conv(F0_curve.unsqueeze(1)) diff --git a/kokoro/model.py b/kokoro/model.py index 27ab5c1..25659ad 100644 --- a/kokoro/model.py +++ b/kokoro/model.py @@ -26,7 +26,7 @@ class KModel(torch.nn.Module): REPO_ID = 'hexgrad/Kokoro-82M' - def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None): + def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None, disable_complex: bool = False): super().__init__() if not isinstance(config, dict): if not config: @@ -49,7 +49,7 @@ class KModel(torch.nn.Module): ) self.decoder = Decoder( dim_in=config['hidden_dim'], style_dim=config['style_dim'], - dim_out=config['n_mels'], **config['istftnet'] + dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet'] ) if not model: model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth') @@ -72,30 +72,29 @@ class KModel(torch.nn.Module): pred_dur: Optional[torch.LongTensor] = None @torch.no_grad() - def forward( + def forward_with_tokens( self, - phonemes: str, + input_ids: torch.LongTensor, ref_s: torch.FloatTensor, - speed: Number = 1, - return_output: bool = False # MARK: BACKWARD COMPAT - ) -> Union['KModel.Output', torch.FloatTensor]: - input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) - logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") - assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) - input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) - input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device) + speed: Number = 1 + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + input_lengths = torch.full( + (input_ids.shape[0],), + input_ids.shape[-1], + device=input_ids.device, + dtype=torch.long + ) + text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths) text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device) bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int()) d_en = self.bert_encoder(bert_dur).transpose(-1, -2) - ref_s = ref_s.to(self.device) s = ref_s[:, 128:] d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = self.predictor.lstm(d) duration = self.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) / speed pred_dur = torch.round(duration).clamp(min=1).long().squeeze() - logger.debug(f"pred_dur: {pred_dur}") indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur) pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device) pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1 @@ -104,5 +103,37 @@ class KModel(torch.nn.Module): F0_pred, N_pred = self.predictor.F0Ntrain(en, s) t_en = self.text_encoder(input_ids, input_lengths, text_mask) asr = t_en @ pred_aln_trg - audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu() - return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio + audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze() + return audio, pred_dur + + def forward( + self, + phonemes: str, + ref_s: torch.FloatTensor, + speed: Number = 1, + return_output: bool = False + ) -> Union['KModel.Output', torch.FloatTensor]: + input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) + logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") + assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) + input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) + ref_s = ref_s.to(self.device) + audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed) + audio = audio.squeeze().cpu() + pred_dur = pred_dur.cpu() if pred_dur is not None else None + logger.debug(f"pred_dur: {pred_dur}") + return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio + +class KModelForONNX(torch.nn.Module): + def __init__(self, kmodel: KModel): + super().__init__() + self.kmodel = kmodel + + def forward( + self, + input_ids: torch.LongTensor, + ref_s: torch.FloatTensor, + speed: Number = 1 + ) -> tuple[torch.FloatTensor, torch.LongTensor]: + waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed) + return waveform diff --git a/kokoro/modules.py b/kokoro/modules.py index 2b72307..05d1575 100644 --- a/kokoro/modules.py +++ b/kokoro/modules.py @@ -50,21 +50,21 @@ class TextEncoder(nn.Module): def forward(self, x, input_lengths, m): x = self.embedding(x) # [B, T, emb] x = x.transpose(1, 2) # [B, emb, T] - m = m.to(input_lengths.device).unsqueeze(1) + m = m.unsqueeze(1) x.masked_fill_(m, 0.0) for c in self.cnn: x = c(x) x.masked_fill_(m, 0.0) x = x.transpose(1, 2) # [B, T, chn] - input_lengths = input_lengths.cpu().numpy() - x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False) + lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu') + x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) self.lstm.flatten_parameters() x, _ = self.lstm(x) x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) x = x.transpose(-1, -2) - x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device) x_pad[:, :, :x.shape[-1]] = x - x = x_pad.to(x.device) + x = x_pad x.masked_fill_(m, 0.0) return x @@ -108,17 +108,15 @@ class ProsodyPredictor(nn.Module): def forward(self, texts, style, text_lengths, alignment, m): d = self.text_encoder(texts, style, text_lengths, m) - batch_size = d.shape[0] - text_size = d.shape[1] - input_lengths = text_lengths.cpu().numpy() - x = nn.utils.rnn.pack_padded_sequence(d, input_lengths, batch_first=True, enforce_sorted=False) - m = m.to(text_lengths.device).unsqueeze(1) + m = m.unsqueeze(1) + lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu') + x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False) self.lstm.flatten_parameters() x, _ = self.lstm(x) x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) - x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]]) + x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device) x_pad[:, :x.shape[1], :] = x - x = x_pad.to(x.device) + x = x_pad duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False)) en = (d.transpose(-1, -2) @ alignment) return duration.squeeze(-1), en @@ -148,32 +146,33 @@ class DurationEncoder(nn.Module): self.sty_dim = sty_dim def forward(self, x, style, text_lengths, m): - masks = m.to(text_lengths.device) + masks = m x = x.permute(2, 0, 1) s = style.expand(x.shape[0], x.shape[1], -1) x = torch.cat([x, s], axis=-1) x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0) x = x.transpose(0, 1) - input_lengths = text_lengths.cpu().numpy() x = x.transpose(-1, -2) for block in self.lstms: if isinstance(block, AdaLayerNorm): x = block(x.transpose(-1, -2), style).transpose(-1, -2) - x = torch.cat([x, s.permute(1, -1, 0)], axis=1) + x = torch.cat([x, s.permute(1, 2, 0)], axis=1) x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0) else: + lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu') x = x.transpose(-1, -2) x = nn.utils.rnn.pack_padded_sequence( - x, input_lengths, batch_first=True, enforce_sorted=False) + x, lengths, batch_first=True, enforce_sorted=False) block.flatten_parameters() x, _ = block(x) x, _ = nn.utils.rnn.pad_packed_sequence( x, batch_first=True) x = F.dropout(x, p=self.dropout, training=False) x = x.transpose(-1, -2) - x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device) x_pad[:, :, :x.shape[-1]] = x - x = x_pad.to(x.device) + x = x_pad + return x.transpose(-1, -2) diff --git a/tests/test_custom_stft.py b/tests/test_custom_stft.py new file mode 100644 index 0000000..103c083 --- /dev/null +++ b/tests/test_custom_stft.py @@ -0,0 +1,81 @@ +import torch +import numpy as np +import pytest +from kokoro.custom_stft import CustomSTFT +from kokoro.istftnet import TorchSTFT +import torch.nn.functional as F + + +@pytest.fixture +def sample_audio(): + # Generate a sample audio signal (sine wave) + sample_rate = 16000 + duration = 1.0 # seconds + t = torch.linspace(0, duration, int(sample_rate * duration)) + frequency = 440.0 # Hz + signal = torch.sin(2 * np.pi * frequency * t) + return signal.unsqueeze(0) # Add batch dimension + + +def test_stft_reconstruction(sample_audio): + # Initialize both STFT implementations + custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800) + torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800) + + # Process through both implementations + custom_output = custom_stft(sample_audio) + torch_output = torch_stft(sample_audio) + + # Compare outputs + assert torch.allclose(custom_output, torch_output, rtol=1e-3, atol=1e-3) + + +def test_magnitude_phase_consistency(sample_audio): + custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800) + torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800) + + # Get magnitude and phase from both implementations + custom_mag, custom_phase = custom_stft.transform(sample_audio) + torch_mag, torch_phase = torch_stft.transform(sample_audio) + + # Compare magnitudes ignoring the boundary frames + custom_mag_center = custom_mag[..., 2:-2] + torch_mag_center = torch_mag[..., 2:-2] + assert torch.allclose(custom_mag_center, torch_mag_center, rtol=1e-2, atol=1e-2) + + +def test_batch_processing(): + # Create a batch of signals + batch_size = 4 + sample_rate = 16000 + duration = 0.1 # shorter duration for faster testing + t = torch.linspace(0, duration, int(sample_rate * duration)) + frequency = 440.0 + signals = torch.sin(2 * np.pi * frequency * t).unsqueeze(0).repeat(batch_size, 1) + + custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800) + + # Process batch + output = custom_stft(signals) + + # Check output shape + assert output.shape[0] == batch_size + assert len(output.shape) == 3 # (batch, 1, time) + + +def test_different_window_sizes(): + signal = torch.randn(1, 16000) # 1 second of random noise + + # Test with different window sizes + for filter_length in [512, 1024, 2048]: + custom_stft = CustomSTFT( + filter_length=filter_length, + hop_length=filter_length // 4, + win_length=filter_length, + ) + + # Forward and backward transform + output = custom_stft(signal) + + # Check that output length is reasonable + assert output.shape[-1] >= signal.shape[-1]