@@ -106,7 +106,9 @@ def set_block_size(self, block_size: int) -> None:
106
106
def _prepare_prompt (
107
107
self ,
108
108
seq_group_metadata_list : List [SequenceGroupMetadata ],
109
- ) -> tuple [list [list [int ]], list [list [int ]], InputMetadata , list [int ], list [Tensor ]]:
109
+ ) -> tuple [
110
+ list [list [int ]], list [list [int ]], InputMetadata , list [int ], list [Tensor ]
111
+ ]:
110
112
assert len (seq_group_metadata_list ) > 0
111
113
input_tokens : List [List [int ]] = []
112
114
input_positions : List [List [int ]] = []
@@ -359,17 +361,23 @@ def _prepare_sample(
359
361
def prepare_input_tensors (
360
362
self ,
361
363
seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]],
362
- ) -> Tuple [torch .Tensor , torch .Tensor , InputMetadata , SamplingMetadata , list [torch .Tensor ]]:
364
+ ) -> Tuple [
365
+ torch .Tensor , torch .Tensor , InputMetadata , SamplingMetadata , list [torch .Tensor ]
366
+ ]:
363
367
speaker_embedding = None
364
368
if self .is_driver_worker :
365
369
# NOTE: We assume that all sequences in the group are all prompts or
366
370
# all decodes.
367
371
is_prompt = seq_group_metadata_list [0 ].is_prompt
368
372
# Prepare input tensors.
369
373
if is_prompt :
370
- (input_tokens , input_positions , input_metadata , prompt_lens , speaker_embedding ) = (
371
- self ._prepare_prompt (seq_group_metadata_list )
372
- )
374
+ (
375
+ input_tokens ,
376
+ input_positions ,
377
+ input_metadata ,
378
+ prompt_lens ,
379
+ speaker_embedding ,
380
+ ) = self ._prepare_prompt (seq_group_metadata_list )
373
381
else :
374
382
(input_tokens , input_positions , input_metadata ) = self ._prepare_decode (
375
383
seq_group_metadata_list
@@ -461,7 +469,13 @@ def get_size_or_none(x: Optional[torch.Tensor]):
461
469
perform_sampling = False ,
462
470
)
463
471
464
- return input_tokens , input_positions , input_metadata , sampling_metadata , speaker_embedding
472
+ return (
473
+ input_tokens ,
474
+ input_positions ,
475
+ input_metadata ,
476
+ sampling_metadata ,
477
+ speaker_embedding ,
478
+ )
465
479
466
480
@torch .inference_mode ()
467
481
def execute_model (
@@ -470,9 +484,13 @@ def execute_model(
470
484
kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
471
485
) -> Optional [SamplerOutput ]:
472
486
473
- input_tokens , input_positions , input_metadata , sampling_metadata , speaker_embedding = (
474
- self .prepare_input_tensors (seq_group_metadata_list )
475
- )
487
+ (
488
+ input_tokens ,
489
+ input_positions ,
490
+ input_metadata ,
491
+ sampling_metadata ,
492
+ speaker_embedding ,
493
+ ) = self .prepare_input_tensors (seq_group_metadata_list )
476
494
# print(sampling_metadata.seq_data)
477
495
seq_groups = []
478
496
for i , rtn in enumerate (sampling_metadata .seq_groups ):
@@ -521,7 +539,9 @@ def execute_model(
521
539
if speaker_embedding_params is None :
522
540
speaker_embedding_params = speaker_embedding [i ]
523
541
else :
524
- speaker_embedding_params = torch .cat ((speaker_embedding_params , speaker_embedding [i ]))
542
+ speaker_embedding_params = torch .cat (
543
+ (speaker_embedding_params , speaker_embedding [i ])
544
+ )
525
545
526
546
else :
527
547
speaker_embedding_params = self .post_model (input_tokens , text_mask )
@@ -559,7 +579,7 @@ def execute_model(
559
579
# sampling_metadata=sampling_metadata,
560
580
# )
561
581
results = []
562
- for i ,val in enumerate (seq_groups ):
582
+ for i , val in enumerate (seq_groups ):
563
583
idx_next_i = idx_next [i , 0 , :].tolist ()
564
584
logprob_i = logprob [i ].tolist ()
565
585
tmp_hidden_states = hidden_states [i ]
@@ -780,7 +800,9 @@ def _make_tensor_with_pad(
780
800
for x_i in x :
781
801
pad_i = pad
782
802
if isinstance (x [0 ][0 ], list ):
783
- pad_i = [0 ,] * len (x [0 ][0 ])
803
+ pad_i = [
804
+ 0 ,
805
+ ] * len (x [0 ][0 ])
784
806
elif isinstance (x [0 ][0 ], tuple ):
785
807
pad_i = (0 ,) * len (x [0 ][0 ])
786
808
padded_x .append (_pad_to_max (x_i , max_len , pad_i ))
@@ -790,6 +812,7 @@ def _make_tensor_with_pad(
790
812
device = device ,
791
813
)
792
814
815
+
793
816
def _make_with_pad (
794
817
x : List [torch .Tensor ],
795
818
max_len : int ,
@@ -804,11 +827,15 @@ def _make_with_pad(
804
827
padded_x .append (x_i )
805
828
else :
806
829
padded_x .append (
807
- torch .cat ((torch .zeros (1 , max_len - x_i .shape [- 2 ], 768 ).to (device ), x_i ), dim = 1 )
830
+ torch .cat (
831
+ (torch .zeros (1 , max_len - x_i .shape [- 2 ], 768 ).to (device ), x_i ),
832
+ dim = 1 ,
833
+ )
808
834
)
809
835
810
836
return padded_x
811
837
838
+
812
839
def _get_graph_batch_size (batch_size : int ) -> int :
813
840
if batch_size <= 2 :
814
841
return batch_size
0 commit comments