Skip to content

Commit 5b7863f

Browse files
chore(format): run black on dev
1 parent ecc7b52 commit 5b7863f

File tree

5 files changed

+57
-37
lines changed

5 files changed

+57
-37
lines changed

ChatTTS/core.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ async def _infer(
404404
keep_cols = np.sum(abs(wavs[0][length:]) > 1e-6, axis=0) > 0
405405

406406
import librosa
407+
407408
silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
408409
silence_left = 0
409410
if len(silence_intervals) == 0:
@@ -529,16 +530,14 @@ async def _infer_code(
529530
input_ids = [i.tolist() for i in input_ids]
530531

531532
result = gpt.llm.generate(
532-
None,
533-
sample_params,
534-
uuid.uuid4(),
535-
speaker_embedding_param,
536-
input_ids[0]
533+
None, sample_params, uuid.uuid4(), speaker_embedding_param, input_ids[0]
537534
)
538535
async for i in result:
539536
token_ids = []
540537
hidden_states = []
541-
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
538+
if (
539+
stream and len(i.outputs[0].token_ids) % stream_batch_size == 0
540+
) or i.finished:
542541
token_ids.append(torch.tensor(i.outputs[0].token_ids))
543542
hidden_states.append(
544543
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
@@ -574,9 +573,7 @@ async def _infer_code(
574573
hidden_states = []
575574
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
576575
token_ids.append(i.ids[0])
577-
hidden_states.append(
578-
i.hiddens[0].to(torch.float32).to(self.device)
579-
)
576+
hidden_states.append(i.hiddens[0].to(torch.float32).to(self.device))
580577
yield GPT.GenerationOutputs(
581578
ids=token_ids,
582579
finished=i.finished,

ChatTTS/model/gpt.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def from_pretrained(
6868
num_audio_tokens=self.num_audio_tokens,
6969
num_text_tokens=self.num_text_tokens,
7070
post_model_path=embed_file_path,
71-
dtype="float32"
71+
dtype="float32",
7272
)
7373
self.logger.info("vLLM model loaded")
7474
return
@@ -585,7 +585,7 @@ async def generate(
585585
attentions,
586586
hiddens,
587587
infer_text,
588-
False
588+
False,
589589
)
590590
del not_finished
591591

@@ -609,11 +609,5 @@ async def generate(
609609
del finish, inputs_ids_buf
610610

611611
yield self._prepare_generation_outputs(
612-
inputs_ids,
613-
start_idx,
614-
end_idx,
615-
attentions,
616-
hiddens,
617-
infer_text,
618-
True
612+
inputs_ids, start_idx, end_idx, attentions, hiddens, infer_text, True
619613
)

ChatTTS/model/velocity/llm.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ def __init__(
112112
self.request_counter = Counter()
113113

114114
async def generate(
115-
self,
116-
prompt: Optional[str],
117-
sampling_params: SamplingParams,
118-
request_id: str,
119-
speaker_embedding_param: torch.Tensor,
120-
prompt_token_ids: Optional[List[int]] = None,
115+
self,
116+
prompt: Optional[str],
117+
sampling_params: SamplingParams,
118+
request_id: str,
119+
speaker_embedding_param: torch.Tensor,
120+
prompt_token_ids: Optional[List[int]] = None,
121121
) -> AsyncIterator[RequestOutput]:
122122
"""Generate outputs for a request.
123123

ChatTTS/model/velocity/model_runner.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def set_block_size(self, block_size: int) -> None:
106106
def _prepare_prompt(
107107
self,
108108
seq_group_metadata_list: List[SequenceGroupMetadata],
109-
) -> tuple[list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]]:
109+
) -> tuple[
110+
list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]
111+
]:
110112
assert len(seq_group_metadata_list) > 0
111113
input_tokens: List[List[int]] = []
112114
input_positions: List[List[int]] = []
@@ -359,17 +361,23 @@ def _prepare_sample(
359361
def prepare_input_tensors(
360362
self,
361363
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
362-
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]]:
364+
) -> Tuple[
365+
torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]
366+
]:
363367
speaker_embedding = None
364368
if self.is_driver_worker:
365369
# NOTE: We assume that all sequences in the group are all prompts or
366370
# all decodes.
367371
is_prompt = seq_group_metadata_list[0].is_prompt
368372
# Prepare input tensors.
369373
if is_prompt:
370-
(input_tokens, input_positions, input_metadata, prompt_lens, speaker_embedding) = (
371-
self._prepare_prompt(seq_group_metadata_list)
372-
)
374+
(
375+
input_tokens,
376+
input_positions,
377+
input_metadata,
378+
prompt_lens,
379+
speaker_embedding,
380+
) = self._prepare_prompt(seq_group_metadata_list)
373381
else:
374382
(input_tokens, input_positions, input_metadata) = self._prepare_decode(
375383
seq_group_metadata_list
@@ -461,7 +469,13 @@ def get_size_or_none(x: Optional[torch.Tensor]):
461469
perform_sampling=False,
462470
)
463471

464-
return input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding
472+
return (
473+
input_tokens,
474+
input_positions,
475+
input_metadata,
476+
sampling_metadata,
477+
speaker_embedding,
478+
)
465479

466480
@torch.inference_mode()
467481
def execute_model(
@@ -470,9 +484,13 @@ def execute_model(
470484
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
471485
) -> Optional[SamplerOutput]:
472486

473-
input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding = (
474-
self.prepare_input_tensors(seq_group_metadata_list)
475-
)
487+
(
488+
input_tokens,
489+
input_positions,
490+
input_metadata,
491+
sampling_metadata,
492+
speaker_embedding,
493+
) = self.prepare_input_tensors(seq_group_metadata_list)
476494
# print(sampling_metadata.seq_data)
477495
seq_groups = []
478496
for i, rtn in enumerate(sampling_metadata.seq_groups):
@@ -521,7 +539,9 @@ def execute_model(
521539
if speaker_embedding_params is None:
522540
speaker_embedding_params = speaker_embedding[i]
523541
else:
524-
speaker_embedding_params = torch.cat((speaker_embedding_params, speaker_embedding[i]))
542+
speaker_embedding_params = torch.cat(
543+
(speaker_embedding_params, speaker_embedding[i])
544+
)
525545

526546
else:
527547
speaker_embedding_params = self.post_model(input_tokens, text_mask)
@@ -559,7 +579,7 @@ def execute_model(
559579
# sampling_metadata=sampling_metadata,
560580
# )
561581
results = []
562-
for i,val in enumerate(seq_groups):
582+
for i, val in enumerate(seq_groups):
563583
idx_next_i = idx_next[i, 0, :].tolist()
564584
logprob_i = logprob[i].tolist()
565585
tmp_hidden_states = hidden_states[i]
@@ -780,7 +800,9 @@ def _make_tensor_with_pad(
780800
for x_i in x:
781801
pad_i = pad
782802
if isinstance(x[0][0], list):
783-
pad_i = [0,] * len(x[0][0])
803+
pad_i = [
804+
0,
805+
] * len(x[0][0])
784806
elif isinstance(x[0][0], tuple):
785807
pad_i = (0,) * len(x[0][0])
786808
padded_x.append(_pad_to_max(x_i, max_len, pad_i))
@@ -790,6 +812,7 @@ def _make_tensor_with_pad(
790812
device=device,
791813
)
792814

815+
793816
def _make_with_pad(
794817
x: List[torch.Tensor],
795818
max_len: int,
@@ -804,11 +827,15 @@ def _make_with_pad(
804827
padded_x.append(x_i)
805828
else:
806829
padded_x.append(
807-
torch.cat((torch.zeros(1, max_len-x_i.shape[-2], 768).to(device), x_i), dim=1)
830+
torch.cat(
831+
(torch.zeros(1, max_len - x_i.shape[-2], 768).to(device), x_i),
832+
dim=1,
833+
)
808834
)
809835

810836
return padded_x
811837

838+
812839
def _get_graph_batch_size(batch_size: int) -> int:
813840
if batch_size <= 2:
814841
return batch_size

tools/audio/np.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
1212
am = 32767 * 32768 // am
1313
return np.multiply(audio, am).astype(np.int16)
1414

15+
1516
def pcm_to_bytes(pcm_data: np.ndarray) -> bytes:
1617
return float_to_int16(pcm_data).tobytes()
1718

19+
1820
def pcm_to_wav_bytes(pcm_data: np.ndarray) -> bytes:
1921
buf = io.BytesIO()
2022
with wave.open(buf, "wb") as wf:

0 commit comments

Comments
 (0)