@@ -95,8 +95,7 @@ async def __post_create__(self):
95
95
AssignerActor .gen_uid (self ._session_id ), address = self .address
96
96
)
97
97
98
- @alru_cache
99
- async def _get_task_api (self ):
98
+ async def _get_task_api (self ) -> TaskAPI :
100
99
return await TaskAPI .create (self ._session_id , self .address )
101
100
102
101
def _put_subtask_with_priority (self , subtask : Subtask , priority : Tuple = None ):
@@ -272,21 +271,47 @@ async def update_subtask_priorities(
272
271
273
272
@alru_cache (maxsize = 10000 )
274
273
async def _get_execution_ref (self , address : str ):
275
- from ..worker .exec import SubtaskExecutionActor
274
+ from ..worker .execution import SubtaskExecutionActor
276
275
277
276
return await mo .actor_ref (SubtaskExecutionActor .default_uid (), address = address )
278
277
279
- async def finish_subtasks (self , subtask_ids : List [str ], schedule_next : bool = True ):
280
- band_tasks = defaultdict (lambda : 0 )
281
- for subtask_id in subtask_ids :
282
- subtask_info = self ._subtask_infos .pop (subtask_id , None )
278
+ async def set_subtask_results (
279
+ self , subtask_results : List [SubtaskResult ], source_bands : List [BandType ]
280
+ ):
281
+ delays = []
282
+ task_api = await self ._get_task_api ()
283
+ for result , band in zip (subtask_results , source_bands ):
284
+ if result .status == SubtaskStatus .errored :
285
+ subtask_info = self ._subtask_infos .get (result .subtask_id )
286
+ if (
287
+ subtask_info is not None
288
+ and subtask_info .subtask .retryable
289
+ and subtask_info .num_reschedules < subtask_info .max_reschedules
290
+ and isinstance (result .error , (MarsError , OSError ))
291
+ ):
292
+ subtask_info .num_reschedules += 1
293
+ logger .warning (
294
+ "Resubmit subtask %s at attempt %d" ,
295
+ subtask_info .subtask .subtask_id ,
296
+ subtask_info .num_reschedules ,
297
+ )
298
+ execution_ref = await self ._get_execution_ref (band [0 ])
299
+ await execution_ref .submit_subtasks .tell (
300
+ [subtask_info .subtask ],
301
+ [subtask_info .priority ],
302
+ self .address ,
303
+ band [1 ],
304
+ )
305
+ continue
306
+
307
+ subtask_info = self ._subtask_infos .pop (result .subtask_id , None )
283
308
if subtask_info is not None :
284
- self ._subtask_summaries [subtask_id ] = subtask_info .to_summary (
309
+ self ._subtask_summaries [result . subtask_id ] = subtask_info .to_summary (
285
310
is_finished = True
286
311
)
287
- if schedule_next :
288
- for band in subtask_info . submitted_bands :
289
- band_tasks [ band ] += 1
312
+ delays . append ( task_api . set_subtask_result . delay ( result ))
313
+
314
+ await task_api . set_subtask_result . batch ( * delays )
290
315
291
316
def _get_subtasks_by_ids (self , subtask_ids : List [str ]) -> List [Optional [Subtask ]]:
292
317
subtasks = []
0 commit comments