diff --git a/channels/consumer.py b/channels/consumer.py index fc065432..c241e8fd 100644 --- a/channels/consumer.py +++ b/channels/consumer.py @@ -47,6 +47,15 @@ async def __call__(self, scope, receive, send): self.channel_receive = functools.partial( self.channel_layer.receive, self.channel_name ) + # Handler to call when dispatch task is cancelled + cancel_callback = None + try: + if callable(self.channel_layer.clean_channel): + cancel_callback = functools.partial( + self.channel_layer.clean_channel, self.channel_name + ) + except AttributeError: + pass # Store send function if self._sync: self.base_send = async_to_sync(send) @@ -56,7 +65,9 @@ async def __call__(self, scope, receive, send): try: if self.channel_layer is not None: await await_many_dispatch( - [receive, self.channel_receive], self.dispatch + [receive, self.channel_receive], + self.dispatch, + cancel_callback=cancel_callback, ) else: await await_many_dispatch([receive], self.dispatch) diff --git a/channels/utils.py b/channels/utils.py index 72cd9ca3..0478f328 100644 --- a/channels/utils.py +++ b/channels/utils.py @@ -29,7 +29,7 @@ def name_that_thing(thing): return repr(thing) -async def await_many_dispatch(consumer_callables, dispatch): +async def await_many_dispatch(consumer_callables, dispatch, cancel_callback=None): """ Given a set of consumer callables, awaits on them all and passes results from them to the dispatch awaitable as they come in. @@ -56,4 +56,5 @@ async def await_many_dispatch(consumer_callables, dispatch): try: await task except asyncio.CancelledError: - pass + if cancel_callback: + await cancel_callback() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..e7654da3 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,20 @@ +import asyncio +from unittest import mock + +import async_timeout +import pytest + +from channels.utils import await_many_dispatch + + +async def sleep_task(*args): + await asyncio.sleep(10) + + +@pytest.mark.asyncio +async def test_cancel_callback_called(): + cancel_callback = mock.AsyncMock() + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(0): + await await_many_dispatch([sleep_task], sleep_task, cancel_callback) + assert cancel_callback.called