Dev (#11)
* KPipeline * misaki[en]>=0.6.0 * Fix typo * Iteration over recursion * Expose KModel * Rename model
This commit is contained in:
4
kokoro/__init__.py
Normal file
4
kokoro/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__version__ = '0.2.0'
|
||||||
|
|
||||||
|
from .models import KModel
|
||||||
|
from .pipeline import KPipeline
|
||||||
149
kokoro/infer.py
149
kokoro/infer.py
@@ -1,149 +0,0 @@
|
|||||||
import phonemizer
|
|
||||||
import re
|
|
||||||
import torch
|
|
||||||
|
|
||||||
def split_num(num):
|
|
||||||
num = num.group()
|
|
||||||
if '.' in num:
|
|
||||||
return num
|
|
||||||
elif ':' in num:
|
|
||||||
h, m = [int(n) for n in num.split(':')]
|
|
||||||
if m == 0:
|
|
||||||
return f"{h} o'clock"
|
|
||||||
elif m < 10:
|
|
||||||
return f'{h} oh {m}'
|
|
||||||
return f'{h} {m}'
|
|
||||||
year = int(num[:4])
|
|
||||||
if year < 1100 or year % 1000 < 10:
|
|
||||||
return num
|
|
||||||
left, right = num[:2], int(num[2:4])
|
|
||||||
s = 's' if num.endswith('s') else ''
|
|
||||||
if 100 <= year % 1000 <= 999:
|
|
||||||
if right == 0:
|
|
||||||
return f'{left} hundred{s}'
|
|
||||||
elif right < 10:
|
|
||||||
return f'{left} oh {right}{s}'
|
|
||||||
return f'{left} {right}{s}'
|
|
||||||
|
|
||||||
def flip_money(m):
|
|
||||||
m = m.group()
|
|
||||||
bill = 'dollar' if m[0] == '$' else 'pound'
|
|
||||||
if m[-1].isalpha():
|
|
||||||
return f'{m[1:]} {bill}s'
|
|
||||||
elif '.' not in m:
|
|
||||||
s = '' if m[1:] == '1' else 's'
|
|
||||||
return f'{m[1:]} {bill}{s}'
|
|
||||||
b, c = m[1:].split('.')
|
|
||||||
s = '' if b == '1' else 's'
|
|
||||||
c = int(c.ljust(2, '0'))
|
|
||||||
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
|
||||||
return f'{b} {bill}{s} and {c} {coins}'
|
|
||||||
|
|
||||||
def point_num(num):
|
|
||||||
a, b = num.group().split('.')
|
|
||||||
return ' point '.join([a, ' '.join(b)])
|
|
||||||
|
|
||||||
def normalize_text(text):
|
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
|
||||||
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
|
||||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
|
||||||
text = text.replace('(', '«').replace(')', '»')
|
|
||||||
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
|
||||||
text = text.replace(a, b+' ')
|
|
||||||
text = re.sub(r'[^\S \n]', ' ', text)
|
|
||||||
text = re.sub(r' +', ' ', text)
|
|
||||||
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
|
||||||
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
|
||||||
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
|
||||||
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
|
||||||
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
|
||||||
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
|
||||||
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
|
||||||
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
|
||||||
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
|
||||||
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
|
||||||
text = re.sub(r'\d*\.\d+', point_num, text)
|
|
||||||
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
|
||||||
text = re.sub(r'(?<=\d)S', ' S', text)
|
|
||||||
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
|
||||||
text = re.sub(r"(?<=X')S\b", 's', text)
|
|
||||||
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
|
||||||
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
def get_vocab():
|
|
||||||
_pad = "$"
|
|
||||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
|
||||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
|
||||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
|
||||||
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
|
||||||
dicts = {}
|
|
||||||
for i in range(len((symbols))):
|
|
||||||
dicts[symbols[i]] = i
|
|
||||||
return dicts
|
|
||||||
|
|
||||||
VOCAB = get_vocab()
|
|
||||||
def tokenize(ps):
|
|
||||||
return [i for i in map(VOCAB.get, ps) if i is not None]
|
|
||||||
|
|
||||||
phonemizers = dict(
|
|
||||||
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
|
||||||
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
|
||||||
)
|
|
||||||
def phonemize(text, lang, norm=True):
|
|
||||||
if norm:
|
|
||||||
text = normalize_text(text)
|
|
||||||
ps = phonemizers[lang].phonemize([text])
|
|
||||||
ps = ps[0] if ps else ''
|
|
||||||
# https://en.wiktionary.org/wiki/kokoro#English
|
|
||||||
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
|
||||||
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
|
||||||
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
|
||||||
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
|
||||||
if lang == 'a':
|
|
||||||
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
|
||||||
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
|
||||||
return ps.strip()
|
|
||||||
|
|
||||||
def length_to_mask(lengths):
|
|
||||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
|
||||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
|
||||||
return mask
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(model, tokens, ref_s, speed):
|
|
||||||
device = ref_s.device
|
|
||||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
|
||||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
|
||||||
text_mask = length_to_mask(input_lengths).to(device)
|
|
||||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
|
||||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
|
||||||
s = ref_s[:, 128:]
|
|
||||||
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
|
||||||
x, _ = model.predictor.lstm(d)
|
|
||||||
duration = model.predictor.duration_proj(x)
|
|
||||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
|
||||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
|
||||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
|
||||||
c_frame = 0
|
|
||||||
for i in range(pred_aln_trg.size(0)):
|
|
||||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
|
||||||
c_frame += pred_dur[0,i].item()
|
|
||||||
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
|
||||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
|
||||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
|
||||||
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
|
||||||
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
|
||||||
|
|
||||||
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
|
|
||||||
ps = ps or phonemize(text, lang)
|
|
||||||
tokens = tokenize(ps)
|
|
||||||
if not tokens:
|
|
||||||
return None
|
|
||||||
elif len(tokens) > 510:
|
|
||||||
tokens = tokens[:510]
|
|
||||||
print('Truncated to 510 tokens')
|
|
||||||
ref_s = voicepack[len(tokens)]
|
|
||||||
out = forward(model, tokens, ref_s, speed)
|
|
||||||
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
|
||||||
return out, ps
|
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||||
from scipy.signal import get_window
|
from scipy.signal import get_window
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
from torch.nn.utils import weight_norm
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||||
def init_weights(m, mean=0.0, std=0.01):
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
@@ -16,7 +16,6 @@ def init_weights(m, mean=0.0, std=0.01):
|
|||||||
def get_padding(kernel_size, dilation=1):
|
def get_padding(kernel_size, dilation=1):
|
||||||
return int((kernel_size*dilation - dilation)/2)
|
return int((kernel_size*dilation - dilation)/2)
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
|
||||||
|
|
||||||
class AdaIN1d(nn.Module):
|
class AdaIN1d(nn.Module):
|
||||||
def __init__(self, style_dim, num_features):
|
def __init__(self, style_dim, num_features):
|
||||||
@@ -30,45 +29,41 @@ class AdaIN1d(nn.Module):
|
|||||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
return (1 + gamma) * self.norm(x) + beta
|
return (1 + gamma) * self.norm(x) + beta
|
||||||
|
|
||||||
class AdaINResBlock1(torch.nn.Module):
|
|
||||||
|
class AdaINResBlock1(nn.Module):
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
||||||
super(AdaINResBlock1, self).__init__()
|
super(AdaINResBlock1, self).__init__()
|
||||||
self.convs1 = nn.ModuleList([
|
self.convs1 = nn.ModuleList([
|
||||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||||
padding=get_padding(kernel_size, dilation[0]))),
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||||
padding=get_padding(kernel_size, dilation[1]))),
|
padding=get_padding(kernel_size, dilation[1]))),
|
||||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||||
padding=get_padding(kernel_size, dilation[2])))
|
padding=get_padding(kernel_size, dilation[2])))
|
||||||
])
|
])
|
||||||
self.convs1.apply(init_weights)
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
self.convs2 = nn.ModuleList([
|
self.convs2 = nn.ModuleList([
|
||||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
padding=get_padding(kernel_size, 1))),
|
padding=get_padding(kernel_size, 1))),
|
||||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
padding=get_padding(kernel_size, 1))),
|
padding=get_padding(kernel_size, 1))),
|
||||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
padding=get_padding(kernel_size, 1)))
|
padding=get_padding(kernel_size, 1)))
|
||||||
])
|
])
|
||||||
self.convs2.apply(init_weights)
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
self.adain1 = nn.ModuleList([
|
self.adain1 = nn.ModuleList([
|
||||||
AdaIN1d(style_dim, channels),
|
AdaIN1d(style_dim, channels),
|
||||||
AdaIN1d(style_dim, channels),
|
AdaIN1d(style_dim, channels),
|
||||||
AdaIN1d(style_dim, channels),
|
AdaIN1d(style_dim, channels),
|
||||||
])
|
])
|
||||||
|
|
||||||
self.adain2 = nn.ModuleList([
|
self.adain2 = nn.ModuleList([
|
||||||
AdaIN1d(style_dim, channels),
|
AdaIN1d(style_dim, channels),
|
||||||
AdaIN1d(style_dim, channels),
|
AdaIN1d(style_dim, channels),
|
||||||
AdaIN1d(style_dim, channels),
|
AdaIN1d(style_dim, channels),
|
||||||
])
|
])
|
||||||
|
|
||||||
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
||||||
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x, s):
|
||||||
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
||||||
xt = n1(x, s)
|
xt = n1(x, s)
|
||||||
@@ -80,13 +75,8 @@ class AdaINResBlock1(torch.nn.Module):
|
|||||||
x = xt + x
|
x = xt + x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
for l in self.convs1:
|
|
||||||
remove_weight_norm(l)
|
|
||||||
for l in self.convs2:
|
|
||||||
remove_weight_norm(l)
|
|
||||||
|
|
||||||
class TorchSTFT(torch.nn.Module):
|
class TorchSTFT(nn.Module):
|
||||||
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.filter_length = filter_length
|
self.filter_length = filter_length
|
||||||
@@ -99,14 +89,12 @@ class TorchSTFT(torch.nn.Module):
|
|||||||
input_data,
|
input_data,
|
||||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
||||||
return_complex=True)
|
return_complex=True)
|
||||||
|
|
||||||
return torch.abs(forward_transform), torch.angle(forward_transform)
|
return torch.abs(forward_transform), torch.angle(forward_transform)
|
||||||
|
|
||||||
def inverse(self, magnitude, phase):
|
def inverse(self, magnitude, phase):
|
||||||
inverse_transform = torch.istft(
|
inverse_transform = torch.istft(
|
||||||
magnitude * torch.exp(phase * 1j),
|
magnitude * torch.exp(phase * 1j),
|
||||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
||||||
|
|
||||||
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||||
|
|
||||||
def forward(self, input_data):
|
def forward(self, input_data):
|
||||||
@@ -114,7 +102,8 @@ class TorchSTFT(torch.nn.Module):
|
|||||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||||
return reconstruction
|
return reconstruction
|
||||||
|
|
||||||
class SineGen(torch.nn.Module):
|
|
||||||
|
class SineGen(nn.Module):
|
||||||
""" Definition of sine generator
|
""" Definition of sine generator
|
||||||
SineGen(samp_rate, harmonic_num = 0,
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
sine_amp = 0.1, noise_std = 0.003,
|
sine_amp = 0.1, noise_std = 0.003,
|
||||||
@@ -129,7 +118,6 @@ class SineGen(torch.nn.Module):
|
|||||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||||
segment is always sin(np.pi) or cos(0)
|
segment is always sin(np.pi) or cos(0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||||
sine_amp=0.1, noise_std=0.003,
|
sine_amp=0.1, noise_std=0.003,
|
||||||
voiced_threshold=0,
|
voiced_threshold=0,
|
||||||
@@ -156,52 +144,25 @@ class SineGen(torch.nn.Module):
|
|||||||
# convert to F0 in rad. The interger part n can be ignored
|
# convert to F0 in rad. The interger part n can be ignored
|
||||||
# because 2 * np.pi * n doesn't affect phase
|
# because 2 * np.pi * n doesn't affect phase
|
||||||
rad_values = (f0_values / self.sampling_rate) % 1
|
rad_values = (f0_values / self.sampling_rate) % 1
|
||||||
|
|
||||||
# initial phase noise (no noise for fundamental component)
|
# initial phase noise (no noise for fundamental component)
|
||||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||||
device=f0_values.device)
|
|
||||||
rand_ini[:, 0] = 0
|
rand_ini[:, 0] = 0
|
||||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||||
|
|
||||||
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||||
if not self.flag_for_pulse:
|
if not self.flag_for_pulse:
|
||||||
# # for normal case
|
rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
|
||||||
|
|
||||||
# # To prevent torch.cumsum numerical overflow,
|
|
||||||
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
|
||||||
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
|
||||||
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
|
||||||
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
|
||||||
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
|
||||||
# cumsum_shift = torch.zeros_like(rad_values)
|
|
||||||
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
|
||||||
|
|
||||||
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
|
||||||
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
|
||||||
scale_factor=1/self.upsample_scale,
|
|
||||||
mode="linear").transpose(1, 2)
|
|
||||||
|
|
||||||
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
|
||||||
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
|
||||||
# cumsum_shift = torch.zeros_like(rad_values)
|
|
||||||
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
|
||||||
|
|
||||||
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||||
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||||
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
|
||||||
sines = torch.sin(phase)
|
sines = torch.sin(phase)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# If necessary, make sure that the first time step of every
|
# If necessary, make sure that the first time step of every
|
||||||
# voiced segments is sin(pi) or cos(0)
|
# voiced segments is sin(pi) or cos(0)
|
||||||
# This is used for pulse-train generation
|
# This is used for pulse-train generation
|
||||||
|
|
||||||
# identify the last time step in unvoiced segments
|
# identify the last time step in unvoiced segments
|
||||||
uv = self._f02uv(f0_values)
|
uv = self._f02uv(f0_values)
|
||||||
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||||
uv_1[:, -1, :] = 1
|
uv_1[:, -1, :] = 1
|
||||||
u_loc = (uv < 1) * (uv_1 > 0)
|
u_loc = (uv < 1) * (uv_1 > 0)
|
||||||
|
|
||||||
# get the instantanouse phase
|
# get the instantanouse phase
|
||||||
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||||
# different batch needs to be processed differently
|
# different batch needs to be processed differently
|
||||||
@@ -212,11 +173,9 @@ class SineGen(torch.nn.Module):
|
|||||||
# each voiced segments
|
# each voiced segments
|
||||||
tmp_cumsum[idx, :, :] = 0
|
tmp_cumsum[idx, :, :] = 0
|
||||||
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||||
|
|
||||||
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||||
# within the previous voiced segment.
|
# within the previous voiced segment.
|
||||||
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||||
|
|
||||||
# get the sines
|
# get the sines
|
||||||
sines = torch.cos(i_phase * 2 * np.pi)
|
sines = torch.cos(i_phase * 2 * np.pi)
|
||||||
return sines
|
return sines
|
||||||
@@ -228,32 +187,27 @@ class SineGen(torch.nn.Module):
|
|||||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
output uv: tensor(batchsize=1, length, 1)
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
"""
|
"""
|
||||||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
||||||
device=f0.device)
|
|
||||||
# fundamental component
|
# fundamental component
|
||||||
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||||
|
|
||||||
# generate sine waveforms
|
# generate sine waveforms
|
||||||
sine_waves = self._f02sine(fn) * self.sine_amp
|
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||||
|
|
||||||
# generate uv signal
|
# generate uv signal
|
||||||
# uv = torch.ones(f0.shape)
|
# uv = torch.ones(f0.shape)
|
||||||
# uv = uv * (f0 > self.voiced_threshold)
|
# uv = uv * (f0 > self.voiced_threshold)
|
||||||
uv = self._f02uv(f0)
|
uv = self._f02uv(f0)
|
||||||
|
|
||||||
# noise: for unvoiced should be similar to sine_amp
|
# noise: for unvoiced should be similar to sine_amp
|
||||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||||
# . for voiced regions is self.noise_std
|
# for voiced regions is self.noise_std
|
||||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||||
noise = noise_amp * torch.randn_like(sine_waves)
|
noise = noise_amp * torch.randn_like(sine_waves)
|
||||||
|
|
||||||
# first: set the unvoiced part to 0 by uv
|
# first: set the unvoiced part to 0 by uv
|
||||||
# then: additive noise
|
# then: additive noise
|
||||||
sine_waves = sine_waves * uv + noise
|
sine_waves = sine_waves * uv + noise
|
||||||
return sine_waves, uv, noise
|
return sine_waves, uv, noise
|
||||||
|
|
||||||
|
|
||||||
class SourceModuleHnNSF(torch.nn.Module):
|
class SourceModuleHnNSF(nn.Module):
|
||||||
""" SourceModule for hn-nsf
|
""" SourceModule for hn-nsf
|
||||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||||
add_noise_std=0.003, voiced_threshod=0)
|
add_noise_std=0.003, voiced_threshod=0)
|
||||||
@@ -270,21 +224,17 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||||||
noise_source (batchsize, length 1)
|
noise_source (batchsize, length 1)
|
||||||
uv (batchsize, length, 1)
|
uv (batchsize, length, 1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||||
add_noise_std=0.003, voiced_threshod=0):
|
add_noise_std=0.003, voiced_threshod=0):
|
||||||
super(SourceModuleHnNSF, self).__init__()
|
super(SourceModuleHnNSF, self).__init__()
|
||||||
|
|
||||||
self.sine_amp = sine_amp
|
self.sine_amp = sine_amp
|
||||||
self.noise_std = add_noise_std
|
self.noise_std = add_noise_std
|
||||||
|
|
||||||
# to produce sine waveforms
|
# to produce sine waveforms
|
||||||
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
||||||
sine_amp, add_noise_std, voiced_threshod)
|
sine_amp, add_noise_std, voiced_threshod)
|
||||||
|
|
||||||
# to merge source harmonics into a single excitation
|
# to merge source harmonics into a single excitation
|
||||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
||||||
self.l_tanh = torch.nn.Tanh()
|
self.l_tanh = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
@@ -297,80 +247,63 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||||
|
|
||||||
# source for noise branch, in the same shape as uv
|
# source for noise branch, in the same shape as uv
|
||||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||||
return sine_merge, noise, uv
|
return sine_merge, noise, uv
|
||||||
def padDiff(x):
|
|
||||||
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(torch.nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
|
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
resblock = AdaINResBlock1
|
|
||||||
|
|
||||||
self.m_source = SourceModuleHnNSF(
|
self.m_source = SourceModuleHnNSF(
|
||||||
sampling_rate=24000,
|
sampling_rate=24000,
|
||||||
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
||||||
harmonic_num=8, voiced_threshod=10)
|
harmonic_num=8, voiced_threshod=10)
|
||||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
self.f0_upsamp = nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
||||||
self.noise_convs = nn.ModuleList()
|
self.noise_convs = nn.ModuleList()
|
||||||
self.noise_res = nn.ModuleList()
|
self.noise_res = nn.ModuleList()
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
self.ups.append(weight_norm(
|
self.ups.append(weight_norm(
|
||||||
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||||
k, u, padding=(k-u)//2)))
|
k, u, padding=(k-u)//2)))
|
||||||
|
|
||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel//(2**(i+1))
|
ch = upsample_initial_channel//(2**(i+1))
|
||||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
||||||
self.resblocks.append(resblock(ch, k, d, style_dim))
|
self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
|
||||||
|
|
||||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
if i + 1 < len(upsample_rates):
|
||||||
if i + 1 < len(upsample_rates): #
|
|
||||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||||
self.noise_convs.append(Conv1d(
|
self.noise_convs.append(nn.Conv1d(
|
||||||
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
||||||
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
|
||||||
else:
|
else:
|
||||||
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
||||||
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
|
||||||
|
|
||||||
|
|
||||||
self.post_n_fft = gen_istft_n_fft
|
self.post_n_fft = gen_istft_n_fft
|
||||||
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
||||||
self.ups.apply(init_weights)
|
self.ups.apply(init_weights)
|
||||||
self.conv_post.apply(init_weights)
|
self.conv_post.apply(init_weights)
|
||||||
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||||
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, s, f0):
|
def forward(self, x, s, f0):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
|
|
||||||
har_source, noi_source, uv = self.m_source(f0)
|
har_source, noi_source, uv = self.m_source(f0)
|
||||||
har_source = har_source.transpose(1, 2).squeeze(1)
|
har_source = har_source.transpose(1, 2).squeeze(1)
|
||||||
har_spec, har_phase = self.stft.transform(har_source)
|
har_spec, har_phase = self.stft.transform(har_source)
|
||||||
har = torch.cat([har_spec, har_phase], dim=1)
|
har = torch.cat([har_spec, har_phase], dim=1)
|
||||||
|
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
x = F.leaky_relu(x, negative_slope=0.1)
|
||||||
x_source = self.noise_convs[i](har)
|
x_source = self.noise_convs[i](har)
|
||||||
x_source = self.noise_res[i](x_source, s)
|
x_source = self.noise_res[i](x_source, s)
|
||||||
|
|
||||||
x = self.ups[i](x)
|
x = self.ups[i](x)
|
||||||
if i == self.num_upsamples - 1:
|
if i == self.num_upsamples - 1:
|
||||||
x = self.reflection_pad(x)
|
x = self.reflection_pad(x)
|
||||||
|
|
||||||
x = x + x_source
|
x = x + x_source
|
||||||
xs = None
|
xs = None
|
||||||
for j in range(self.num_kernels):
|
for j in range(self.num_kernels):
|
||||||
@@ -385,37 +318,21 @@ class Generator(torch.nn.Module):
|
|||||||
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||||
return self.stft.inverse(spec, phase)
|
return self.stft.inverse(spec, phase)
|
||||||
|
|
||||||
def fw_phase(self, x, s):
|
|
||||||
for i in range(self.num_upsamples):
|
|
||||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
|
||||||
x = self.ups[i](x)
|
|
||||||
xs = None
|
|
||||||
for j in range(self.num_kernels):
|
|
||||||
if xs is None:
|
|
||||||
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
|
||||||
else:
|
|
||||||
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
|
||||||
x = xs / self.num_kernels
|
|
||||||
x = F.leaky_relu(x)
|
|
||||||
x = self.reflection_pad(x)
|
|
||||||
x = self.conv_post(x)
|
|
||||||
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
|
||||||
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
|
||||||
return spec, phase
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
class UpSample1d(nn.Module):
|
||||||
print('Removing weight norm...')
|
def __init__(self, layer_type):
|
||||||
for l in self.ups:
|
super().__init__()
|
||||||
remove_weight_norm(l)
|
self.layer_type = layer_type
|
||||||
for l in self.resblocks:
|
|
||||||
l.remove_weight_norm()
|
def forward(self, x):
|
||||||
remove_weight_norm(self.conv_pre)
|
if self.layer_type == 'none':
|
||||||
remove_weight_norm(self.conv_post)
|
return x
|
||||||
|
else:
|
||||||
|
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
|
||||||
class AdainResBlk1d(nn.Module):
|
class AdainResBlk1d(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
|
||||||
upsample='none', dropout_p=0.0):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.actv = actv
|
self.actv = actv
|
||||||
self.upsample_type = upsample
|
self.upsample_type = upsample
|
||||||
@@ -423,13 +340,11 @@ class AdainResBlk1d(nn.Module):
|
|||||||
self.learned_sc = dim_in != dim_out
|
self.learned_sc = dim_in != dim_out
|
||||||
self._build_weights(dim_in, dim_out, style_dim)
|
self._build_weights(dim_in, dim_out, style_dim)
|
||||||
self.dropout = nn.Dropout(dropout_p)
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
if upsample == 'none':
|
if upsample == 'none':
|
||||||
self.pool = nn.Identity()
|
self.pool = nn.Identity()
|
||||||
else:
|
else:
|
||||||
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||||
|
|
||||||
|
|
||||||
def _build_weights(self, dim_in, dim_out, style_dim):
|
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||||
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||||
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||||
@@ -459,45 +374,25 @@ class AdainResBlk1d(nn.Module):
|
|||||||
out = (out + self._shortcut(x)) / np.sqrt(2)
|
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class UpSample1d(nn.Module):
|
|
||||||
def __init__(self, layer_type):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_type = layer_type
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.layer_type == 'none':
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
return F.interpolate(x, scale_factor=2, mode='nearest')
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
def __init__(self, dim_in, style_dim, dim_out,
|
||||||
resblock_kernel_sizes = [3,7,11],
|
resblock_kernel_sizes,
|
||||||
upsample_rates = [10, 6],
|
upsample_rates,
|
||||||
upsample_initial_channel=512,
|
upsample_initial_channel,
|
||||||
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
resblock_dilation_sizes,
|
||||||
upsample_kernel_sizes=[20, 12],
|
upsample_kernel_sizes,
|
||||||
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
gen_istft_n_fft, gen_istft_hop_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.decode = nn.ModuleList()
|
|
||||||
|
|
||||||
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||||
|
self.decode = nn.ModuleList()
|
||||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
||||||
|
|
||||||
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
|
||||||
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
|
||||||
self.asr_res = nn.Sequential(
|
|
||||||
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||||
upsample_initial_channel, resblock_dilation_sizes,
|
upsample_initial_channel, resblock_dilation_sizes,
|
||||||
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
||||||
@@ -505,12 +400,9 @@ class Decoder(nn.Module):
|
|||||||
def forward(self, asr, F0_curve, N, s):
|
def forward(self, asr, F0_curve, N, s):
|
||||||
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||||
N = self.N_conv(N.unsqueeze(1))
|
N = self.N_conv(N.unsqueeze(1))
|
||||||
|
|
||||||
x = torch.cat([asr, F0, N], axis=1)
|
x = torch.cat([asr, F0, N], axis=1)
|
||||||
x = self.encode(x, s)
|
x = self.encode(x, s)
|
||||||
|
|
||||||
asr_res = self.asr_res(asr)
|
asr_res = self.asr_res(asr)
|
||||||
|
|
||||||
res = True
|
res = True
|
||||||
for block in self.decode:
|
for block in self.decode:
|
||||||
if res:
|
if res:
|
||||||
@@ -518,6 +410,5 @@ class Decoder(nn.Module):
|
|||||||
x = block(x, s)
|
x = block(x, s)
|
||||||
if block.upsample_type != "none":
|
if block.upsample_type != "none":
|
||||||
res = False
|
res = False
|
||||||
|
|
||||||
x = self.generator(x, s, F0_curve)
|
x = self.generator(x, s, F0_curve)
|
||||||
return x
|
return x
|
||||||
|
|||||||
283
kokoro/models.py
283
kokoro/models.py
@@ -1,35 +1,28 @@
|
|||||||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||||
from istftnet import AdaIN1d, Decoder
|
from .istftnet import AdaIN1d, AdainResBlk1d, Decoder
|
||||||
from munch import Munch
|
from torch.nn.utils import weight_norm
|
||||||
from pathlib import Path
|
from transformers import AlbertConfig, AlbertModel
|
||||||
from plbert import load_plbert
|
|
||||||
from torch.nn.utils import weight_norm, spectral_norm
|
|
||||||
import json
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class LinearNorm(torch.nn.Module):
|
|
||||||
|
class LinearNorm(nn.Module):
|
||||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||||
super(LinearNorm, self).__init__()
|
super(LinearNorm, self).__init__()
|
||||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
|
||||||
|
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
|
||||||
torch.nn.init.xavier_uniform_(
|
|
||||||
self.linear_layer.weight,
|
|
||||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear_layer(x)
|
return self.linear_layer(x)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, channels, eps=1e-5):
|
def __init__(self, channels, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
self.gamma = nn.Parameter(torch.ones(channels))
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
self.beta = nn.Parameter(torch.zeros(channels))
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
@@ -38,11 +31,11 @@ class LayerNorm(nn.Module):
|
|||||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
return x.transpose(1, -1)
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
class TextEncoder(nn.Module):
|
class TextEncoder(nn.Module):
|
||||||
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(n_symbols, channels)
|
self.embedding = nn.Embedding(n_symbols, channels)
|
||||||
|
|
||||||
padding = (kernel_size - 1) // 2
|
padding = (kernel_size - 1) // 2
|
||||||
self.cnn = nn.ModuleList()
|
self.cnn = nn.ModuleList()
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
@@ -52,8 +45,6 @@ class TextEncoder(nn.Module):
|
|||||||
actv,
|
actv,
|
||||||
nn.Dropout(0.2),
|
nn.Dropout(0.2),
|
||||||
))
|
))
|
||||||
# self.cnn = nn.Sequential(*self.cnn)
|
|
||||||
|
|
||||||
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
def forward(self, x, input_lengths, m):
|
def forward(self, x, input_lengths, m):
|
||||||
@@ -61,234 +52,110 @@ class TextEncoder(nn.Module):
|
|||||||
x = x.transpose(1, 2) # [B, emb, T]
|
x = x.transpose(1, 2) # [B, emb, T]
|
||||||
m = m.to(input_lengths.device).unsqueeze(1)
|
m = m.to(input_lengths.device).unsqueeze(1)
|
||||||
x.masked_fill_(m, 0.0)
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
for c in self.cnn:
|
for c in self.cnn:
|
||||||
x = c(x)
|
x = c(x)
|
||||||
x.masked_fill_(m, 0.0)
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
x = x.transpose(1, 2) # [B, T, chn]
|
x = x.transpose(1, 2) # [B, T, chn]
|
||||||
|
|
||||||
input_lengths = input_lengths.cpu().numpy()
|
input_lengths = input_lengths.cpu().numpy()
|
||||||
x = nn.utils.rnn.pack_padded_sequence(
|
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
x, input_lengths, batch_first=True, enforce_sorted=False)
|
|
||||||
|
|
||||||
self.lstm.flatten_parameters()
|
self.lstm.flatten_parameters()
|
||||||
x, _ = self.lstm(x)
|
x, _ = self.lstm(x)
|
||||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||||
x, batch_first=True)
|
|
||||||
|
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||||
|
|
||||||
x_pad[:, :, :x.shape[-1]] = x
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
x = x_pad.to(x.device)
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
x.masked_fill_(m, 0.0)
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def inference(self, x):
|
|
||||||
x = self.embedding(x)
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
x = self.cnn(x)
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
self.lstm.flatten_parameters()
|
|
||||||
x, _ = self.lstm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def length_to_mask(self, lengths):
|
|
||||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
|
||||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class UpSample1d(nn.Module):
|
|
||||||
def __init__(self, layer_type):
|
|
||||||
super().__init__()
|
|
||||||
self.layer_type = layer_type
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.layer_type == 'none':
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
return F.interpolate(x, scale_factor=2, mode='nearest')
|
|
||||||
|
|
||||||
class AdainResBlk1d(nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
|
||||||
upsample='none', dropout_p=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.actv = actv
|
|
||||||
self.upsample_type = upsample
|
|
||||||
self.upsample = UpSample1d(upsample)
|
|
||||||
self.learned_sc = dim_in != dim_out
|
|
||||||
self._build_weights(dim_in, dim_out, style_dim)
|
|
||||||
self.dropout = nn.Dropout(dropout_p)
|
|
||||||
|
|
||||||
if upsample == 'none':
|
|
||||||
self.pool = nn.Identity()
|
|
||||||
else:
|
|
||||||
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_weights(self, dim_in, dim_out, style_dim):
|
|
||||||
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
|
||||||
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
|
||||||
self.norm1 = AdaIN1d(style_dim, dim_in)
|
|
||||||
self.norm2 = AdaIN1d(style_dim, dim_out)
|
|
||||||
if self.learned_sc:
|
|
||||||
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
|
||||||
|
|
||||||
def _shortcut(self, x):
|
|
||||||
x = self.upsample(x)
|
|
||||||
if self.learned_sc:
|
|
||||||
x = self.conv1x1(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _residual(self, x, s):
|
|
||||||
x = self.norm1(x, s)
|
|
||||||
x = self.actv(x)
|
|
||||||
x = self.pool(x)
|
|
||||||
x = self.conv1(self.dropout(x))
|
|
||||||
x = self.norm2(x, s)
|
|
||||||
x = self.actv(x)
|
|
||||||
x = self.conv2(self.dropout(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x, s):
|
|
||||||
out = self._residual(x, s)
|
|
||||||
out = (out + self._shortcut(x)) / np.sqrt(2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class AdaLayerNorm(nn.Module):
|
class AdaLayerNorm(nn.Module):
|
||||||
def __init__(self, style_dim, channels, eps=1e-5):
|
def __init__(self, style_dim, channels, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
self.fc = nn.Linear(style_dim, channels*2)
|
self.fc = nn.Linear(style_dim, channels*2)
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x, s):
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
x = x.transpose(1, -1)
|
x = x.transpose(1, -1)
|
||||||
|
|
||||||
h = self.fc(s)
|
h = self.fc(s)
|
||||||
h = h.view(h.size(0), h.size(1), 1)
|
h = h.view(h.size(0), h.size(1), 1)
|
||||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
||||||
x = (1 + gamma) * x + beta
|
x = (1 + gamma) * x + beta
|
||||||
return x.transpose(1, -1).transpose(-1, -2)
|
return x.transpose(1, -1).transpose(-1, -2)
|
||||||
|
|
||||||
class ProsodyPredictor(nn.Module):
|
|
||||||
|
|
||||||
|
class ProsodyPredictor(nn.Module):
|
||||||
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
|
||||||
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
|
||||||
d_model=d_hid,
|
|
||||||
nlayers=nlayers,
|
|
||||||
dropout=dropout)
|
|
||||||
|
|
||||||
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
self.duration_proj = LinearNorm(d_hid, max_dur)
|
self.duration_proj = LinearNorm(d_hid, max_dur)
|
||||||
|
|
||||||
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
self.F0 = nn.ModuleList()
|
self.F0 = nn.ModuleList()
|
||||||
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
|
||||||
self.N = nn.ModuleList()
|
self.N = nn.ModuleList()
|
||||||
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
|
||||||
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, texts, style, text_lengths, alignment, m):
|
def forward(self, texts, style, text_lengths, alignment, m):
|
||||||
d = self.text_encoder(texts, style, text_lengths, m)
|
d = self.text_encoder(texts, style, text_lengths, m)
|
||||||
|
|
||||||
batch_size = d.shape[0]
|
batch_size = d.shape[0]
|
||||||
text_size = d.shape[1]
|
text_size = d.shape[1]
|
||||||
|
|
||||||
# predict duration
|
|
||||||
input_lengths = text_lengths.cpu().numpy()
|
input_lengths = text_lengths.cpu().numpy()
|
||||||
x = nn.utils.rnn.pack_padded_sequence(
|
x = nn.utils.rnn.pack_padded_sequence(d, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
d, input_lengths, batch_first=True, enforce_sorted=False)
|
|
||||||
|
|
||||||
m = m.to(text_lengths.device).unsqueeze(1)
|
m = m.to(text_lengths.device).unsqueeze(1)
|
||||||
|
|
||||||
self.lstm.flatten_parameters()
|
self.lstm.flatten_parameters()
|
||||||
x, _ = self.lstm(x)
|
x, _ = self.lstm(x)
|
||||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||||
x, batch_first=True)
|
|
||||||
|
|
||||||
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
||||||
|
|
||||||
x_pad[:, :x.shape[1], :] = x
|
x_pad[:, :x.shape[1], :] = x
|
||||||
x = x_pad.to(x.device)
|
x = x_pad.to(x.device)
|
||||||
|
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
|
||||||
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
|
|
||||||
|
|
||||||
en = (d.transpose(-1, -2) @ alignment)
|
en = (d.transpose(-1, -2) @ alignment)
|
||||||
|
|
||||||
return duration.squeeze(-1), en
|
return duration.squeeze(-1), en
|
||||||
|
|
||||||
def F0Ntrain(self, x, s):
|
def F0Ntrain(self, x, s):
|
||||||
x, _ = self.shared(x.transpose(-1, -2))
|
x, _ = self.shared(x.transpose(-1, -2))
|
||||||
|
|
||||||
F0 = x.transpose(-1, -2)
|
F0 = x.transpose(-1, -2)
|
||||||
for block in self.F0:
|
for block in self.F0:
|
||||||
F0 = block(F0, s)
|
F0 = block(F0, s)
|
||||||
F0 = self.F0_proj(F0)
|
F0 = self.F0_proj(F0)
|
||||||
|
|
||||||
N = x.transpose(-1, -2)
|
N = x.transpose(-1, -2)
|
||||||
for block in self.N:
|
for block in self.N:
|
||||||
N = block(N, s)
|
N = block(N, s)
|
||||||
N = self.N_proj(N)
|
N = self.N_proj(N)
|
||||||
|
|
||||||
return F0.squeeze(1), N.squeeze(1)
|
return F0.squeeze(1), N.squeeze(1)
|
||||||
|
|
||||||
def length_to_mask(self, lengths):
|
|
||||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
|
||||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
|
||||||
return mask
|
|
||||||
|
|
||||||
class DurationEncoder(nn.Module):
|
class DurationEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lstms = nn.ModuleList()
|
self.lstms = nn.ModuleList()
|
||||||
for _ in range(nlayers):
|
for _ in range(nlayers):
|
||||||
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
|
||||||
d_model // 2,
|
|
||||||
num_layers=1,
|
|
||||||
batch_first=True,
|
|
||||||
bidirectional=True,
|
|
||||||
dropout=dropout))
|
|
||||||
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
||||||
|
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.sty_dim = sty_dim
|
self.sty_dim = sty_dim
|
||||||
|
|
||||||
def forward(self, x, style, text_lengths, m):
|
def forward(self, x, style, text_lengths, m):
|
||||||
masks = m.to(text_lengths.device)
|
masks = m.to(text_lengths.device)
|
||||||
|
|
||||||
x = x.permute(2, 0, 1)
|
x = x.permute(2, 0, 1)
|
||||||
s = style.expand(x.shape[0], x.shape[1], -1)
|
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||||
x = torch.cat([x, s], axis=-1)
|
x = torch.cat([x, s], axis=-1)
|
||||||
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||||
|
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
input_lengths = text_lengths.cpu().numpy()
|
input_lengths = text_lengths.cpu().numpy()
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
for block in self.lstms:
|
for block in self.lstms:
|
||||||
if isinstance(block, AdaLayerNorm):
|
if isinstance(block, AdaLayerNorm):
|
||||||
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||||
@@ -302,71 +169,73 @@ class DurationEncoder(nn.Module):
|
|||||||
x, _ = block(x)
|
x, _ = block(x)
|
||||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
x, batch_first=True)
|
x, batch_first=True)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=False)
|
||||||
x = x.transpose(-1, -2)
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||||
|
|
||||||
x_pad[:, :, :x.shape[-1]] = x
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
x = x_pad.to(x.device)
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
return x.transpose(-1, -2)
|
return x.transpose(-1, -2)
|
||||||
|
|
||||||
def inference(self, x, style):
|
|
||||||
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
|
||||||
style = style.expand(x.shape[0], x.shape[1], -1)
|
|
||||||
x = torch.cat([x, style], axis=-1)
|
|
||||||
src = self.pos_encoder(x)
|
|
||||||
output = self.transformer_encoder(src).transpose(0, 1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def length_to_mask(self, lengths):
|
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||||
|
class CustomAlbert(AlbertModel):
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
outputs = super().forward(*args, **kwargs)
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class KModel(nn.Module):
|
||||||
|
def __init__(self, config, path):
|
||||||
|
super().__init__()
|
||||||
|
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
|
||||||
|
self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
|
||||||
|
self.predictor = ProsodyPredictor(
|
||||||
|
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
|
||||||
|
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
|
||||||
|
)
|
||||||
|
self.text_encoder = TextEncoder(
|
||||||
|
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
|
||||||
|
depth=config['n_layer'], n_symbols=config['n_token']
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
|
||||||
|
dim_out=config['n_mels'], **config['istftnet']
|
||||||
|
)
|
||||||
|
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True).items():
|
||||||
|
assert hasattr(self, key), key
|
||||||
|
try:
|
||||||
|
getattr(self, key).load_state_dict(state_dict)
|
||||||
|
except:
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
getattr(self, key).load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def length_to_mask(cls, lengths):
|
||||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
@torch.no_grad()
|
||||||
def recursive_munch(d):
|
def forward(self, input_ids, ref_s, speed):
|
||||||
if isinstance(d, dict):
|
device = ref_s.device
|
||||||
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(device)
|
||||||
elif isinstance(d, list):
|
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(device)
|
||||||
return [recursive_munch(v) for v in d]
|
text_mask = type(self).length_to_mask(input_lengths).to(device)
|
||||||
else:
|
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
|
||||||
return d
|
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
|
s = ref_s[:, 128:]
|
||||||
def build_model(path, device):
|
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||||
config = Path(__file__).parent / 'config.json'
|
x, _ = self.predictor.lstm(d)
|
||||||
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
duration = self.predictor.duration_proj(x)
|
||||||
with open(config, 'r') as r:
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
args = recursive_munch(json.load(r))
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||||
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||||
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
c_frame = 0
|
||||||
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
for i in range(pred_aln_trg.size(0)):
|
||||||
upsample_rates = args.decoder.upsample_rates,
|
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||||
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
c_frame += pred_dur[0,i].item()
|
||||||
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||||
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||||
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
|
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||||
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||||
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||||
bert = load_plbert()
|
|
||||||
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
|
|
||||||
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
|
|
||||||
for child in parent.children():
|
|
||||||
if isinstance(child, nn.RNNBase):
|
|
||||||
child.flatten_parameters()
|
|
||||||
model = Munch(
|
|
||||||
bert=bert.to(device).eval(),
|
|
||||||
bert_encoder=bert_encoder.to(device).eval(),
|
|
||||||
predictor=predictor.to(device).eval(),
|
|
||||||
decoder=decoder.to(device).eval(),
|
|
||||||
text_encoder=text_encoder.to(device).eval(),
|
|
||||||
)
|
|
||||||
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
|
||||||
assert key in model, key
|
|
||||||
try:
|
|
||||||
model[key].load_state_dict(state_dict)
|
|
||||||
except:
|
|
||||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
|
||||||
model[key].load_state_dict(state_dict, strict=False)
|
|
||||||
return model
|
|
||||||
|
|||||||
103
kokoro/pipeline.py
Normal file
103
kokoro/pipeline.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from .models import KModel
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from misaki import en, espeak
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
LANG_CODES = dict(
|
||||||
|
a='American English',
|
||||||
|
b='British English',
|
||||||
|
)
|
||||||
|
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||||
|
|
||||||
|
class KPipeline:
|
||||||
|
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False):
|
||||||
|
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||||
|
self.lang_code = lang_code
|
||||||
|
if config_path is None:
|
||||||
|
config_path = hf_hub_download(repo_id=REPO_ID, filename='config.json')
|
||||||
|
assert os.path.exists(config_path)
|
||||||
|
with open(config_path, 'r') as r:
|
||||||
|
config = json.load(r)
|
||||||
|
if model_path is None:
|
||||||
|
model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth')
|
||||||
|
assert os.path.exists(model_path)
|
||||||
|
try:
|
||||||
|
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
||||||
|
except Exception as e:
|
||||||
|
print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e)
|
||||||
|
fallback = None
|
||||||
|
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
||||||
|
self.vocab = config['vocab']
|
||||||
|
self.model = KModel(config, model_path)
|
||||||
|
self.voices = {}
|
||||||
|
|
||||||
|
def load_voice(self, voice):
|
||||||
|
if voice in self.voices:
|
||||||
|
return
|
||||||
|
v = voice.split('/')[-1]
|
||||||
|
if not v.startswith(self.lang_code):
|
||||||
|
v = LANG_CODES.get(v, voice)
|
||||||
|
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
||||||
|
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.')
|
||||||
|
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt')
|
||||||
|
assert os.path.exists(voice_path)
|
||||||
|
self.voices[voice] = torch.load(voice_path, weights_only=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}):
|
||||||
|
for w in waterfall:
|
||||||
|
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None)
|
||||||
|
if z is not None:
|
||||||
|
z += 1
|
||||||
|
if z < len(pairs) and pairs[z][1].strip() in bumps:
|
||||||
|
z += 1
|
||||||
|
_, ps = zip(*pairs[:z])
|
||||||
|
if next_count - len(''.join(ps)) <= 510:
|
||||||
|
return z
|
||||||
|
return len(pairs)
|
||||||
|
|
||||||
|
def tokenize(self, tokens):
|
||||||
|
pairs = []
|
||||||
|
count = 0
|
||||||
|
for w in tokens:
|
||||||
|
for t in (w if isinstance(w, list) else [w]):
|
||||||
|
if t.phonemes is None:
|
||||||
|
continue
|
||||||
|
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else ''
|
||||||
|
next_ps += ''.join(filter(lambda p: p in self.vocab, t.phonemes.replace('ɾ', 'T'))) # American English: ɾ => T
|
||||||
|
next_ps += ' ' if t.whitespace else ''
|
||||||
|
next_count = count + len(next_ps.rstrip())
|
||||||
|
if next_count > 510:
|
||||||
|
z = type(self).waterfall_last(pairs, next_count)
|
||||||
|
text, ps = zip(*pairs[:z])
|
||||||
|
ps = ''.join(ps)
|
||||||
|
yield ''.join(text).strip(), ps.strip()
|
||||||
|
pairs = pairs[z:]
|
||||||
|
count -= len(ps)
|
||||||
|
if not pairs:
|
||||||
|
next_ps = next_ps.lstrip()
|
||||||
|
pairs.append((t.text + t.whitespace, next_ps))
|
||||||
|
count += len(next_ps)
|
||||||
|
if pairs:
|
||||||
|
text, ps = zip(*pairs)
|
||||||
|
yield ''.join(text).strip(), ''.join(ps).strip()
|
||||||
|
|
||||||
|
def __call__(self, text, voice='af', speed=1, split_pattern=r'\n+'):
|
||||||
|
assert isinstance(text, str) or isinstance(text, list), type(text)
|
||||||
|
self.load_voice(voice)
|
||||||
|
if isinstance(text, str) and split_pattern:
|
||||||
|
text = re.split(split_pattern, text.strip())
|
||||||
|
for graphemes in text:
|
||||||
|
_, tokens = self.g2p(graphemes)
|
||||||
|
for gs, ps in self.tokenize(tokens):
|
||||||
|
if not ps:
|
||||||
|
continue
|
||||||
|
elif len(ps) > 510:
|
||||||
|
print('TODO: Unexpected len(ps) > 510', len(ps), ps)
|
||||||
|
continue
|
||||||
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps)))
|
||||||
|
assert input_ids and len(input_ids) <= 510, input_ids
|
||||||
|
yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
|
||||||
from transformers import AlbertConfig, AlbertModel
|
|
||||||
|
|
||||||
class CustomAlbert(AlbertModel):
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
# Call the original forward method
|
|
||||||
outputs = super().forward(*args, **kwargs)
|
|
||||||
# Only return the last_hidden_state
|
|
||||||
return outputs.last_hidden_state
|
|
||||||
|
|
||||||
def load_plbert():
|
|
||||||
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
|
||||||
albert_base_configuration = AlbertConfig(**plbert_config)
|
|
||||||
bert = CustomAlbert(albert_base_configuration)
|
|
||||||
return bert
|
|
||||||
5
setup.py
5
setup.py
@@ -2,11 +2,12 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='kokoro', # Name of the package
|
name='kokoro', # Name of the package
|
||||||
version='0.1.0', # Initial version
|
version='0.2.0', # Initial version
|
||||||
packages=find_packages(), # Automatically finds packages
|
packages=find_packages(), # Automatically finds packages
|
||||||
install_requires=[ # List your dependencies here
|
install_requires=[ # List your dependencies here
|
||||||
|
'huggingface_hub',
|
||||||
|
'misaki[en]>=0.6.1',
|
||||||
'numpy',
|
'numpy',
|
||||||
'phonemizer',
|
|
||||||
'scipy',
|
'scipy',
|
||||||
'torch',
|
'torch',
|
||||||
'transformers',
|
'transformers',
|
||||||
|
|||||||
Reference in New Issue
Block a user