Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b4f06f3
feat: add flow_cache and hift_cache
weedge Feb 19, 2025
1b096e8
feat: add tts static batch stream generate waveform
weedge Feb 20, 2025
a79c5af
feat: add tts stream inference
weedge Feb 20, 2025
c867728
feat: add tts static batch stream to merge
weedge Feb 20, 2025
46f22e9
streamer skip prefix prompt
weedge Feb 20, 2025
332b5f9
fix: step1lm generated token ids - 65536 to get vq codes
weedge Feb 20, 2025
c5314f2
fix: lock
weedge Feb 20, 2025
0634c43
feat: add stream_factor cmd param and TTS_TEXT env param
weedge Feb 20, 2025
de512a4
feat: add stream-factor cmd param and close debug print
weedge Feb 20, 2025
a62ae84
fix: fast path to check params
weedge Feb 20, 2025
2813963
merge stream
weedge Feb 20, 2025
e9d1c7c
fix typo
weedge Feb 20, 2025
075feef
fix typo
weedge Feb 20, 2025
7ce0a88
fix: tts instance share streamer -> gen session streamer
weedge Feb 21, 2025
e970f3f
merge
weedge Feb 21, 2025
90a23e0
add speaker_file_path
weedge Feb 22, 2025
b337668
add speaker_file_path
weedge Feb 22, 2025
1d8cb6e
add speaker_file_path
weedge Feb 22, 2025
68752ba
add speaker_file_path
weedge Feb 22, 2025
a833d67
feat: add modal run step tts/voice
weedge Feb 22, 2025
698f55d
Merge branch 'feat/flow_hifi_cache' into feat/dev
weedge Feb 23, 2025
3f17784
add ThreadSafeDict for tts lm,flow,hift session gen
weedge Feb 23, 2025
73fb149
fix: token_overlap_len
weedge Feb 23, 2025
b4a2cf6
fix: token_overlap_len
weedge Feb 23, 2025
ec9d5c6
remove print
weedge Feb 23, 2025
71a83d7
feat: add token2wav with flow hift session
weedge Feb 23, 2025
a25e16d
remove python-dotenv
weedge Feb 23, 2025
8ba35a1
fix token2wav
weedge Feb 23, 2025
733e052
fix token2wav
weedge Feb 23, 2025
42ff2bf
fix token2wav
weedge Feb 23, 2025
73897a3
fix token2wav mel
weedge Feb 23, 2025
37c6eee
fix: flow infer return mel float
weedge Feb 23, 2025
5a426dc
fix: lock
weedge Feb 23, 2025
8a2b3f3
token2wav load to cpu
weedge Feb 23, 2025
bb0e08e
feat: add token overlap to gen
weedge Feb 24, 2025
9866e8c
feat: add max_stream_factor for dynamic batch stream
weedge Feb 24, 2025
6b56e20
change device_map default auto
weedge Feb 24, 2025
680e934
change max_batch_size
weedge Feb 24, 2025
376d762
add: stream_factor stream_scale_factor max_stream_factor token_overla…
weedge Feb 24, 2025
7f623b0
fix tts_inference_stream params
weedge Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading
import uuid
import time
from tqdm import tqdm
Expand All @@ -23,10 +24,10 @@


class CosyVoice:

def __init__(
self,
model_dir,
token_overlap_len: int = 20,
):
self.model_dir = model_dir
with open("{}/cosyvoice.yaml".format(model_dir), "r") as f:
Expand All @@ -36,7 +37,11 @@ def __init__(
"{}/campplus.onnx".format(model_dir),
"{}/speech_tokenizer_v1.onnx".format(model_dir),
)
self.model = CosyVoiceModel(configs["flow"], configs["hift"])
self.model = CosyVoiceModel(
configs["flow"],
configs["hift"],
token_overlap_len=token_overlap_len,
)
self.model.load(
"{}/flow.pt".format(model_dir),
"{}/hift.pt".format(model_dir),
Expand All @@ -53,11 +58,9 @@ def token_to_wav_offline(
prompt_token_len,
embedding,
):
tts_mel = self.model.flow.inference(
tts_mel, _ = self.model.flow.inference(
token=speech_token.to(self.model.device),
token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(
self.model.device
),
token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(self.model.device),
prompt_token=prompt_token.to(self.model.device),
prompt_token_len=prompt_token_len.to(self.model.device),
prompt_feat=speech_feat.to(self.model.device),
Expand Down
111 changes: 110 additions & 1 deletion cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,131 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import numpy as np
import torch
from torch.nn import functional as F

from cosyvoice.utils.common import fade_in_out, ThreadSafeDict

class CosyVoiceModel:

class CosyVoiceModel:
def __init__(
self,
flow: torch.nn.Module,
hift: torch.nn.Module,
token_overlap_len: int = 20,
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.flow = flow
self.hift = hift

# dict used to store session related variable
self.mel_overlap_dict = ThreadSafeDict()
self.flow_cache_dict = ThreadSafeDict()
self.hift_cache_dict = ThreadSafeDict()

# mel fade in out
self.mel_overlap_len = int(token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 20
self.source_cache_len = int(self.mel_cache_len * 256)
# speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len)

def load(self, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
self.flow.to(self.device).eval()
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
self.hift.to(self.device).eval()

def token2wav(
self,
token,
prompt_token,
prompt_feat,
embedding,
session_id,
finalize=False,
speed=1.0,
):
if self.flow_cache_dict.get(session_id) is None:
self.mel_overlap_dict.set(session_id, torch.zeros(1, 80, 0))
self.flow_cache_dict.set(session_id, torch.zeros(1, 80, 0, 2))

tts_mel, flow_cache = self.flow.inference(
token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
flow_cache=self.flow_cache_dict.get(session_id),
)
self.flow_cache_dict.set(session_id, flow_cache)

# mel overlap fade in out
if self.mel_overlap_dict.get(session_id).shape[2] != 0:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict.get(session_id), self.mel_window)

hift_cache_source = None
if self.hift_cache_dict.get(session_id) is not None:
# append hift cache
hift_cache_mel, hift_cache_source = (
self.hift_cache_dict.get(session_id)["mel"],
self.hift_cache_dict.get(session_id)["source"],
)
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)

# keep overlap mel and hift cache
if finalize is False:
self.mel_overlap_dict.set(session_id, tts_mel[:, :, -self.mel_overlap_len :])

tts_mel = tts_mel[:, :, : -self.mel_overlap_len]
tts_speech, tts_source = self.hift.inference(
mel=tts_mel, cache_source=hift_cache_source
)

if self.hift_cache_dict.get(session_id) is not None:
tts_speech = fade_in_out(
tts_speech, self.hift_cache_dict.get(session_id)["speech"], self.speech_window
)
self.hift_cache_dict.set(
session_id,
{
"mel": tts_mel[:, :, -self.mel_cache_len :],
"source": tts_source[:, :, -self.source_cache_len :],
"speech": tts_speech[:, -self.source_cache_len :],
},
)

tts_speech = tts_speech[:, : -self.source_cache_len]

logging.info("tts_speech: {}".format(tts_speech.shape))
else: # finalize
if speed != 1.0:
assert (
self.hift_cache_dict.get(session_id) is None
), "speed change only support non-stream inference mode"
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear")
tts_speech, tts_source = self.hift.inference(
mel=tts_mel, cache_source=hift_cache_source
)
if self.hift_cache_dict.get(session_id) is not None:
tts_speech = fade_in_out(
tts_speech, self.hift_cache_dict.get(session_id)["speech"], self.speech_window
)

self.mel_overlap_dict.pop(session_id)
self.hift_cache_dict.pop(session_id)
self.flow_cache_dict.pop(session_id)
logging.info("finalize tts_speech: {}".format(tts_speech.shape))

return tts_speech.cpu()
32 changes: 16 additions & 16 deletions cosyvoice/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def forward(
conds = conds.transpose(1, 2)

mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(
feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest"
).squeeze(dim=1)
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
Expand All @@ -143,6 +141,7 @@ def inference(
prompt_feat,
prompt_feat_len,
embedding,
flow_cache=None,
):
assert token.shape[0] == 1
# xvec projection
Expand All @@ -159,11 +158,14 @@ def inference(
token = self.input_embedding(torch.clamp(token, min=0))
h, _ = self.encoder.inference(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(
token_len2
/ self.input_frame_rate
* self.mel_feat_conf["sampling_rate"]
/ self.mel_feat_conf["hop_size"]
mel_len1, mel_len2 = (
prompt_feat.shape[1],
int(
token_len2
/ self.input_frame_rate
* self.mel_feat_conf["sampling_rate"]
/ self.mel_feat_conf["hop_size"]
),
)

h, _ = self.length_regulator.inference(
Expand All @@ -174,23 +176,21 @@ def inference(
)

# get conditions
conds = torch.zeros(
[1, mel_len1 + mel_len2, self.output_size], device=token.device
)
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)

# mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
mask = torch.ones(
[1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16
)
feat = self.decoder(
mask = torch.ones([1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16)
feat, flow_cache = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
flow_cache=flow_cache,
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat
return feat.float(), flow_cache
30 changes: 19 additions & 11 deletions cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def forward(
temperature=1.0,
spks=None,
cond=None,
prompt_len=0,
# flow_cache=torch.zeros(1, 80, 0, 2),
flow_cache=None,
):
"""Forward diffusion

Expand All @@ -69,13 +72,24 @@ def forward(
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature

if flow_cache is not None:
cache_size = flow_cache.shape[2]
# fix prompt and overlap part mu and z
if cache_size != 0:
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)

t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == "cosine":
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(
z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
)
), flow_cache

@torch.inference_mode()
def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))):
Expand All @@ -96,9 +110,7 @@ def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))):
static_mask = torch.ones(
1, 1, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
)
static_spks = torch.randn(
1, 80, device=torch.device("cuda"), dtype=torch.bfloat16
)
static_spks = torch.randn(1, 80, device=torch.device("cuda"), dtype=torch.bfloat16)
static_cond = torch.randn(
1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32
)
Expand Down Expand Up @@ -231,9 +243,7 @@ def _solve_euler_impl(self, x, t_span, mu, mask, spks, cond):
mu_double = torch.cat([mu, torch.zeros_like(mu)], dim=0)
t_double = torch.cat([t, t], dim=0)
spks_double = (
torch.cat([spks, torch.zeros_like(spks)], dim=0)
if spks is not None
else None
torch.cat([spks, torch.zeros_like(spks)], dim=0) if spks is not None else None
)
cond_double = torch.cat([cond, torch.zeros_like(cond)], dim=0)

Expand Down Expand Up @@ -309,7 +319,5 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
cond = cond * cfg_mask.view(-1, 1, 1)

pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
Loading