Skip to content

Commit ecc7b52

Browse files
committed
format: revert
1 parent 8fcc0cd commit ecc7b52

File tree

11 files changed

+800
-266
lines changed

11 files changed

+800
-266
lines changed

ChatTTS/core.py

+89-84
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import logging
33
import tempfile
4+
import uuid
45
from dataclasses import dataclass, asdict
56
from typing import Literal, Optional, List, Tuple, Dict, Union
67
from json import load
@@ -200,9 +201,10 @@ def infer(
200201
do_homophone_replacement=True,
201202
params_refine_text=RefineTextParams(),
202203
params_infer_code=InferCodeParams(),
204+
stream_batch_size=16,
203205
):
204206
self.context.set(False)
205-
res_gen = self._infer(
207+
return self._infer(
206208
text,
207209
stream,
208210
lang,
@@ -213,11 +215,8 @@ def infer(
213215
do_homophone_replacement,
214216
params_refine_text,
215217
params_infer_code,
218+
stream_batch_size,
216219
)
217-
if stream:
218-
return res_gen
219-
else:
220-
return next(res_gen)
221220

222221
def interrupt(self):
223222
self.context.set(True)
@@ -339,7 +338,7 @@ def _load(
339338

340339
return self.has_loaded()
341340

342-
def _infer(
341+
async def _infer(
343342
self,
344343
text,
345344
stream=False,
@@ -351,6 +350,7 @@ def _infer(
351350
do_homophone_replacement=True,
352351
params_refine_text=RefineTextParams(),
353352
params_infer_code=InferCodeParams(),
353+
stream_batch_size=16,
354354
):
355355

356356
assert self.has_loaded(use_decoder=use_decoder)
@@ -384,41 +384,38 @@ def _infer(
384384
yield text
385385
return
386386

387-
if stream:
388-
length = 0
389-
pass_batch_count = 0
390-
for result in self._infer_code(
387+
length = 0
388+
async for result in self._infer_code(
391389
text,
392390
stream,
393391
self.device,
394392
use_decoder,
395393
params_infer_code,
394+
stream_batch_size,
396395
):
397396
wavs = self._decode_to_wavs(
398397
result.hiddens if use_decoder else result.ids,
399398
use_decoder,
400399
)
401-
result.destroy()
402-
if stream:
403-
pass_batch_count += 1
404-
if pass_batch_count <= params_infer_code.pass_first_n_batches:
405-
continue
406-
a = length
407-
b = a + params_infer_code.stream_speed
408-
if b > wavs.shape[1]:
409-
b = wavs.shape[1]
410-
new_wavs = wavs[:, a:b]
411-
length = b
412-
yield new_wavs
400+
if result.finished:
401+
yield wavs[:, length:]
413402
else:
414-
yield wavs
415-
if stream:
416-
new_wavs = wavs[:, length:]
417-
# Identify rows with non-zero elements using np.any
418-
# keep_rows = np.any(array != 0, axis=1)
419-
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
420-
# Filter both rows and columns using slicing
421-
yield new_wavs[:][:, keep_cols]
403+
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
404+
keep_cols = np.sum(abs(wavs[0][length:]) > 1e-6, axis=0) > 0
405+
406+
import librosa
407+
silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
408+
silence_left = 0
409+
if len(silence_intervals) == 0:
410+
silence_left = len(wavs[0])
411+
else:
412+
for i in range(len(silence_intervals)):
413+
silence_left = silence_intervals[i][0]
414+
if silence_left <= 0:
415+
continue
416+
new_wavs = wavs[:, length : length + silence_left]
417+
length += len(new_wavs[0])
418+
yield new_wavs
422419

423420
@torch.inference_mode()
424421
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
@@ -457,13 +454,14 @@ def _decode_to_wavs(
457454
return wavs
458455

459456
@torch.no_grad()
460-
def _infer_code(
457+
async def _infer_code(
461458
self,
462459
text: Tuple[List[str], str],
463460
stream: bool,
464461
device: torch.device,
465462
return_hidden: bool,
466463
params: InferCodeParams,
464+
stream_batch_size: int,
467465
):
468466

469467
gpt = self.gpt
@@ -504,6 +502,17 @@ def _infer_code(
504502
repetition_penalty=params.repetition_penalty,
505503
)
506504

505+
speaker_embedding_param = self.embed(input_ids, text_mask)
506+
del text_mask
507+
if params.spk_emb is not None:
508+
self.speaker.apply(
509+
speaker_embedding_param,
510+
params.spk_emb,
511+
input_ids,
512+
self.tokenizer.spk_emb_ids,
513+
self.gpt.device_gpt,
514+
)
515+
507516
if gpt.is_vllm:
508517
from .model.velocity import SamplingParams
509518

@@ -522,62 +531,58 @@ def _infer_code(
522531
result = gpt.llm.generate(
523532
None,
524533
sample_params,
525-
input_ids,
534+
uuid.uuid4(),
535+
speaker_embedding_param,
536+
input_ids[0]
526537
)
527-
528-
token_ids = []
529-
hidden_states = []
530-
for i in result:
531-
token_ids.append(torch.tensor(i.outputs[0].token_ids))
532-
hidden_states.append(
533-
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
534-
)
535-
536-
del text_mask, input_ids
537-
538-
return [
539-
GPT.GenerationOutputs(
540-
ids=token_ids,
541-
hiddens=hidden_states,
542-
attentions=[],
543-
),
544-
]
545-
546-
emb = self.embed(input_ids, text_mask)
547-
548-
del text_mask
549-
550-
if params.spk_emb is not None:
551-
self.speaker.apply(
552-
emb,
553-
params.spk_emb,
538+
async for i in result:
539+
token_ids = []
540+
hidden_states = []
541+
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
542+
token_ids.append(torch.tensor(i.outputs[0].token_ids))
543+
hidden_states.append(
544+
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
545+
)
546+
yield GPT.GenerationOutputs(
547+
ids=token_ids,
548+
finished=i.finished,
549+
hiddens=hidden_states,
550+
attentions=[],
551+
)
552+
else:
553+
result = gpt.generate(
554+
speaker_embedding_param,
554555
input_ids,
555-
self.tokenizer.spk_emb_ids,
556-
self.gpt.device_gpt,
556+
temperature=torch.tensor(temperature, device=device),
557+
eos_token=num_code,
558+
attention_mask=attention_mask,
559+
max_new_token=params.max_new_token,
560+
min_new_token=params.min_new_token,
561+
logits_processors=(*logits_processors, *logits_warpers),
562+
infer_text=False,
563+
return_hidden=return_hidden,
564+
stream=stream,
565+
show_tqdm=params.show_tqdm,
566+
ensure_non_empty=params.ensure_non_empty,
567+
stream_batch=params.stream_batch,
568+
manual_seed=params.manual_seed,
569+
context=self.context,
557570
)
558-
559-
result = gpt.generate(
560-
emb,
561-
input_ids,
562-
temperature=torch.tensor(temperature, device=device),
563-
eos_token=num_code,
564-
attention_mask=attention_mask,
565-
max_new_token=params.max_new_token,
566-
min_new_token=params.min_new_token,
567-
logits_processors=(*logits_processors, *logits_warpers),
568-
infer_text=False,
569-
return_hidden=return_hidden,
570-
stream=stream,
571-
show_tqdm=params.show_tqdm,
572-
ensure_non_empty=params.ensure_non_empty,
573-
stream_batch=params.stream_batch,
574-
manual_seed=params.manual_seed,
575-
context=self.context,
576-
)
577-
578-
del emb, input_ids
579-
580-
return result
571+
del speaker_embedding_param, input_ids
572+
async for i in result:
573+
token_ids = []
574+
hidden_states = []
575+
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
576+
token_ids.append(i.ids[0])
577+
hidden_states.append(
578+
i.hiddens[0].to(torch.float32).to(self.device)
579+
)
580+
yield GPT.GenerationOutputs(
581+
ids=token_ids,
582+
finished=i.finished,
583+
hiddens=hidden_states,
584+
attentions=[],
585+
)
581586

582587
@torch.no_grad()
583588
def _refine_text(

ChatTTS/model/gpt.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def from_pretrained(
6868
num_audio_tokens=self.num_audio_tokens,
6969
num_text_tokens=self.num_text_tokens,
7070
post_model_path=embed_file_path,
71+
dtype="float32"
7172
)
7273
self.logger.info("vLLM model loaded")
7374
return
@@ -273,6 +274,7 @@ class GenerationOutputs:
273274
ids: List[torch.Tensor]
274275
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
275276
hiddens: List[torch.Tensor]
277+
finished: bool
276278

277279
def destroy(self):
278280
del_all(self.ids)
@@ -288,6 +290,7 @@ def _prepare_generation_outputs(
288290
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
289291
hiddens: List[torch.Tensor],
290292
infer_text: bool,
293+
finished: bool,
291294
) -> GenerationOutputs:
292295
inputs_ids = [
293296
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
@@ -305,10 +308,11 @@ def _prepare_generation_outputs(
305308
ids=inputs_ids,
306309
attentions=attentions,
307310
hiddens=hiddens,
311+
finished=finished,
308312
)
309313

310314
@torch.no_grad()
311-
def generate(
315+
async def generate(
312316
self,
313317
emb: torch.Tensor,
314318
inputs_ids: torch.Tensor,
@@ -581,6 +585,7 @@ def generate(
581585
attentions,
582586
hiddens,
583587
infer_text,
588+
False
584589
)
585590
del not_finished
586591

@@ -610,4 +615,5 @@ def generate(
610615
attentions,
611616
hiddens,
612617
infer_text,
618+
True
613619
)

0 commit comments

Comments
 (0)