1
1
import os
2
2
import logging
3
3
import tempfile
4
+ import uuid
4
5
from dataclasses import dataclass , asdict
5
6
from typing import Literal , Optional , List , Tuple , Dict , Union
6
7
from json import load
@@ -200,9 +201,10 @@ def infer(
200
201
do_homophone_replacement = True ,
201
202
params_refine_text = RefineTextParams (),
202
203
params_infer_code = InferCodeParams (),
204
+ stream_batch_size = 16 ,
203
205
):
204
206
self .context .set (False )
205
- res_gen = self ._infer (
207
+ return self ._infer (
206
208
text ,
207
209
stream ,
208
210
lang ,
@@ -213,11 +215,8 @@ def infer(
213
215
do_homophone_replacement ,
214
216
params_refine_text ,
215
217
params_infer_code ,
218
+ stream_batch_size ,
216
219
)
217
- if stream :
218
- return res_gen
219
- else :
220
- return next (res_gen )
221
220
222
221
def interrupt (self ):
223
222
self .context .set (True )
@@ -339,7 +338,7 @@ def _load(
339
338
340
339
return self .has_loaded ()
341
340
342
- def _infer (
341
+ async def _infer (
343
342
self ,
344
343
text ,
345
344
stream = False ,
@@ -351,6 +350,7 @@ def _infer(
351
350
do_homophone_replacement = True ,
352
351
params_refine_text = RefineTextParams (),
353
352
params_infer_code = InferCodeParams (),
353
+ stream_batch_size = 16 ,
354
354
):
355
355
356
356
assert self .has_loaded (use_decoder = use_decoder )
@@ -384,41 +384,38 @@ def _infer(
384
384
yield text
385
385
return
386
386
387
- if stream :
388
- length = 0
389
- pass_batch_count = 0
390
- for result in self ._infer_code (
387
+ length = 0
388
+ async for result in self ._infer_code (
391
389
text ,
392
390
stream ,
393
391
self .device ,
394
392
use_decoder ,
395
393
params_infer_code ,
394
+ stream_batch_size ,
396
395
):
397
396
wavs = self ._decode_to_wavs (
398
397
result .hiddens if use_decoder else result .ids ,
399
398
use_decoder ,
400
399
)
401
- result .destroy ()
402
- if stream :
403
- pass_batch_count += 1
404
- if pass_batch_count <= params_infer_code .pass_first_n_batches :
405
- continue
406
- a = length
407
- b = a + params_infer_code .stream_speed
408
- if b > wavs .shape [1 ]:
409
- b = wavs .shape [1 ]
410
- new_wavs = wavs [:, a :b ]
411
- length = b
412
- yield new_wavs
400
+ if result .finished :
401
+ yield wavs [:, length :]
413
402
else :
414
- yield wavs
415
- if stream :
416
- new_wavs = wavs [:, length :]
417
- # Identify rows with non-zero elements using np.any
418
- # keep_rows = np.any(array != 0, axis=1)
419
- keep_cols = np .sum (new_wavs != 0 , axis = 0 ) > 0
420
- # Filter both rows and columns using slicing
421
- yield new_wavs [:][:, keep_cols ]
403
+ # Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
404
+ keep_cols = np .sum (abs (wavs [0 ][length :]) > 1e-6 , axis = 0 ) > 0
405
+
406
+ import librosa
407
+ silence_intervals = librosa .effects .split (wavs [0 ][length :], top_db = 10 )
408
+ silence_left = 0
409
+ if len (silence_intervals ) == 0 :
410
+ silence_left = len (wavs [0 ])
411
+ else :
412
+ for i in range (len (silence_intervals )):
413
+ silence_left = silence_intervals [i ][0 ]
414
+ if silence_left <= 0 :
415
+ continue
416
+ new_wavs = wavs [:, length : length + silence_left ]
417
+ length += len (new_wavs [0 ])
418
+ yield new_wavs
422
419
423
420
@torch .inference_mode ()
424
421
def _vocos_decode (self , spec : torch .Tensor ) -> np .ndarray :
@@ -457,13 +454,14 @@ def _decode_to_wavs(
457
454
return wavs
458
455
459
456
@torch .no_grad ()
460
- def _infer_code (
457
+ async def _infer_code (
461
458
self ,
462
459
text : Tuple [List [str ], str ],
463
460
stream : bool ,
464
461
device : torch .device ,
465
462
return_hidden : bool ,
466
463
params : InferCodeParams ,
464
+ stream_batch_size : int ,
467
465
):
468
466
469
467
gpt = self .gpt
@@ -504,6 +502,17 @@ def _infer_code(
504
502
repetition_penalty = params .repetition_penalty ,
505
503
)
506
504
505
+ speaker_embedding_param = self .embed (input_ids , text_mask )
506
+ del text_mask
507
+ if params .spk_emb is not None :
508
+ self .speaker .apply (
509
+ speaker_embedding_param ,
510
+ params .spk_emb ,
511
+ input_ids ,
512
+ self .tokenizer .spk_emb_ids ,
513
+ self .gpt .device_gpt ,
514
+ )
515
+
507
516
if gpt .is_vllm :
508
517
from .model .velocity import SamplingParams
509
518
@@ -522,62 +531,58 @@ def _infer_code(
522
531
result = gpt .llm .generate (
523
532
None ,
524
533
sample_params ,
525
- input_ids ,
534
+ uuid .uuid4 (),
535
+ speaker_embedding_param ,
536
+ input_ids [0 ]
526
537
)
527
-
528
- token_ids = []
529
- hidden_states = []
530
- for i in result :
531
- token_ids .append (torch .tensor (i .outputs [0 ].token_ids ))
532
- hidden_states .append (
533
- i .outputs [0 ].hidden_states .to (torch .float32 ).to (self .device )
534
- )
535
-
536
- del text_mask , input_ids
537
-
538
- return [
539
- GPT .GenerationOutputs (
540
- ids = token_ids ,
541
- hiddens = hidden_states ,
542
- attentions = [],
543
- ),
544
- ]
545
-
546
- emb = self .embed (input_ids , text_mask )
547
-
548
- del text_mask
549
-
550
- if params .spk_emb is not None :
551
- self .speaker .apply (
552
- emb ,
553
- params .spk_emb ,
538
+ async for i in result :
539
+ token_ids = []
540
+ hidden_states = []
541
+ if (stream and len (i .outputs [0 ].token_ids ) % stream_batch_size == 0 ) or i .finished :
542
+ token_ids .append (torch .tensor (i .outputs [0 ].token_ids ))
543
+ hidden_states .append (
544
+ i .outputs [0 ].hidden_states .to (torch .float32 ).to (self .device )
545
+ )
546
+ yield GPT .GenerationOutputs (
547
+ ids = token_ids ,
548
+ finished = i .finished ,
549
+ hiddens = hidden_states ,
550
+ attentions = [],
551
+ )
552
+ else :
553
+ result = gpt .generate (
554
+ speaker_embedding_param ,
554
555
input_ids ,
555
- self .tokenizer .spk_emb_ids ,
556
- self .gpt .device_gpt ,
556
+ temperature = torch .tensor (temperature , device = device ),
557
+ eos_token = num_code ,
558
+ attention_mask = attention_mask ,
559
+ max_new_token = params .max_new_token ,
560
+ min_new_token = params .min_new_token ,
561
+ logits_processors = (* logits_processors , * logits_warpers ),
562
+ infer_text = False ,
563
+ return_hidden = return_hidden ,
564
+ stream = stream ,
565
+ show_tqdm = params .show_tqdm ,
566
+ ensure_non_empty = params .ensure_non_empty ,
567
+ stream_batch = params .stream_batch ,
568
+ manual_seed = params .manual_seed ,
569
+ context = self .context ,
557
570
)
558
-
559
- result = gpt .generate (
560
- emb ,
561
- input_ids ,
562
- temperature = torch .tensor (temperature , device = device ),
563
- eos_token = num_code ,
564
- attention_mask = attention_mask ,
565
- max_new_token = params .max_new_token ,
566
- min_new_token = params .min_new_token ,
567
- logits_processors = (* logits_processors , * logits_warpers ),
568
- infer_text = False ,
569
- return_hidden = return_hidden ,
570
- stream = stream ,
571
- show_tqdm = params .show_tqdm ,
572
- ensure_non_empty = params .ensure_non_empty ,
573
- stream_batch = params .stream_batch ,
574
- manual_seed = params .manual_seed ,
575
- context = self .context ,
576
- )
577
-
578
- del emb , input_ids
579
-
580
- return result
571
+ del speaker_embedding_param , input_ids
572
+ async for i in result :
573
+ token_ids = []
574
+ hidden_states = []
575
+ if (stream and len (i .ids [0 ]) % stream_batch_size == 0 ) or i .finished :
576
+ token_ids .append (i .ids [0 ])
577
+ hidden_states .append (
578
+ i .hiddens [0 ].to (torch .float32 ).to (self .device )
579
+ )
580
+ yield GPT .GenerationOutputs (
581
+ ids = token_ids ,
582
+ finished = i .finished ,
583
+ hiddens = hidden_states ,
584
+ attentions = [],
585
+ )
581
586
582
587
@torch .no_grad ()
583
588
def _refine_text (
0 commit comments