Skip to content

MPS and XPU support #1075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions cosyvoice/bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)

logging.getLogger("matplotlib").setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -53,13 +54,20 @@ def get_args():

def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s")
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

if torch.cuda.is_available():
device = torch.device("cuda:{}".format(args.gpu))
elif torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.xpu.is_available():
device = torch.device("xpu")
else:
device = torch.device("cpu")

try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
Expand All @@ -74,15 +82,14 @@ def main():

model.load(args.llm_model, args.flow_model, args.hifigan_model)

test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs["data_pipeline"], mode="inference", shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

sample_rate = configs['sample_rate']
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
fn = os.path.join(args.result_dir, "wav.scp")
f = open(fn, "w")
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
Expand All @@ -98,28 +105,35 @@ def main():
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
if args.mode == "sft":
model_input = {"text": tts_text_token, "text_len": tts_text_token_len, "llm_embedding": spk_embedding, "flow_embedding": spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
model_input = {
"text": tts_text_token,
"text_len": tts_text_token_len,
"prompt_text": text_token,
"prompt_text_len": text_token_len,
"llm_prompt_speech_token": speech_token,
"llm_prompt_speech_token_len": speech_token_len,
"flow_prompt_speech_token": speech_token,
"flow_prompt_speech_token_len": speech_token_len,
"prompt_speech_feat": speech_feat,
"prompt_speech_feat_len": speech_feat_len,
"llm_embedding": utt_embedding,
"flow_embedding": utt_embedding,
}
tts_speeches = []
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches.append(model_output["tts_speech"])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
logging.info("Result wav.scp saved in {}".format(fn))


if __name__ == '__main__':
if __name__ == "__main__":
main()
10 changes: 8 additions & 2 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type


class CosyVoice:

def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
Expand All @@ -45,7 +44,8 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):

if self.gpu_available() and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
Expand All @@ -62,6 +62,12 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
self.fp16)
del configs

def gpu_available(self) -> bool:
"""check torch GPU device"""
if torch.cuda.is_available() or torch.backends.mps.is_available() or torch.xpu.is_available():
return True
return False

def list_available_spks(self):
spks = list(self.frontend.spk2info.keys())
return spks
Expand Down
15 changes: 14 additions & 1 deletion cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation

def set_device() -> torch.device:
"""Assign GPU device if possible"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
elif torch.xpu.is_available():
return torch.device("xpu")
else:
return torch.device("cpu")

class CosyVoiceFrontEnd:

Expand All @@ -45,9 +55,12 @@ def __init__(self,
speech_tokenizer_model: str,
spk2info: str = '',
allowed_special: str = 'all'):

self.device = set_device()

self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
Expand Down
47 changes: 42 additions & 5 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,38 @@
from cosyvoice.utils.file_utils import convert_onnx_to_trt


def set_device() -> torch.device:
"""Assign GPU device if possible"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
elif torch.xpu.is_available():
return torch.device("xpu")
else:
return torch.device("cpu")

def clear_cache() -> None:
"""Empty device caches"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()


class CosyVoiceModel:

def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,

fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

self.device = set_device()

self.llm = llm
self.flow = flow
self.hift = hift
Expand Down Expand Up @@ -81,7 +105,7 @@ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
self.flow.encoder = flow_encoder

def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
assert torch.cuda.is_available() or torch.backends.mps.is_available() or torch.xpu.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model):
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
if os.path.getsize(flow_decoder_estimator_model) == 0:
Expand Down Expand Up @@ -231,7 +255,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
import gc
gc.collect()
clear_cache()



class CosyVoice2Model(CosyVoiceModel):
Expand All @@ -242,7 +269,16 @@ def __init__(self,
hift: torch.nn.Module,
fp16: bool = False,
use_flow_cache: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
self.device = torch.device('cuda')
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.xpu.is_available():
self.device = torch.device('xpu')
else:
self.device = torch.device('cpu')

self.llm = llm
self.flow = flow
self.hift = hift
Expand Down Expand Up @@ -405,4 +441,5 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()

clear_cache()
8 changes: 7 additions & 1 deletion cosyvoice/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,13 @@ def set_all_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() is True:
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif torch.backends.mps.is_available() is True:
torch.mps.manual_seed(seed)
elif torch.xpu.is_available() is True:
torch.xpu.manual_seed(seed)


def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
Expand Down
8 changes: 7 additions & 1 deletion cosyvoice/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def __init__(self, gan: bool = False):
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get('RANK', 0))
self.device = torch.device('cuda:{}'.format(self.rank))
self.device = torch.device("cpu")
if torch.cuda.is_available():
self.device = torch.device('cuda:{}'.format(self.rank))
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
elif torch.xpu.is_available():
self.device = torch.device('xpu')

def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
''' Train one epoch
Expand Down
Loading