From 9d22103540c08b70d6843c31e60f3cf71c4fa929 Mon Sep 17 00:00:00 2001 From: Xwdit <44023235+Xwdit@users.noreply.github.com> Date: Sat, 23 Mar 2024 11:48:26 +0100 Subject: [PATCH] Init webui implement --- models/megatts2.py | 32 ++++++++++++++------------------ requirements.txt | 3 ++- webui.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 19 deletions(-) create mode 100644 webui.py diff --git a/models/megatts2.py b/models/megatts2.py index 357fce3..c9804b0 100644 --- a/models/megatts2.py +++ b/models/megatts2.py @@ -1,3 +1,5 @@ +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F @@ -323,18 +325,16 @@ def __init__( self.hifi_gan.eval() def forward( - self, - wavs_dir: str, - text: str, + self, + audio_paths: List[str], + text: str ): mels_prompt = None # Make mrte mels - wavs = glob.glob(f'{wavs_dir}/*.wav') mels = torch.empty(0) - for wav in wavs: - y = librosa.load(wav, sr=HIFIGAN_SR)[0] + for audio_path in audio_paths: + y = librosa.load(audio_path, sr=HIFIGAN_SR)[0] y = librosa.util.normalize(y) - # y = librosa.effects.trim(y, top_db=20)[0] y = torch.from_numpy(y) mel_spec = extract_mel_spec(y).transpose(0, 1) @@ -346,30 +346,26 @@ def forward( mels = mels.unsqueeze(0) # G2P - phone_tokens = self.ttc.phone2token( - self.tt.tokenize_lty(self.tt.tokenize(text))) + phone_tokens = self.ttc.phone2token(self.tt.tokenize_lty(self.tt.tokenize(text))) phone_tokens = phone_tokens.unsqueeze(0) with torch.no_grad(): tc_latent = self.generator.mrte.tc_latent(phone_tokens, mels) dt = self.adm.infer(tc_latent)[..., 0] tc_latent_expand = self.lr(tc_latent, dt) - tc_latent = F.max_pool1d(tc_latent_expand.transpose( - 1, 2), 8, ceil_mode=True).transpose(1, 2) + tc_latent = F.max_pool1d(tc_latent_expand.transpose(1, 2), 8, ceil_mode=True).transpose(1, 2) p_codes = self.plm.infer(tc_latent) zq = self.generator.vqpe.vq.decode(p_codes.unsqueeze(0)) - zq = rearrange( - zq, "B D T -> B T D").unsqueeze(2).contiguous().expand(-1, -1, 8, -1) + zq = rearrange(zq, "B D T -> B T D").unsqueeze(2).contiguous().expand(-1, -1, 8, -1) zq = rearrange(zq, "B T S D -> B (T S) D") - x = torch.cat( - [tc_latent_expand, zq[:, :tc_latent_expand.shape[1], :]], dim=-1) + x = torch.cat([tc_latent_expand, zq[:, :tc_latent_expand.shape[1], :]], dim=-1) x = rearrange(x, 'B T D -> B D T') x = self.generator.decoder(x) audio = self.hifi_gan.decode_batch(x.cpu()) - audio_prompt = self.hifi_gan.decode_batch( - mels_prompt.unsqueeze(0).transpose(1, 2).cpu()) + audio_prompt = self.hifi_gan.decode_batch(mels_prompt.unsqueeze(0).transpose(1, 2).cpu()) audio = torch.cat([audio_prompt, audio], dim=-1) - torchaudio.save('test.wav', audio[0], HIFIGAN_SR) + return audio.squeeze() + diff --git a/requirements.txt b/requirements.txt index c1fd3a7..75c322e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ torchaudio==2.1.0+cu118 torchvision==0.16.0+cu118 lightning==2.1.2 lhotse==1.17.0 -h5py \ No newline at end of file +h5py +gradio \ No newline at end of file diff --git a/webui.py b/webui.py new file mode 100644 index 0000000..aee2a46 --- /dev/null +++ b/webui.py @@ -0,0 +1,44 @@ +import gradio as gr +from models.megatts import Megatts +from modules.tokenizer import HIFIGAN_SR + +megatts = Megatts( + g_ckpt='generator.ckpt', + g_config='configs/config_gan.yaml', + plm_ckpt='plm.ckpt', + plm_config='configs/config_plm.yaml', + adm_ckpt='adm.ckpt', + adm_config='configs/config_adm.yaml', + symbol_table='unique_text_tokens.k2symbols' +) +megatts.eval() + +def generate_audio( + audio_files, + text +): + audio_paths = [audio_file.name for audio_file in audio_files] + audio_tensor = megatts.forward(audio_paths, text) + audio_numpy = audio_tensor.cpu().numpy() + return audio_numpy, HIFIGAN_SR + +iface = gr.Interface( + fn=generate_audio, + inputs=[ + gr.inputs.File( + type="file", + label="Upload Audio Files", + multiple=True, + filetype="audio/wav" + ), + gr.inputs.Textbox(lines=2, label="Input Text") + ], + outputs=[ + gr.outputs.Audio(type="numpy", label="Generated Audio") + ], + title="MegaTTS2 Speech Synthesis", + description="Upload your audio files (only .wav format) and enter text to generate speech." +) + +if __name__ == "__main__": + iface.launch()