Modify model for ONNX compatibility (#87)

This commit is contained in:
Adrian Lyjak
2025-02-15 13:05:57 -06:00
committed by GitHub
parent ce71a10c57
commit 93abff8795
5 changed files with 354 additions and 38 deletions

81
tests/test_custom_stft.py Normal file
View File

@@ -0,0 +1,81 @@
import torch
import numpy as np
import pytest
from kokoro.custom_stft import CustomSTFT
from kokoro.istftnet import TorchSTFT
import torch.nn.functional as F
@pytest.fixture
def sample_audio():
# Generate a sample audio signal (sine wave)
sample_rate = 16000
duration = 1.0 # seconds
t = torch.linspace(0, duration, int(sample_rate * duration))
frequency = 440.0 # Hz
signal = torch.sin(2 * np.pi * frequency * t)
return signal.unsqueeze(0) # Add batch dimension
def test_stft_reconstruction(sample_audio):
# Initialize both STFT implementations
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
# Process through both implementations
custom_output = custom_stft(sample_audio)
torch_output = torch_stft(sample_audio)
# Compare outputs
assert torch.allclose(custom_output, torch_output, rtol=1e-3, atol=1e-3)
def test_magnitude_phase_consistency(sample_audio):
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
# Get magnitude and phase from both implementations
custom_mag, custom_phase = custom_stft.transform(sample_audio)
torch_mag, torch_phase = torch_stft.transform(sample_audio)
# Compare magnitudes ignoring the boundary frames
custom_mag_center = custom_mag[..., 2:-2]
torch_mag_center = torch_mag[..., 2:-2]
assert torch.allclose(custom_mag_center, torch_mag_center, rtol=1e-2, atol=1e-2)
def test_batch_processing():
# Create a batch of signals
batch_size = 4
sample_rate = 16000
duration = 0.1 # shorter duration for faster testing
t = torch.linspace(0, duration, int(sample_rate * duration))
frequency = 440.0
signals = torch.sin(2 * np.pi * frequency * t).unsqueeze(0).repeat(batch_size, 1)
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
# Process batch
output = custom_stft(signals)
# Check output shape
assert output.shape[0] == batch_size
assert len(output.shape) == 3 # (batch, 1, time)
def test_different_window_sizes():
signal = torch.randn(1, 16000) # 1 second of random noise
# Test with different window sizes
for filter_length in [512, 1024, 2048]:
custom_stft = CustomSTFT(
filter_length=filter_length,
hop_length=filter_length // 4,
win_length=filter_length,
)
# Forward and backward transform
output = custom_stft(signal)
# Check that output length is reasonable
assert output.shape[-1] >= signal.shape[-1]