26
26
from ....cluster import ClusterAPI
27
27
from ....core import ActorCallback
28
28
from ....subtask import Subtask , SubtaskAPI , SubtaskResult , SubtaskStatus
29
- from ....task import TaskAPI
30
29
from ..queues import SubtaskPrepareQueueActor , SubtaskExecutionQueueActor
31
30
from ..quota import QuotaActor
32
31
from ..slotmanager import SlotManagerActor
@@ -102,19 +101,37 @@ async def _get_band_quota_ref(
102
101
) -> Union [mo .ActorRef , QuotaActor ]:
103
102
return await mo .actor_ref (QuotaActor .gen_uid (band_name ), address = self .address )
104
103
104
+ @staticmethod
105
+ @alru_cache (cache_exceptions = False )
106
+ async def _get_manager_ref (session_id : str , supervisor_address : str ):
107
+ from ...supervisor .manager import SubtaskManagerActor
108
+
109
+ return await mo .actor_ref (
110
+ uid = SubtaskManagerActor .gen_uid (session_id ),
111
+ address = supervisor_address ,
112
+ )
113
+
105
114
def _build_subtask_info (
106
115
self ,
107
116
subtask : Subtask ,
108
117
priority : Tuple ,
109
118
supervisor_address : str ,
110
119
band_name : str ,
111
120
) -> SubtaskExecutionInfo :
121
+ subtask_max_retries = (
122
+ subtask .extra_config .get ("subtask_max_retries" )
123
+ if subtask .extra_config
124
+ else None
125
+ )
126
+ if subtask_max_retries is None :
127
+ subtask_max_retries = self ._subtask_max_retries
128
+
112
129
subtask_info = SubtaskExecutionInfo (
113
130
subtask ,
114
131
priority ,
115
132
supervisor_address = supervisor_address ,
116
133
band_name = band_name ,
117
- max_retries = self . _subtask_max_retries ,
134
+ max_retries = subtask_max_retries ,
118
135
)
119
136
subtask_info .result = SubtaskResult (
120
137
subtask_id = subtask .subtask_id ,
@@ -252,18 +269,19 @@ async def _dequeue_subtask_ids(self, queue_ref, subtask_ids: List[str]):
252
269
infos_to_report .append (subtask_info )
253
270
await self ._report_subtask_results (infos_to_report )
254
271
255
- @staticmethod
256
- async def _report_subtask_results (subtask_infos : List [SubtaskExecutionInfo ]):
272
+ async def _report_subtask_results (self , subtask_infos : List [SubtaskExecutionInfo ]):
257
273
if not subtask_infos :
258
274
return
259
- task_api = await TaskAPI .create (
260
- subtask_infos [0 ].result .session_id , subtask_infos [0 ].supervisor_address
275
+ try :
276
+ manager_ref = await self ._get_manager_ref (
277
+ subtask_infos [0 ].result .session_id , subtask_infos [0 ].supervisor_address
278
+ )
279
+ except mo .ActorNotExist :
280
+ return
281
+ await manager_ref .set_subtask_results (
282
+ [info .result for info in subtask_infos ],
283
+ [(self .address , info .band_name ) for info in subtask_infos ],
261
284
)
262
- batch = [
263
- task_api .set_subtask_result .delay (subtask_info .result )
264
- for subtask_info in subtask_infos
265
- ]
266
- await task_api .set_subtask_result .batch (* batch )
267
285
268
286
async def cancel_subtasks (
269
287
self , subtask_ids : List [str ], kill_timeout : Optional [int ] = 5
@@ -307,6 +325,25 @@ async def wait_subtasks(self, subtask_ids: List[str]):
307
325
yield asyncio .wait ([info .finish_future for info in infos ])
308
326
raise mo .Return ([info .result for info in infos ])
309
327
328
+ def _create_subtask_with_exception (self , subtask_id , coro ):
329
+ info = self ._subtask_executions [subtask_id ]
330
+
331
+ async def _run_with_exception_handling ():
332
+ try :
333
+ return await coro
334
+ except : # noqa: E722 # nosec # pylint: disable=bare-except
335
+ self ._fill_result_with_exc (info )
336
+ await self ._report_subtask_results ([info ])
337
+ await self ._prepare_queue_ref .release_slot (
338
+ info .subtask .subtask_id , errors = "ignore"
339
+ )
340
+ await self ._execution_queue_ref .release_slot (
341
+ info .subtask .subtask_id , errors = "ignore"
342
+ )
343
+
344
+ task = asyncio .create_task (_run_with_exception_handling ())
345
+ info .aio_tasks .append (task )
346
+
310
347
async def handle_prepare_queue (self , band_name : str ):
311
348
while True :
312
349
try :
@@ -322,8 +359,8 @@ async def handle_prepare_queue(self, band_name: str):
322
359
continue
323
360
324
361
logger .debug (f"Obtained subtask { subtask_id } from prepare queue" )
325
- subtask_info . aio_tasks . append (
326
- asyncio . create_task ( self ._prepare_subtask_with_retry (subtask_info ) )
362
+ self . _create_subtask_with_exception (
363
+ subtask_id , self ._prepare_subtask_with_retry (subtask_info )
327
364
)
328
365
329
366
async def handle_execute_queue (self , band_name : str ):
@@ -355,8 +392,8 @@ async def handle_execute_queue(self, band_name: str):
355
392
c .key in self ._pred_key_mapping_dag
356
393
for c in subtask_info .subtask .chunk_graph .result_chunks
357
394
)
358
- subtask_info . aio_tasks . append (
359
- asyncio . create_task ( self ._execute_subtask_with_retry (subtask_info ) )
395
+ self . _create_subtask_with_exception (
396
+ subtask_id , self ._execute_subtask_with_retry (subtask_info )
360
397
)
361
398
362
399
async def _prepare_subtask_once (self , subtask_info : SubtaskExecutionInfo ):
0 commit comments