* KPipeline

* misaki[en]>=0.6.0

* Fix typo

* Iteration over recursion

* Expose KModel

* Rename model
This commit is contained in:
hexgrad
2025-01-27 21:02:59 -08:00
committed by GitHub
parent 7c248d19d7
commit de2acfcc8a
7 changed files with 260 additions and 556 deletions

4
kokoro/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
__version__ = '0.2.0'
from .models import KModel
from .pipeline import KPipeline

View File

@@ -1,149 +0,0 @@
import phonemizer
import re
import torch
def split_num(num):
num = num.group()
if '.' in num:
return num
elif ':' in num:
h, m = [int(n) for n in num.split(':')]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f'{h} oh {m}'
return f'{h} {m}'
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = 's' if num.endswith('s') else ''
if 100 <= year % 1000 <= 999:
if right == 0:
return f'{left} hundred{s}'
elif right < 10:
return f'{left} oh {right}{s}'
return f'{left} {right}{s}'
def flip_money(m):
m = m.group()
bill = 'dollar' if m[0] == '$' else 'pound'
if m[-1].isalpha():
return f'{m[1:]} {bill}s'
elif '.' not in m:
s = '' if m[1:] == '1' else 's'
return f'{m[1:]} {bill}{s}'
b, c = m[1:].split('.')
s = '' if b == '1' else 's'
c = int(c.ljust(2, '0'))
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
return f'{b} {bill}{s} and {c} {coins}'
def point_num(num):
a, b = num.group().split('.')
return ' point '.join([a, ' '.join(b)])
def normalize_text(text):
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace('«', chr(8220)).replace('»', chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace('(', '«').replace(')', '»')
for a, b in zip('、。!,:;?', ',.!,:;?'):
text = text.replace(a, b+' ')
text = re.sub(r'[^\S \n]', ' ', text)
text = re.sub(r' +', ' ', text)
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
text = re.sub(r'(?<=\d),(?=\d)', '', text)
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
text = re.sub(r'\d*\.\d+', point_num, text)
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
text = re.sub(r'(?<=\d)S', ' S', text)
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", 's', text)
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
return text.strip()
def get_vocab():
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i in range(len((symbols))):
dicts[symbols[i]] = i
return dicts
VOCAB = get_vocab()
def tokenize(ps):
return [i for i in map(VOCAB.get, ps) if i is not None]
phonemizers = dict(
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
)
def phonemize(text, lang, norm=True):
if norm:
text = normalize_text(text)
ps = phonemizers[lang].phonemize([text])
ps = ps[0] if ps else ''
# https://en.wiktionary.org/wiki/kokoro#English
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
if lang == 'a':
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
ps = ''.join(filter(lambda p: p in VOCAB, ps))
return ps.strip()
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
device = ref_s.device
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
c_frame += pred_dur[0,i].item()
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
ps = ps or phonemize(text, lang)
tokens = tokenize(ps)
if not tokens:
return None
elif len(tokens) > 510:
tokens = tokens[:510]
print('Truncated to 510 tokens')
ref_s = voicepack[len(tokens)]
out = forward(model, tokens, ref_s, speed)
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
return out, ps

View File

@@ -1,12 +1,12 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py # https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
from scipy.signal import get_window from scipy.signal import get_window
from torch.nn import Conv1d, ConvTranspose1d from torch.nn.utils import weight_norm
from torch.nn.utils import weight_norm, remove_weight_norm
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
def init_weights(m, mean=0.0, std=0.01): def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__ classname = m.__class__.__name__
@@ -16,7 +16,6 @@ def init_weights(m, mean=0.0, std=0.01):
def get_padding(kernel_size, dilation=1): def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2) return int((kernel_size*dilation - dilation)/2)
LRELU_SLOPE = 0.1
class AdaIN1d(nn.Module): class AdaIN1d(nn.Module):
def __init__(self, style_dim, num_features): def __init__(self, style_dim, num_features):
@@ -30,45 +29,41 @@ class AdaIN1d(nn.Module):
gamma, beta = torch.chunk(h, chunks=2, dim=1) gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta return (1 + gamma) * self.norm(x) + beta
class AdaINResBlock1(torch.nn.Module):
class AdaINResBlock1(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
super(AdaINResBlock1, self).__init__() super(AdaINResBlock1, self).__init__()
self.convs1 = nn.ModuleList([ self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))), padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))), padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))) padding=get_padding(kernel_size, dilation[2])))
]) ])
self.convs1.apply(init_weights) self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([ self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))), padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))), padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))) padding=get_padding(kernel_size, 1)))
]) ])
self.convs2.apply(init_weights) self.convs2.apply(init_weights)
self.adain1 = nn.ModuleList([ self.adain1 = nn.ModuleList([
AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels),
]) ])
self.adain2 = nn.ModuleList([ self.adain2 = nn.ModuleList([
AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels),
]) ])
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]) self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]) self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
def forward(self, x, s): def forward(self, x, s):
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2): for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
xt = n1(x, s) xt = n1(x, s)
@@ -80,13 +75,8 @@ class AdaINResBlock1(torch.nn.Module):
x = xt + x x = xt + x
return x return x
def remove_weight_norm(self):
for l in self.convs1: class TorchSTFT(nn.Module):
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class TorchSTFT(torch.nn.Module):
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
super().__init__() super().__init__()
self.filter_length = filter_length self.filter_length = filter_length
@@ -99,22 +89,21 @@ class TorchSTFT(torch.nn.Module):
input_data, input_data,
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device), self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
return_complex=True) return_complex=True)
return torch.abs(forward_transform), torch.angle(forward_transform) return torch.abs(forward_transform), torch.angle(forward_transform)
def inverse(self, magnitude, phase): def inverse(self, magnitude, phase):
inverse_transform = torch.istft( inverse_transform = torch.istft(
magnitude * torch.exp(phase * 1j), magnitude * torch.exp(phase * 1j),
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device)) self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
def forward(self, input_data): def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data) self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase) reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction return reconstruction
class SineGen(torch.nn.Module):
class SineGen(nn.Module):
""" Definition of sine generator """ Definition of sine generator
SineGen(samp_rate, harmonic_num = 0, SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003, sine_amp = 0.1, noise_std = 0.003,
@@ -129,7 +118,6 @@ class SineGen(torch.nn.Module):
Note: when flag_for_pulse is True, the first time step of a voiced Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0) segment is always sin(np.pi) or cos(0)
""" """
def __init__(self, samp_rate, upsample_scale, harmonic_num=0, def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
sine_amp=0.1, noise_std=0.003, sine_amp=0.1, noise_std=0.003,
voiced_threshold=0, voiced_threshold=0,
@@ -156,52 +144,25 @@ class SineGen(torch.nn.Module):
# convert to F0 in rad. The interger part n can be ignored # convert to F0 in rad. The interger part n can be ignored
# because 2 * np.pi * n doesn't affect phase # because 2 * np.pi * n doesn't affect phase
rad_values = (f0_values / self.sampling_rate) % 1 rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component) # initial phase noise (no noise for fundamental component)
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
device=f0_values.device)
rand_ini[:, 0] = 0 rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
if not self.flag_for_pulse: if not self.flag_for_pulse:
# # for normal case rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
# # To prevent torch.cumsum numerical overflow,
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
# # Buffer tmp_over_one_idx indicates the time step to add -1.
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
# cumsum_shift = torch.zeros_like(rad_values)
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
scale_factor=1/self.upsample_scale,
mode="linear").transpose(1, 2)
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
# cumsum_shift = torch.zeros_like(rad_values)
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale, phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
sines = torch.sin(phase) sines = torch.sin(phase)
else: else:
# If necessary, make sure that the first time step of every # If necessary, make sure that the first time step of every
# voiced segments is sin(pi) or cos(0) # voiced segments is sin(pi) or cos(0)
# This is used for pulse-train generation # This is used for pulse-train generation
# identify the last time step in unvoiced segments # identify the last time step in unvoiced segments
uv = self._f02uv(f0_values) uv = self._f02uv(f0_values)
uv_1 = torch.roll(uv, shifts=-1, dims=1) uv_1 = torch.roll(uv, shifts=-1, dims=1)
uv_1[:, -1, :] = 1 uv_1[:, -1, :] = 1
u_loc = (uv < 1) * (uv_1 > 0) u_loc = (uv < 1) * (uv_1 > 0)
# get the instantanouse phase # get the instantanouse phase
tmp_cumsum = torch.cumsum(rad_values, dim=1) tmp_cumsum = torch.cumsum(rad_values, dim=1)
# different batch needs to be processed differently # different batch needs to be processed differently
@@ -212,11 +173,9 @@ class SineGen(torch.nn.Module):
# each voiced segments # each voiced segments
tmp_cumsum[idx, :, :] = 0 tmp_cumsum[idx, :, :] = 0
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
# rad_values - tmp_cumsum: remove the accumulation of i.phase # rad_values - tmp_cumsum: remove the accumulation of i.phase
# within the previous voiced segment. # within the previous voiced segment.
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
# get the sines # get the sines
sines = torch.cos(i_phase * 2 * np.pi) sines = torch.cos(i_phase * 2 * np.pi)
return sines return sines
@@ -228,32 +187,27 @@ class SineGen(torch.nn.Module):
output sine_tensor: tensor(batchsize=1, length, dim) output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1) output uv: tensor(batchsize=1, length, 1)
""" """
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
device=f0.device)
# fundamental component # fundamental component
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
# generate sine waveforms # generate sine waveforms
sine_waves = self._f02sine(fn) * self.sine_amp sine_waves = self._f02sine(fn) * self.sine_amp
# generate uv signal # generate uv signal
# uv = torch.ones(f0.shape) # uv = torch.ones(f0.shape)
# uv = uv * (f0 > self.voiced_threshold) # uv = uv * (f0 > self.voiced_threshold)
uv = self._f02uv(f0) uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp # noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp # std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std # for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves) noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv # first: set the unvoiced part to 0 by uv
# then: additive noise # then: additive noise
sine_waves = sine_waves * uv + noise sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module): class SourceModuleHnNSF(nn.Module):
""" SourceModule for hn-nsf """ SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0) add_noise_std=0.003, voiced_threshod=0)
@@ -270,21 +224,17 @@ class SourceModuleHnNSF(torch.nn.Module):
noise_source (batchsize, length 1) noise_source (batchsize, length 1)
uv (batchsize, length, 1) uv (batchsize, length, 1)
""" """
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0): add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__() super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp self.sine_amp = sine_amp
self.noise_std = add_noise_std self.noise_std = add_noise_std
# to produce sine waveforms # to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num, self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
sine_amp, add_noise_std, voiced_threshod) sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation # to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_linear = nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh() self.l_tanh = nn.Tanh()
def forward(self, x): def forward(self, x):
""" """
@@ -297,80 +247,63 @@ class SourceModuleHnNSF(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x) sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs)) sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv # source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3 noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv return sine_merge, noise, uv
def padDiff(x):
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
class Generator(torch.nn.Module): class Generator(nn.Module):
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size): def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
resblock = AdaINResBlock1
self.m_source = SourceModuleHnNSF( self.m_source = SourceModuleHnNSF(
sampling_rate=24000, sampling_rate=24000,
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size, upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
harmonic_num=8, voiced_threshod=10) harmonic_num=8, voiced_threshod=10)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size) self.f0_upsamp = nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
self.noise_convs = nn.ModuleList() self.noise_convs = nn.ModuleList()
self.noise_res = nn.ModuleList() self.noise_res = nn.ModuleList()
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm( self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2))) k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1)) ch = upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)): for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d, style_dim)) self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
c_cur = upsample_initial_channel // (2 ** (i + 1)) c_cur = upsample_initial_channel // (2 ** (i + 1))
if i + 1 < len(upsample_rates):
if i + 1 < len(upsample_rates): #
stride_f0 = np.prod(upsample_rates[i + 1:]) stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(Conv1d( self.noise_convs.append(nn.Conv1d(
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim)) self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
else: else:
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)) self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim)) self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
self.post_n_fft = gen_istft_n_fft self.post_n_fft = gen_istft_n_fft
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
self.ups.apply(init_weights) self.ups.apply(init_weights)
self.conv_post.apply(init_weights) self.conv_post.apply(init_weights)
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
def forward(self, x, s, f0): def forward(self, x, s, f0):
with torch.no_grad(): with torch.no_grad():
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, noi_source, uv = self.m_source(f0) har_source, noi_source, uv = self.m_source(f0)
har_source = har_source.transpose(1, 2).squeeze(1) har_source = har_source.transpose(1, 2).squeeze(1)
har_spec, har_phase = self.stft.transform(har_source) har_spec, har_phase = self.stft.transform(har_source)
har = torch.cat([har_spec, har_phase], dim=1) har = torch.cat([har_spec, har_phase], dim=1)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE) x = F.leaky_relu(x, negative_slope=0.1)
x_source = self.noise_convs[i](har) x_source = self.noise_convs[i](har)
x_source = self.noise_res[i](x_source, s) x_source = self.noise_res[i](x_source, s)
x = self.ups[i](x) x = self.ups[i](x)
if i == self.num_upsamples - 1: if i == self.num_upsamples - 1:
x = self.reflection_pad(x) x = self.reflection_pad(x)
x = x + x_source x = x + x_source
xs = None xs = None
for j in range(self.num_kernels): for j in range(self.num_kernels):
@@ -384,38 +317,22 @@ class Generator(torch.nn.Module):
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
return self.stft.inverse(spec, phase) return self.stft.inverse(spec, phase)
def fw_phase(self, x, s):
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x, s)
else:
xs += self.resblocks[i*self.num_kernels+j](x, s)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.reflection_pad(x)
x = self.conv_post(x)
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
return spec, phase
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class UpSample1d(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
else:
return F.interpolate(x, scale_factor=2, mode='nearest')
class AdainResBlk1d(nn.Module): class AdainResBlk1d(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
upsample='none', dropout_p=0.0):
super().__init__() super().__init__()
self.actv = actv self.actv = actv
self.upsample_type = upsample self.upsample_type = upsample
@@ -423,13 +340,11 @@ class AdainResBlk1d(nn.Module):
self.learned_sc = dim_in != dim_out self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim) self._build_weights(dim_in, dim_out, style_dim)
self.dropout = nn.Dropout(dropout_p) self.dropout = nn.Dropout(dropout_p)
if upsample == 'none': if upsample == 'none':
self.pool = nn.Identity() self.pool = nn.Identity()
else: else:
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
def _build_weights(self, dim_in, dim_out, style_dim): def _build_weights(self, dim_in, dim_out, style_dim):
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
@@ -458,59 +373,36 @@ class AdainResBlk1d(nn.Module):
out = self._residual(x, s) out = self._residual(x, s)
out = (out + self._shortcut(x)) / np.sqrt(2) out = (out + self._shortcut(x)) / np.sqrt(2)
return out return out
class UpSample1d(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
else:
return F.interpolate(x, scale_factor=2, mode='nearest')
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80, def __init__(self, dim_in, style_dim, dim_out,
resblock_kernel_sizes = [3,7,11], resblock_kernel_sizes,
upsample_rates = [10, 6], upsample_rates,
upsample_initial_channel=512, upsample_initial_channel,
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]], resblock_dilation_sizes,
upsample_kernel_sizes=[20, 12], upsample_kernel_sizes,
gen_istft_n_fft=20, gen_istft_hop_size=5): gen_istft_n_fft, gen_istft_hop_size):
super().__init__() super().__init__()
self.decode = nn.ModuleList()
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
self.decode = nn.ModuleList()
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
self.asr_res = nn.Sequential(
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
)
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
upsample_initial_channel, resblock_dilation_sizes, upsample_initial_channel, resblock_dilation_sizes,
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size) upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
def forward(self, asr, F0_curve, N, s): def forward(self, asr, F0_curve, N, s):
F0 = self.F0_conv(F0_curve.unsqueeze(1)) F0 = self.F0_conv(F0_curve.unsqueeze(1))
N = self.N_conv(N.unsqueeze(1)) N = self.N_conv(N.unsqueeze(1))
x = torch.cat([asr, F0, N], axis=1) x = torch.cat([asr, F0, N], axis=1)
x = self.encode(x, s) x = self.encode(x, s)
asr_res = self.asr_res(asr) asr_res = self.asr_res(asr)
res = True res = True
for block in self.decode: for block in self.decode:
if res: if res:
@@ -518,6 +410,5 @@ class Decoder(nn.Module):
x = block(x, s) x = block(x, s)
if block.upsample_type != "none": if block.upsample_type != "none":
res = False res = False
x = self.generator(x, s, F0_curve) x = self.generator(x, s, F0_curve)
return x return x

View File

@@ -1,35 +1,28 @@
# https://github.com/yl4579/StyleTTS2/blob/main/models.py # https://github.com/yl4579/StyleTTS2/blob/main/models.py
from istftnet import AdaIN1d, Decoder from .istftnet import AdaIN1d, AdainResBlk1d, Decoder
from munch import Munch from torch.nn.utils import weight_norm
from pathlib import Path from transformers import AlbertConfig, AlbertModel
from plbert import load_plbert
from torch.nn.utils import weight_norm, spectral_norm
import json
import numpy as np import numpy as np
import os
import os.path as osp
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class LinearNorm(torch.nn.Module):
class LinearNorm(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__() super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x): def forward(self, x):
return self.linear_layer(x) return self.linear_layer(x)
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5): def __init__(self, channels, eps=1e-5):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.eps = eps self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels)) self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels)) self.beta = nn.Parameter(torch.zeros(channels))
@@ -37,12 +30,12 @@ class LayerNorm(nn.Module):
x = x.transpose(1, -1) x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1) return x.transpose(1, -1)
class TextEncoder(nn.Module): class TextEncoder(nn.Module):
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)): def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
super().__init__() super().__init__()
self.embedding = nn.Embedding(n_symbols, channels) self.embedding = nn.Embedding(n_symbols, channels)
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
self.cnn = nn.ModuleList() self.cnn = nn.ModuleList()
for _ in range(depth): for _ in range(depth):
@@ -52,8 +45,6 @@ class TextEncoder(nn.Module):
actv, actv,
nn.Dropout(0.2), nn.Dropout(0.2),
)) ))
# self.cnn = nn.Sequential(*self.cnn)
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True) self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
def forward(self, x, input_lengths, m): def forward(self, x, input_lengths, m):
@@ -61,234 +52,110 @@ class TextEncoder(nn.Module):
x = x.transpose(1, 2) # [B, emb, T] x = x.transpose(1, 2) # [B, emb, T]
m = m.to(input_lengths.device).unsqueeze(1) m = m.to(input_lengths.device).unsqueeze(1)
x.masked_fill_(m, 0.0) x.masked_fill_(m, 0.0)
for c in self.cnn: for c in self.cnn:
x = c(x) x = c(x)
x.masked_fill_(m, 0.0) x.masked_fill_(m, 0.0)
x = x.transpose(1, 2) # [B, T, chn] x = x.transpose(1, 2) # [B, T, chn]
input_lengths = input_lengths.cpu().numpy() input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence( x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
x, input_lengths, batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters() self.lstm.flatten_parameters()
x, _ = self.lstm(x) x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence( x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
x, batch_first=True)
x = x.transpose(-1, -2) x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
x_pad[:, :, :x.shape[-1]] = x x_pad[:, :, :x.shape[-1]] = x
x = x_pad.to(x.device) x = x_pad.to(x.device)
x.masked_fill_(m, 0.0) x.masked_fill_(m, 0.0)
return x return x
def inference(self, x):
x = self.embedding(x)
x = x.transpose(1, 2)
x = self.cnn(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
return x
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
class UpSample1d(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
else:
return F.interpolate(x, scale_factor=2, mode='nearest')
class AdainResBlk1d(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
upsample='none', dropout_p=0.0):
super().__init__()
self.actv = actv
self.upsample_type = upsample
self.upsample = UpSample1d(upsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
self.dropout = nn.Dropout(dropout_p)
if upsample == 'none':
self.pool = nn.Identity()
else:
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
def _build_weights(self, dim_in, dim_out, style_dim):
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
self.norm1 = AdaIN1d(style_dim, dim_in)
self.norm2 = AdaIN1d(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
def _shortcut(self, x):
x = self.upsample(x)
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
x = self.pool(x)
x = self.conv1(self.dropout(x))
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(self.dropout(x))
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / np.sqrt(2)
return out
class AdaLayerNorm(nn.Module): class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5): def __init__(self, style_dim, channels, eps=1e-5):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.eps = eps self.eps = eps
self.fc = nn.Linear(style_dim, channels*2) self.fc = nn.Linear(style_dim, channels*2)
def forward(self, x, s): def forward(self, x, s):
x = x.transpose(-1, -2) x = x.transpose(-1, -2)
x = x.transpose(1, -1) x = x.transpose(1, -1)
h = self.fc(s) h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1) h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1) gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps) x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2) return x.transpose(1, -1).transpose(-1, -2)
class ProsodyPredictor(nn.Module): class ProsodyPredictor(nn.Module):
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1): def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
super().__init__() super().__init__()
self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
self.text_encoder = DurationEncoder(sty_dim=style_dim,
d_model=d_hid,
nlayers=nlayers,
dropout=dropout)
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.duration_proj = LinearNorm(d_hid, max_dur) self.duration_proj = LinearNorm(d_hid, max_dur)
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.F0 = nn.ModuleList() self.F0 = nn.ModuleList()
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.N = nn.ModuleList() self.N = nn.ModuleList()
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
def forward(self, texts, style, text_lengths, alignment, m): def forward(self, texts, style, text_lengths, alignment, m):
d = self.text_encoder(texts, style, text_lengths, m) d = self.text_encoder(texts, style, text_lengths, m)
batch_size = d.shape[0] batch_size = d.shape[0]
text_size = d.shape[1] text_size = d.shape[1]
# predict duration
input_lengths = text_lengths.cpu().numpy() input_lengths = text_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence( x = nn.utils.rnn.pack_padded_sequence(d, input_lengths, batch_first=True, enforce_sorted=False)
d, input_lengths, batch_first=True, enforce_sorted=False)
m = m.to(text_lengths.device).unsqueeze(1) m = m.to(text_lengths.device).unsqueeze(1)
self.lstm.flatten_parameters() self.lstm.flatten_parameters()
x, _ = self.lstm(x) x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence( x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
x, batch_first=True)
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]]) x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
x_pad[:, :x.shape[1], :] = x x_pad[:, :x.shape[1], :] = x
x = x_pad.to(x.device) x = x_pad.to(x.device)
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
en = (d.transpose(-1, -2) @ alignment) en = (d.transpose(-1, -2) @ alignment)
return duration.squeeze(-1), en return duration.squeeze(-1), en
def F0Ntrain(self, x, s): def F0Ntrain(self, x, s):
x, _ = self.shared(x.transpose(-1, -2)) x, _ = self.shared(x.transpose(-1, -2))
F0 = x.transpose(-1, -2) F0 = x.transpose(-1, -2)
for block in self.F0: for block in self.F0:
F0 = block(F0, s) F0 = block(F0, s)
F0 = self.F0_proj(F0) F0 = self.F0_proj(F0)
N = x.transpose(-1, -2) N = x.transpose(-1, -2)
for block in self.N: for block in self.N:
N = block(N, s) N = block(N, s)
N = self.N_proj(N) N = self.N_proj(N)
return F0.squeeze(1), N.squeeze(1) return F0.squeeze(1), N.squeeze(1)
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
class DurationEncoder(nn.Module): class DurationEncoder(nn.Module):
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1): def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
super().__init__() super().__init__()
self.lstms = nn.ModuleList() self.lstms = nn.ModuleList()
for _ in range(nlayers): for _ in range(nlayers):
self.lstms.append(nn.LSTM(d_model + sty_dim, self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
d_model // 2,
num_layers=1,
batch_first=True,
bidirectional=True,
dropout=dropout))
self.lstms.append(AdaLayerNorm(sty_dim, d_model)) self.lstms.append(AdaLayerNorm(sty_dim, d_model))
self.dropout = dropout self.dropout = dropout
self.d_model = d_model self.d_model = d_model
self.sty_dim = sty_dim self.sty_dim = sty_dim
def forward(self, x, style, text_lengths, m): def forward(self, x, style, text_lengths, m):
masks = m.to(text_lengths.device) masks = m.to(text_lengths.device)
x = x.permute(2, 0, 1) x = x.permute(2, 0, 1)
s = style.expand(x.shape[0], x.shape[1], -1) s = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, s], axis=-1) x = torch.cat([x, s], axis=-1)
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0) x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
x = x.transpose(0, 1) x = x.transpose(0, 1)
input_lengths = text_lengths.cpu().numpy() input_lengths = text_lengths.cpu().numpy()
x = x.transpose(-1, -2) x = x.transpose(-1, -2)
for block in self.lstms: for block in self.lstms:
if isinstance(block, AdaLayerNorm): if isinstance(block, AdaLayerNorm):
x = block(x.transpose(-1, -2), style).transpose(-1, -2) x = block(x.transpose(-1, -2), style).transpose(-1, -2)
@@ -302,71 +169,73 @@ class DurationEncoder(nn.Module):
x, _ = block(x) x, _ = block(x)
x, _ = nn.utils.rnn.pad_packed_sequence( x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True) x, batch_first=True)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=False)
x = x.transpose(-1, -2) x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
x_pad[:, :, :x.shape[-1]] = x x_pad[:, :, :x.shape[-1]] = x
x = x_pad.to(x.device) x = x_pad.to(x.device)
return x.transpose(-1, -2) return x.transpose(-1, -2)
def inference(self, x, style):
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model) # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
style = style.expand(x.shape[0], x.shape[1], -1) class CustomAlbert(AlbertModel):
x = torch.cat([x, style], axis=-1) def forward(self, *args, **kwargs):
src = self.pos_encoder(x) outputs = super().forward(*args, **kwargs)
output = self.transformer_encoder(src).transpose(0, 1) return outputs.last_hidden_state
return output
def length_to_mask(self, lengths): class KModel(nn.Module):
def __init__(self, config, path):
super().__init__()
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
self.predictor = ProsodyPredictor(
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
)
self.text_encoder = TextEncoder(
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
depth=config['n_layer'], n_symbols=config['n_token']
)
self.decoder = Decoder(
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
dim_out=config['n_mels'], **config['istftnet']
)
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key
try:
getattr(self, key).load_state_dict(state_dict)
except:
state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False)
@classmethod
def length_to_mask(cls, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1)) mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask return mask
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py @torch.no_grad()
def recursive_munch(d): def forward(self, input_ids, ref_s, speed):
if isinstance(d, dict): device = ref_s.device
return Munch((k, recursive_munch(v)) for k, v in d.items()) input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(device)
elif isinstance(d, list): input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(device)
return [recursive_munch(v) for v in d] text_mask = type(self).length_to_mask(input_lengths).to(device)
else: bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
return d d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
def build_model(path, device): d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
config = Path(__file__).parent / 'config.json' x, _ = self.predictor.lstm(d)
assert config.exists(), f'Config path incorrect: config.json not found at {config}' duration = self.predictor.duration_proj(x)
with open(config, 'r') as r: duration = torch.sigmoid(duration).sum(axis=-1) / speed
args = recursive_munch(json.load(r)) pred_dur = torch.round(duration).clamp(min=1).long()
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}' pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, c_frame = 0
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, for i in range(pred_aln_trg.size(0)):
upsample_rates = args.decoder.upsample_rates, pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
upsample_initial_channel=args.decoder.upsample_initial_channel, c_frame += pred_dur[0,i].item()
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size) t_en = self.text_encoder(input_ids, input_lengths, text_mask)
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token) asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout) return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
bert = load_plbert()
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
for child in parent.children():
if isinstance(child, nn.RNNBase):
child.flatten_parameters()
model = Munch(
bert=bert.to(device).eval(),
bert_encoder=bert_encoder.to(device).eval(),
predictor=predictor.to(device).eval(),
decoder=decoder.to(device).eval(),
text_encoder=text_encoder.to(device).eval(),
)
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
assert key in model, key
try:
model[key].load_state_dict(state_dict)
except:
state_dict = {k[7:]: v for k, v in state_dict.items()}
model[key].load_state_dict(state_dict, strict=False)
return model

103
kokoro/pipeline.py Normal file
View File

@@ -0,0 +1,103 @@
from .models import KModel
from huggingface_hub import hf_hub_download
from misaki import en, espeak
import json
import os
import re
import torch
LANG_CODES = dict(
a='American English',
b='British English',
)
REPO_ID = 'hexgrad/Kokoro-82M'
class KPipeline:
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False):
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
self.lang_code = lang_code
if config_path is None:
config_path = hf_hub_download(repo_id=REPO_ID, filename='config.json')
assert os.path.exists(config_path)
with open(config_path, 'r') as r:
config = json.load(r)
if model_path is None:
model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth')
assert os.path.exists(model_path)
try:
fallback = espeak.EspeakFallback(british=lang_code=='b')
except Exception as e:
print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e)
fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
self.vocab = config['vocab']
self.model = KModel(config, model_path)
self.voices = {}
def load_voice(self, voice):
if voice in self.voices:
return
v = voice.split('/')[-1]
if not v.startswith(self.lang_code):
v = LANG_CODES.get(v, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code)
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.')
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt')
assert os.path.exists(voice_path)
self.voices[voice] = torch.load(voice_path, weights_only=True)
@classmethod
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', ''}):
for w in waterfall:
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None)
if z is not None:
z += 1
if z < len(pairs) and pairs[z][1].strip() in bumps:
z += 1
_, ps = zip(*pairs[:z])
if next_count - len(''.join(ps)) <= 510:
return z
return len(pairs)
def tokenize(self, tokens):
pairs = []
count = 0
for w in tokens:
for t in (w if isinstance(w, list) else [w]):
if t.phonemes is None:
continue
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else ''
next_ps += ''.join(filter(lambda p: p in self.vocab, t.phonemes.replace('ɾ', 'T'))) # American English: ɾ => T
next_ps += ' ' if t.whitespace else ''
next_count = count + len(next_ps.rstrip())
if next_count > 510:
z = type(self).waterfall_last(pairs, next_count)
text, ps = zip(*pairs[:z])
ps = ''.join(ps)
yield ''.join(text).strip(), ps.strip()
pairs = pairs[z:]
count -= len(ps)
if not pairs:
next_ps = next_ps.lstrip()
pairs.append((t.text + t.whitespace, next_ps))
count += len(next_ps)
if pairs:
text, ps = zip(*pairs)
yield ''.join(text).strip(), ''.join(ps).strip()
def __call__(self, text, voice='af', speed=1, split_pattern=r'\n+'):
assert isinstance(text, str) or isinstance(text, list), type(text)
self.load_voice(voice)
if isinstance(text, str) and split_pattern:
text = re.split(split_pattern, text.strip())
for graphemes in text:
_, tokens = self.g2p(graphemes)
for gs, ps in self.tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
print('TODO: Unexpected len(ps) > 510', len(ps), ps)
continue
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps)))
assert input_ids and len(input_ids) <= 510, input_ids
yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)

View File

@@ -1,15 +0,0 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
from transformers import AlbertConfig, AlbertModel
class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs):
# Call the original forward method
outputs = super().forward(*args, **kwargs)
# Only return the last_hidden_state
return outputs.last_hidden_state
def load_plbert():
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
albert_base_configuration = AlbertConfig(**plbert_config)
bert = CustomAlbert(albert_base_configuration)
return bert

View File

@@ -2,11 +2,12 @@ from setuptools import setup, find_packages
setup( setup(
name='kokoro', # Name of the package name='kokoro', # Name of the package
version='0.1.0', # Initial version version='0.2.0', # Initial version
packages=find_packages(), # Automatically finds packages packages=find_packages(), # Automatically finds packages
install_requires=[ # List your dependencies here install_requires=[ # List your dependencies here
'huggingface_hub',
'misaki[en]>=0.6.1',
'numpy', 'numpy',
'phonemizer',
'scipy', 'scipy',
'torch', 'torch',
'transformers', 'transformers',