82 lines
2.7 KiB
Python
82 lines
2.7 KiB
Python
import torch
|
|
import numpy as np
|
|
import pytest
|
|
from kokoro.custom_stft import CustomSTFT
|
|
from kokoro.istftnet import TorchSTFT
|
|
import torch.nn.functional as F
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_audio():
|
|
# Generate a sample audio signal (sine wave)
|
|
sample_rate = 16000
|
|
duration = 1.0 # seconds
|
|
t = torch.linspace(0, duration, int(sample_rate * duration))
|
|
frequency = 440.0 # Hz
|
|
signal = torch.sin(2 * np.pi * frequency * t)
|
|
return signal.unsqueeze(0) # Add batch dimension
|
|
|
|
|
|
def test_stft_reconstruction(sample_audio):
|
|
# Initialize both STFT implementations
|
|
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
|
|
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
|
|
|
|
# Process through both implementations
|
|
custom_output = custom_stft(sample_audio)
|
|
torch_output = torch_stft(sample_audio)
|
|
|
|
# Compare outputs
|
|
assert torch.allclose(custom_output, torch_output, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
def test_magnitude_phase_consistency(sample_audio):
|
|
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
|
|
torch_stft = TorchSTFT(filter_length=800, hop_length=200, win_length=800)
|
|
|
|
# Get magnitude and phase from both implementations
|
|
custom_mag, custom_phase = custom_stft.transform(sample_audio)
|
|
torch_mag, torch_phase = torch_stft.transform(sample_audio)
|
|
|
|
# Compare magnitudes ignoring the boundary frames
|
|
custom_mag_center = custom_mag[..., 2:-2]
|
|
torch_mag_center = torch_mag[..., 2:-2]
|
|
assert torch.allclose(custom_mag_center, torch_mag_center, rtol=1e-2, atol=1e-2)
|
|
|
|
|
|
def test_batch_processing():
|
|
# Create a batch of signals
|
|
batch_size = 4
|
|
sample_rate = 16000
|
|
duration = 0.1 # shorter duration for faster testing
|
|
t = torch.linspace(0, duration, int(sample_rate * duration))
|
|
frequency = 440.0
|
|
signals = torch.sin(2 * np.pi * frequency * t).unsqueeze(0).repeat(batch_size, 1)
|
|
|
|
custom_stft = CustomSTFT(filter_length=800, hop_length=200, win_length=800)
|
|
|
|
# Process batch
|
|
output = custom_stft(signals)
|
|
|
|
# Check output shape
|
|
assert output.shape[0] == batch_size
|
|
assert len(output.shape) == 3 # (batch, 1, time)
|
|
|
|
|
|
def test_different_window_sizes():
|
|
signal = torch.randn(1, 16000) # 1 second of random noise
|
|
|
|
# Test with different window sizes
|
|
for filter_length in [512, 1024, 2048]:
|
|
custom_stft = CustomSTFT(
|
|
filter_length=filter_length,
|
|
hop_length=filter_length // 4,
|
|
win_length=filter_length,
|
|
)
|
|
|
|
# Forward and backward transform
|
|
output = custom_stft(signal)
|
|
|
|
# Check that output length is reasonable
|
|
assert output.shape[-1] >= signal.shape[-1]
|