Files
kokoro/tests/test_custom_stft.py
2025-02-15 11:05:57 -08:00

82 lines
2.7 KiB
Python

import torch
import numpy as np
import pytest
from kokoro.custom_stft import CustomSTFT
from kokoro.istftnet import TorchSTFT
import torch.nn.functional as F
@pytest.fixture
def sample_audio():
# Generate a sample audio signal (sine wave)
sample_rate = 16000
duration = 1.0 # seconds
t = torch.linspace(0, duration, int(sample_rate * duration))
frequency = 440.0 # Hz
signal = torch.sin(2 * np.pi * frequency * t)
return signal.unsqueeze(0) # Add batch dimension
def test_stft_reconstruction(sample_audio):
# Initialize both STFT implementations
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
# Process through both implementations
custom_output = custom_stft(sample_audio)
torch_output = torch_stft(sample_audio)
# Compare outputs
assert torch.allclose(custom_output, torch_output, rtol=1e-3, atol=1e-3)
def test_magnitude_phase_consistency(sample_audio):
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
# Get magnitude and phase from both implementations
custom_mag, custom_phase = custom_stft.transform(sample_audio)
torch_mag, torch_phase = torch_stft.transform(sample_audio)
# Compare magnitudes ignoring the boundary frames
custom_mag_center = custom_mag[..., 2:-2]
torch_mag_center = torch_mag[..., 2:-2]
assert torch.allclose(custom_mag_center, torch_mag_center, rtol=1e-2, atol=1e-2)
def test_batch_processing():
# Create a batch of signals
batch_size = 4
sample_rate = 16000
duration = 0.1 # shorter duration for faster testing
t = torch.linspace(0, duration, int(sample_rate * duration))
frequency = 440.0
signals = torch.sin(2 * np.pi * frequency * t).unsqueeze(0).repeat(batch_size, 1)
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
# Process batch
output = custom_stft(signals)
# Check output shape
assert output.shape[0] == batch_size
assert len(output.shape) == 3 # (batch, 1, time)
def test_different_window_sizes():
signal = torch.randn(1, 16000) # 1 second of random noise
# Test with different window sizes
for filter_length in [512, 1024, 2048]:
custom_stft = CustomSTFT(
filter_length=filter_length,
hop_length=filter_length // 4,
win_length=filter_length,
)
# Forward and backward transform
output = custom_stft(signal)
# Check that output length is reasonable
assert output.shape[-1] >= signal.shape[-1]