From de2acfcc8a78b3eceeff4ef8a5a4c08d9deb1bb7 Mon Sep 17 00:00:00 2001 From: hexgrad <166769057+hexgrad@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:02:59 -0800 Subject: [PATCH] Dev (#11) * KPipeline * misaki[en]>=0.6.0 * Fix typo * Iteration over recursion * Expose KModel * Rename model --- kokoro/__init__.py | 4 + kokoro/infer.py | 149 ----------------------- kokoro/istftnet.py | 245 +++++++++++-------------------------- kokoro/models.py | 295 +++++++++++++-------------------------------- kokoro/pipeline.py | 103 ++++++++++++++++ kokoro/plbert.py | 15 --- setup.py | 5 +- 7 files changed, 260 insertions(+), 556 deletions(-) create mode 100644 kokoro/__init__.py delete mode 100644 kokoro/infer.py create mode 100644 kokoro/pipeline.py delete mode 100644 kokoro/plbert.py diff --git a/kokoro/__init__.py b/kokoro/__init__.py new file mode 100644 index 0000000..c11f3fa --- /dev/null +++ b/kokoro/__init__.py @@ -0,0 +1,4 @@ +__version__ = '0.2.0' + +from .models import KModel +from .pipeline import KPipeline diff --git a/kokoro/infer.py b/kokoro/infer.py deleted file mode 100644 index d33cb95..0000000 --- a/kokoro/infer.py +++ /dev/null @@ -1,149 +0,0 @@ -import phonemizer -import re -import torch - -def split_num(num): - num = num.group() - if '.' in num: - return num - elif ':' in num: - h, m = [int(n) for n in num.split(':')] - if m == 0: - return f"{h} o'clock" - elif m < 10: - return f'{h} oh {m}' - return f'{h} {m}' - year = int(num[:4]) - if year < 1100 or year % 1000 < 10: - return num - left, right = num[:2], int(num[2:4]) - s = 's' if num.endswith('s') else '' - if 100 <= year % 1000 <= 999: - if right == 0: - return f'{left} hundred{s}' - elif right < 10: - return f'{left} oh {right}{s}' - return f'{left} {right}{s}' - -def flip_money(m): - m = m.group() - bill = 'dollar' if m[0] == '$' else 'pound' - if m[-1].isalpha(): - return f'{m[1:]} {bill}s' - elif '.' not in m: - s = '' if m[1:] == '1' else 's' - return f'{m[1:]} {bill}{s}' - b, c = m[1:].split('.') - s = '' if b == '1' else 's' - c = int(c.ljust(2, '0')) - coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence') - return f'{b} {bill}{s} and {c} {coins}' - -def point_num(num): - a, b = num.group().split('.') - return ' point '.join([a, ' '.join(b)]) - -def normalize_text(text): - text = text.replace(chr(8216), "'").replace(chr(8217), "'") - text = text.replace('«', chr(8220)).replace('»', chr(8221)) - text = text.replace(chr(8220), '"').replace(chr(8221), '"') - text = text.replace('(', '«').replace(')', '»') - for a, b in zip('、。!,:;?', ',.!,:;?'): - text = text.replace(a, b+' ') - text = re.sub(r'[^\S \n]', ' ', text) - text = re.sub(r' +', ' ', text) - text = re.sub(r'(?<=\n) +(?=\n)', '', text) - text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text) - text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text) - text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text) - text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text) - text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text) - text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text) - text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(? 510: - tokens = tokens[:510] - print('Truncated to 510 tokens') - ref_s = voicepack[len(tokens)] - out = forward(model, tokens, ref_s, speed) - ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens) - return out, ps diff --git a/kokoro/istftnet.py b/kokoro/istftnet.py index da29481..929c478 100644 --- a/kokoro/istftnet.py +++ b/kokoro/istftnet.py @@ -1,12 +1,12 @@ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py from scipy.signal import get_window -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm +from torch.nn.utils import weight_norm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F + # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ @@ -16,7 +16,6 @@ def init_weights(m, mean=0.0, std=0.01): def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) -LRELU_SLOPE = 0.1 class AdaIN1d(nn.Module): def __init__(self, style_dim, num_features): @@ -30,45 +29,41 @@ class AdaIN1d(nn.Module): gamma, beta = torch.chunk(h, chunks=2, dim=1) return (1 + gamma) * self.norm(x) + beta -class AdaINResBlock1(torch.nn.Module): + +class AdaINResBlock1(nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64): super(AdaINResBlock1, self).__init__() self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]))) + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) - self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))) + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) - self.adain1 = nn.ModuleList([ AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels), ]) - self.adain2 = nn.ModuleList([ AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels), AdaIN1d(style_dim, channels), ]) - self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]) self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]) - def forward(self, x, s): for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2): xt = n1(x, s) @@ -80,13 +75,8 @@ class AdaINResBlock1(torch.nn.Module): x = xt + x return x - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - -class TorchSTFT(torch.nn.Module): + +class TorchSTFT(nn.Module): def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): super().__init__() self.filter_length = filter_length @@ -99,22 +89,21 @@ class TorchSTFT(torch.nn.Module): input_data, self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device), return_complex=True) - return torch.abs(forward_transform), torch.angle(forward_transform) def inverse(self, magnitude, phase): inverse_transform = torch.istft( magnitude * torch.exp(phase * 1j), self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device)) - return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation def forward(self, input_data): self.magnitude, self.phase = self.transform(input_data) reconstruction = self.inverse(self.magnitude, self.phase) return reconstruction - -class SineGen(torch.nn.Module): + + +class SineGen(nn.Module): """ Definition of sine generator SineGen(samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, @@ -129,7 +118,6 @@ class SineGen(torch.nn.Module): Note: when flag_for_pulse is True, the first time step of a voiced segment is always sin(np.pi) or cos(0) """ - def __init__(self, samp_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, @@ -156,52 +144,25 @@ class SineGen(torch.nn.Module): # convert to F0 in rad. The interger part n can be ignored # because 2 * np.pi * n doesn't affect phase rad_values = (f0_values / self.sampling_rate) % 1 - # initial phase noise (no noise for fundamental component) - rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ - device=f0_values.device) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini - # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) if not self.flag_for_pulse: -# # for normal case - -# # To prevent torch.cumsum numerical overflow, -# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. -# # Buffer tmp_over_one_idx indicates the time step to add -1. -# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi -# tmp_over_one = torch.cumsum(rad_values, 1) % 1 -# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 -# cumsum_shift = torch.zeros_like(rad_values) -# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - -# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi - rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2), - scale_factor=1/self.upsample_scale, - mode="linear").transpose(1, 2) - -# tmp_over_one = torch.cumsum(rad_values, 1) % 1 -# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 -# cumsum_shift = torch.zeros_like(rad_values) -# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - + rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2) phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi - phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale, - scale_factor=self.upsample_scale, mode="linear").transpose(1, 2) + phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2) sines = torch.sin(phase) - else: # If necessary, make sure that the first time step of every # voiced segments is sin(pi) or cos(0) # This is used for pulse-train generation - # identify the last time step in unvoiced segments uv = self._f02uv(f0_values) uv_1 = torch.roll(uv, shifts=-1, dims=1) uv_1[:, -1, :] = 1 u_loc = (uv < 1) * (uv_1 > 0) - # get the instantanouse phase tmp_cumsum = torch.cumsum(rad_values, dim=1) # different batch needs to be processed differently @@ -212,11 +173,9 @@ class SineGen(torch.nn.Module): # each voiced segments tmp_cumsum[idx, :, :] = 0 tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum - # rad_values - tmp_cumsum: remove the accumulation of i.phase # within the previous voiced segment. i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) - # get the sines sines = torch.cos(i_phase * 2 * np.pi) return sines @@ -228,32 +187,27 @@ class SineGen(torch.nn.Module): output sine_tensor: tensor(batchsize=1, length, dim) output uv: tensor(batchsize=1, length, 1) """ - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, - device=f0.device) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) # fundamental component fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) - # generate sine waveforms sine_waves = self._f02sine(fn) * self.sine_amp - # generate uv signal # uv = torch.ones(f0.shape) # uv = uv * (f0 > self.voiced_threshold) uv = self._f02uv(f0) - # noise: for unvoiced should be similar to sine_amp # std = self.sine_amp/3 -> max value ~ self.sine_amp - # . for voiced regions is self.noise_std + # for voiced regions is self.noise_std noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) - # first: set the unvoiced part to 0 by uv # then: additive noise sine_waves = sine_waves * uv + noise return sine_waves, uv, noise -class SourceModuleHnNSF(torch.nn.Module): +class SourceModuleHnNSF(nn.Module): """ SourceModule for hn-nsf SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0) @@ -270,21 +224,17 @@ class SourceModuleHnNSF(torch.nn.Module): noise_source (batchsize, length 1) uv (batchsize, length, 1) """ - def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0): super(SourceModuleHnNSF, self).__init__() - self.sine_amp = sine_amp self.noise_std = add_noise_std - # to produce sine waveforms self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod) - # to merge source harmonics into a single excitation - self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) - self.l_tanh = torch.nn.Tanh() + self.l_linear = nn.Linear(harmonic_num + 1, 1) + self.l_tanh = nn.Tanh() def forward(self, x): """ @@ -297,80 +247,63 @@ class SourceModuleHnNSF(torch.nn.Module): with torch.no_grad(): sine_wavs, uv, _ = self.l_sin_gen(x) sine_merge = self.l_tanh(self.l_linear(sine_wavs)) - # source for noise branch, in the same shape as uv noise = torch.randn_like(uv) * self.sine_amp / 3 return sine_merge, noise, uv -def padDiff(x): - return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0) - -class Generator(torch.nn.Module): + +class Generator(nn.Module): def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size): super(Generator, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) - resblock = AdaINResBlock1 - self.m_source = SourceModuleHnNSF( sampling_rate=24000, upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size, harmonic_num=8, voiced_threshod=10) - self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size) + self.f0_upsamp = nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size) self.noise_convs = nn.ModuleList() self.noise_res = nn.ModuleList() - self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): self.ups.append(weight_norm( - ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), - k, u, padding=(k-u)//2))) - + nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel//(2**(i+1)) for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d, style_dim)) - + self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim)) c_cur = upsample_initial_channel // (2 ** (i + 1)) - - if i + 1 < len(upsample_rates): # + if i + 1 < len(upsample_rates): stride_f0 = np.prod(upsample_rates[i + 1:]) - self.noise_convs.append(Conv1d( + self.noise_convs.append(nn.Conv1d( gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) - self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim)) + self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim)) else: - self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)) - self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim)) - - + self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)) + self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim)) self.post_n_fft = gen_istft_n_fft - self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) + self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) - self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) - - + def forward(self, x, s, f0): with torch.no_grad(): f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t - har_source, noi_source, uv = self.m_source(f0) har_source = har_source.transpose(1, 2).squeeze(1) har_spec, har_phase = self.stft.transform(har_source) har = torch.cat([har_spec, har_phase], dim=1) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) + x = F.leaky_relu(x, negative_slope=0.1) x_source = self.noise_convs[i](har) x_source = self.noise_res[i](x_source, s) - x = self.ups[i](x) if i == self.num_upsamples - 1: x = self.reflection_pad(x) - x = x + x_source xs = None for j in range(self.num_kernels): @@ -384,38 +317,22 @@ class Generator(torch.nn.Module): spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) return self.stft.inverse(spec, phase) - - def fw_phase(self, x, s): - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i*self.num_kernels+j](x, s) - else: - xs += self.resblocks[i*self.num_kernels+j](x, s) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.reflection_pad(x) - x = self.conv_post(x) - spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) - phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) - return spec, phase - def remove_weight_norm(self): - print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) - +class UpSample1d(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + else: + return F.interpolate(x, scale_factor=2, mode='nearest') + + class AdainResBlk1d(nn.Module): - def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), - upsample='none', dropout_p=0.0): + def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0): super().__init__() self.actv = actv self.upsample_type = upsample @@ -423,13 +340,11 @@ class AdainResBlk1d(nn.Module): self.learned_sc = dim_in != dim_out self._build_weights(dim_in, dim_out, style_dim) self.dropout = nn.Dropout(dropout_p) - if upsample == 'none': self.pool = nn.Identity() else: self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) - - + def _build_weights(self, dim_in, dim_out, style_dim): self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) @@ -458,59 +373,36 @@ class AdainResBlk1d(nn.Module): out = self._residual(x, s) out = (out + self._shortcut(x)) / np.sqrt(2) return out - -class UpSample1d(nn.Module): - def __init__(self, layer_type): - super().__init__() - self.layer_type = layer_type - def forward(self, x): - if self.layer_type == 'none': - return x - else: - return F.interpolate(x, scale_factor=2, mode='nearest') class Decoder(nn.Module): - def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80, - resblock_kernel_sizes = [3,7,11], - upsample_rates = [10, 6], - upsample_initial_channel=512, - resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]], - upsample_kernel_sizes=[20, 12], - gen_istft_n_fft=20, gen_istft_hop_size=5): + def __init__(self, dim_in, style_dim, dim_out, + resblock_kernel_sizes, + upsample_rates, + upsample_initial_channel, + resblock_dilation_sizes, + upsample_kernel_sizes, + gen_istft_n_fft, gen_istft_hop_size): super().__init__() - - self.decode = nn.ModuleList() - self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) - + self.decode = nn.ModuleList() self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) - self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) - self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) - - self.asr_res = nn.Sequential( - weight_norm(nn.Conv1d(512, 64, kernel_size=1)), - ) - - + self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1))) self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size) - + def forward(self, asr, F0_curve, N, s): F0 = self.F0_conv(F0_curve.unsqueeze(1)) N = self.N_conv(N.unsqueeze(1)) - x = torch.cat([asr, F0, N], axis=1) x = self.encode(x, s) - asr_res = self.asr_res(asr) - res = True for block in self.decode: if res: @@ -518,6 +410,5 @@ class Decoder(nn.Module): x = block(x, s) if block.upsample_type != "none": res = False - x = self.generator(x, s, F0_curve) return x diff --git a/kokoro/models.py b/kokoro/models.py index 516ad7d..0344a6e 100644 --- a/kokoro/models.py +++ b/kokoro/models.py @@ -1,35 +1,28 @@ # https://github.com/yl4579/StyleTTS2/blob/main/models.py -from istftnet import AdaIN1d, Decoder -from munch import Munch -from pathlib import Path -from plbert import load_plbert -from torch.nn.utils import weight_norm, spectral_norm -import json +from .istftnet import AdaIN1d, AdainResBlk1d, Decoder +from torch.nn.utils import weight_norm +from transformers import AlbertConfig, AlbertModel import numpy as np -import os -import os.path as osp import torch import torch.nn as nn import torch.nn.functional as F -class LinearNorm(torch.nn.Module): + +class LinearNorm(nn.Module): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): super(LinearNorm, self).__init__() - self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) - - torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(w_init_gain)) + self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) + nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain)) def forward(self, x): return self.linear_layer(x) + class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-5): super().__init__() self.channels = channels self.eps = eps - self.gamma = nn.Parameter(torch.ones(channels)) self.beta = nn.Parameter(torch.zeros(channels)) @@ -37,12 +30,12 @@ class LayerNorm(nn.Module): x = x.transpose(1, -1) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) return x.transpose(1, -1) - + + class TextEncoder(nn.Module): def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)): super().__init__() self.embedding = nn.Embedding(n_symbols, channels) - padding = (kernel_size - 1) // 2 self.cnn = nn.ModuleList() for _ in range(depth): @@ -52,8 +45,6 @@ class TextEncoder(nn.Module): actv, nn.Dropout(0.2), )) - # self.cnn = nn.Sequential(*self.cnn) - self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True) def forward(self, x, input_lengths, m): @@ -61,234 +52,110 @@ class TextEncoder(nn.Module): x = x.transpose(1, 2) # [B, emb, T] m = m.to(input_lengths.device).unsqueeze(1) x.masked_fill_(m, 0.0) - for c in self.cnn: x = c(x) x.masked_fill_(m, 0.0) - x = x.transpose(1, 2) # [B, T, chn] - input_lengths = input_lengths.cpu().numpy() - x = nn.utils.rnn.pack_padded_sequence( - x, input_lengths, batch_first=True, enforce_sorted=False) - + x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False) self.lstm.flatten_parameters() x, _ = self.lstm(x) - x, _ = nn.utils.rnn.pad_packed_sequence( - x, batch_first=True) - + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) x = x.transpose(-1, -2) x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) - x_pad[:, :, :x.shape[-1]] = x x = x_pad.to(x.device) - x.masked_fill_(m, 0.0) - return x - def inference(self, x): - x = self.embedding(x) - x = x.transpose(1, 2) - x = self.cnn(x) - x = x.transpose(1, 2) - self.lstm.flatten_parameters() - x, _ = self.lstm(x) - return x - - def length_to_mask(self, lengths): - mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) - mask = torch.gt(mask+1, lengths.unsqueeze(1)) - return mask - -class UpSample1d(nn.Module): - def __init__(self, layer_type): - super().__init__() - self.layer_type = layer_type - - def forward(self, x): - if self.layer_type == 'none': - return x - else: - return F.interpolate(x, scale_factor=2, mode='nearest') - -class AdainResBlk1d(nn.Module): - def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), - upsample='none', dropout_p=0.0): - super().__init__() - self.actv = actv - self.upsample_type = upsample - self.upsample = UpSample1d(upsample) - self.learned_sc = dim_in != dim_out - self._build_weights(dim_in, dim_out, style_dim) - self.dropout = nn.Dropout(dropout_p) - - if upsample == 'none': - self.pool = nn.Identity() - else: - self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) - - - def _build_weights(self, dim_in, dim_out, style_dim): - self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) - self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) - self.norm1 = AdaIN1d(style_dim, dim_in) - self.norm2 = AdaIN1d(style_dim, dim_out) - if self.learned_sc: - self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) - - def _shortcut(self, x): - x = self.upsample(x) - if self.learned_sc: - x = self.conv1x1(x) - return x - - def _residual(self, x, s): - x = self.norm1(x, s) - x = self.actv(x) - x = self.pool(x) - x = self.conv1(self.dropout(x)) - x = self.norm2(x, s) - x = self.actv(x) - x = self.conv2(self.dropout(x)) - return x - - def forward(self, x, s): - out = self._residual(x, s) - out = (out + self._shortcut(x)) / np.sqrt(2) - return out - class AdaLayerNorm(nn.Module): def __init__(self, style_dim, channels, eps=1e-5): super().__init__() self.channels = channels self.eps = eps - self.fc = nn.Linear(style_dim, channels*2) def forward(self, x, s): x = x.transpose(-1, -2) x = x.transpose(1, -1) - h = self.fc(s) h = h.view(h.size(0), h.size(1), 1) gamma, beta = torch.chunk(h, chunks=2, dim=1) gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) - - x = F.layer_norm(x, (self.channels,), eps=self.eps) x = (1 + gamma) * x + beta return x.transpose(1, -1).transpose(-1, -2) + class ProsodyPredictor(nn.Module): - def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1): - super().__init__() - - self.text_encoder = DurationEncoder(sty_dim=style_dim, - d_model=d_hid, - nlayers=nlayers, - dropout=dropout) - + super().__init__() + self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout) self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) self.duration_proj = LinearNorm(d_hid, max_dur) - self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) self.F0 = nn.ModuleList() self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) - self.N = nn.ModuleList() self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) - self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) - def forward(self, texts, style, text_lengths, alignment, m): d = self.text_encoder(texts, style, text_lengths, m) - batch_size = d.shape[0] text_size = d.shape[1] - - # predict duration input_lengths = text_lengths.cpu().numpy() - x = nn.utils.rnn.pack_padded_sequence( - d, input_lengths, batch_first=True, enforce_sorted=False) - + x = nn.utils.rnn.pack_padded_sequence(d, input_lengths, batch_first=True, enforce_sorted=False) m = m.to(text_lengths.device).unsqueeze(1) - self.lstm.flatten_parameters() x, _ = self.lstm(x) - x, _ = nn.utils.rnn.pad_packed_sequence( - x, batch_first=True) - + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]]) - x_pad[:, :x.shape[1], :] = x x = x_pad.to(x.device) - - duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training)) - + duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False)) en = (d.transpose(-1, -2) @ alignment) - return duration.squeeze(-1), en - + def F0Ntrain(self, x, s): x, _ = self.shared(x.transpose(-1, -2)) - F0 = x.transpose(-1, -2) for block in self.F0: F0 = block(F0, s) F0 = self.F0_proj(F0) - N = x.transpose(-1, -2) for block in self.N: N = block(N, s) N = self.N_proj(N) - return F0.squeeze(1), N.squeeze(1) - - def length_to_mask(self, lengths): - mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) - mask = torch.gt(mask+1, lengths.unsqueeze(1)) - return mask + class DurationEncoder(nn.Module): - def __init__(self, sty_dim, d_model, nlayers, dropout=0.1): super().__init__() self.lstms = nn.ModuleList() for _ in range(nlayers): - self.lstms.append(nn.LSTM(d_model + sty_dim, - d_model // 2, - num_layers=1, - batch_first=True, - bidirectional=True, - dropout=dropout)) + self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout)) self.lstms.append(AdaLayerNorm(sty_dim, d_model)) - - self.dropout = dropout self.d_model = d_model self.sty_dim = sty_dim def forward(self, x, style, text_lengths, m): masks = m.to(text_lengths.device) - x = x.permute(2, 0, 1) s = style.expand(x.shape[0], x.shape[1], -1) x = torch.cat([x, s], axis=-1) x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0) - x = x.transpose(0, 1) input_lengths = text_lengths.cpu().numpy() x = x.transpose(-1, -2) - for block in self.lstms: if isinstance(block, AdaLayerNorm): x = block(x.transpose(-1, -2), style).transpose(-1, -2) @@ -302,71 +169,73 @@ class DurationEncoder(nn.Module): x, _ = block(x) x, _ = nn.utils.rnn.pad_packed_sequence( x, batch_first=True) - x = F.dropout(x, p=self.dropout, training=self.training) + x = F.dropout(x, p=self.dropout, training=False) x = x.transpose(-1, -2) - x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) - x_pad[:, :, :x.shape[-1]] = x x = x_pad.to(x.device) - return x.transpose(-1, -2) - - def inference(self, x, style): - x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model) - style = style.expand(x.shape[0], x.shape[1], -1) - x = torch.cat([x, style], axis=-1) - src = self.pos_encoder(x) - output = self.transformer_encoder(src).transpose(0, 1) - return output - - def length_to_mask(self, lengths): + + +# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py +class CustomAlbert(AlbertModel): + def forward(self, *args, **kwargs): + outputs = super().forward(*args, **kwargs) + return outputs.last_hidden_state + + +class KModel(nn.Module): + def __init__(self, config, path): + super().__init__() + self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) + self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) + self.predictor = ProsodyPredictor( + style_dim=config['style_dim'], d_hid=config['hidden_dim'], + nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout'] + ) + self.text_encoder = TextEncoder( + channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'], + depth=config['n_layer'], n_symbols=config['n_token'] + ) + self.decoder = Decoder( + dim_in=config['hidden_dim'], style_dim=config['style_dim'], + dim_out=config['n_mels'], **config['istftnet'] + ) + for key, state_dict in torch.load(path, map_location='cpu', weights_only=True).items(): + assert hasattr(self, key), key + try: + getattr(self, key).load_state_dict(state_dict) + except: + state_dict = {k[7:]: v for k, v in state_dict.items()} + getattr(self, key).load_state_dict(state_dict, strict=False) + + @classmethod + def length_to_mask(cls, lengths): mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask+1, lengths.unsqueeze(1)) return mask -# https://github.com/yl4579/StyleTTS2/blob/main/utils.py -def recursive_munch(d): - if isinstance(d, dict): - return Munch((k, recursive_munch(v)) for k, v in d.items()) - elif isinstance(d, list): - return [recursive_munch(v) for v in d] - else: - return d - -def build_model(path, device): - config = Path(__file__).parent / 'config.json' - assert config.exists(), f'Config path incorrect: config.json not found at {config}' - with open(config, 'r') as r: - args = recursive_munch(json.load(r)) - assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}' - decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, - resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, - upsample_rates = args.decoder.upsample_rates, - upsample_initial_channel=args.decoder.upsample_initial_channel, - resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, - upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, - gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size) - text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token) - predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout) - bert = load_plbert() - bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim) - for parent in [bert, bert_encoder, predictor, decoder, text_encoder]: - for child in parent.children(): - if isinstance(child, nn.RNNBase): - child.flatten_parameters() - model = Munch( - bert=bert.to(device).eval(), - bert_encoder=bert_encoder.to(device).eval(), - predictor=predictor.to(device).eval(), - decoder=decoder.to(device).eval(), - text_encoder=text_encoder.to(device).eval(), - ) - for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items(): - assert key in model, key - try: - model[key].load_state_dict(state_dict) - except: - state_dict = {k[7:]: v for k, v in state_dict.items()} - model[key].load_state_dict(state_dict, strict=False) - return model + @torch.no_grad() + def forward(self, input_ids, ref_s, speed): + device = ref_s.device + input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(device) + input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(device) + text_mask = type(self).length_to_mask(input_lengths).to(device) + bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int()) + d_en = self.bert_encoder(bert_dur).transpose(-1, -2) + s = ref_s[:, 128:] + d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask) + x, _ = self.predictor.lstm(d) + duration = self.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) / speed + pred_dur = torch.round(duration).clamp(min=1).long() + pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item()) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1 + c_frame += pred_dur[0,i].item() + en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device) + F0_pred, N_pred = self.predictor.F0Ntrain(en, s) + t_en = self.text_encoder(input_ids, input_lengths, text_mask) + asr = t_en @ pred_aln_trg.unsqueeze(0).to(device) + return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy() diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py new file mode 100644 index 0000000..38a803e --- /dev/null +++ b/kokoro/pipeline.py @@ -0,0 +1,103 @@ +from .models import KModel +from huggingface_hub import hf_hub_download +from misaki import en, espeak +import json +import os +import re +import torch + +LANG_CODES = dict( + a='American English', + b='British English', +) +REPO_ID = 'hexgrad/Kokoro-82M' + +class KPipeline: + def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False): + assert lang_code in LANG_CODES, (lang_code, LANG_CODES) + self.lang_code = lang_code + if config_path is None: + config_path = hf_hub_download(repo_id=REPO_ID, filename='config.json') + assert os.path.exists(config_path) + with open(config_path, 'r') as r: + config = json.load(r) + if model_path is None: + model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth') + assert os.path.exists(model_path) + try: + fallback = espeak.EspeakFallback(british=lang_code=='b') + except Exception as e: + print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e) + fallback = None + self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) + self.vocab = config['vocab'] + self.model = KModel(config, model_path) + self.voices = {} + + def load_voice(self, voice): + if voice in self.voices: + return + v = voice.split('/')[-1] + if not v.startswith(self.lang_code): + v = LANG_CODES.get(v, voice) + p = LANG_CODES.get(self.lang_code, self.lang_code) + print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.') + voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt') + assert os.path.exists(voice_path) + self.voices[voice] = torch.load(voice_path, weights_only=True) + + @classmethod + def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}): + for w in waterfall: + z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None) + if z is not None: + z += 1 + if z < len(pairs) and pairs[z][1].strip() in bumps: + z += 1 + _, ps = zip(*pairs[:z]) + if next_count - len(''.join(ps)) <= 510: + return z + return len(pairs) + + def tokenize(self, tokens): + pairs = [] + count = 0 + for w in tokens: + for t in (w if isinstance(w, list) else [w]): + if t.phonemes is None: + continue + next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else '' + next_ps += ''.join(filter(lambda p: p in self.vocab, t.phonemes.replace('ɾ', 'T'))) # American English: ɾ => T + next_ps += ' ' if t.whitespace else '' + next_count = count + len(next_ps.rstrip()) + if next_count > 510: + z = type(self).waterfall_last(pairs, next_count) + text, ps = zip(*pairs[:z]) + ps = ''.join(ps) + yield ''.join(text).strip(), ps.strip() + pairs = pairs[z:] + count -= len(ps) + if not pairs: + next_ps = next_ps.lstrip() + pairs.append((t.text + t.whitespace, next_ps)) + count += len(next_ps) + if pairs: + text, ps = zip(*pairs) + yield ''.join(text).strip(), ''.join(ps).strip() + + def __call__(self, text, voice='af', speed=1, split_pattern=r'\n+'): + assert isinstance(text, str) or isinstance(text, list), type(text) + self.load_voice(voice) + if isinstance(text, str) and split_pattern: + text = re.split(split_pattern, text.strip()) + for graphemes in text: + _, tokens = self.g2p(graphemes) + for gs, ps in self.tokenize(tokens): + if not ps: + continue + elif len(ps) > 510: + print('TODO: Unexpected len(ps) > 510', len(ps), ps) + continue + input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps))) + assert input_ids and len(input_ids) <= 510, input_ids + yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed) diff --git a/kokoro/plbert.py b/kokoro/plbert.py deleted file mode 100644 index ef54f57..0000000 --- a/kokoro/plbert.py +++ /dev/null @@ -1,15 +0,0 @@ -# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py -from transformers import AlbertConfig, AlbertModel - -class CustomAlbert(AlbertModel): - def forward(self, *args, **kwargs): - # Call the original forward method - outputs = super().forward(*args, **kwargs) - # Only return the last_hidden_state - return outputs.last_hidden_state - -def load_plbert(): - plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1} - albert_base_configuration = AlbertConfig(**plbert_config) - bert = CustomAlbert(albert_base_configuration) - return bert diff --git a/setup.py b/setup.py index f30931b..9f05033 100644 --- a/setup.py +++ b/setup.py @@ -2,11 +2,12 @@ from setuptools import setup, find_packages setup( name='kokoro', # Name of the package - version='0.1.0', # Initial version + version='0.2.0', # Initial version packages=find_packages(), # Automatically finds packages install_requires=[ # List your dependencies here + 'huggingface_hub', + 'misaki[en]>=0.6.1', 'numpy', - 'phonemizer', 'scipy', 'torch', 'transformers',