From 49abd3379bcfd19c45a6f156f2a33a79736db4b8 Mon Sep 17 00:00:00 2001 From: Seth Foster Date: Fri, 4 Nov 2022 14:42:40 -0400 Subject: [PATCH 1/4] call clean_channel in cancel --- channels/consumer.py | 6 +++++- channels/utils.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/channels/consumer.py b/channels/consumer.py index fc065432b..b7b78b1c0 100644 --- a/channels/consumer.py +++ b/channels/consumer.py @@ -47,6 +47,10 @@ async def __call__(self, scope, receive, send): self.channel_receive = functools.partial( self.channel_layer.receive, self.channel_name ) + if getattr(self.channel_layer, "clean_channel", None) and callable(self.channel_layer.clean_channel): + cancel_callback = functools.partial(self.channel_layer.clean_channel, self.channel_name) + else: + cancel_callback = None # Store send function if self._sync: self.base_send = async_to_sync(send) @@ -56,7 +60,7 @@ async def __call__(self, scope, receive, send): try: if self.channel_layer is not None: await await_many_dispatch( - [receive, self.channel_receive], self.dispatch + [receive, self.channel_receive], self.dispatch, cancel_callback=cancel_callback ) else: await await_many_dispatch([receive], self.dispatch) diff --git a/channels/utils.py b/channels/utils.py index 72cd9ca30..0478f3283 100644 --- a/channels/utils.py +++ b/channels/utils.py @@ -29,7 +29,7 @@ def name_that_thing(thing): return repr(thing) -async def await_many_dispatch(consumer_callables, dispatch): +async def await_many_dispatch(consumer_callables, dispatch, cancel_callback=None): """ Given a set of consumer callables, awaits on them all and passes results from them to the dispatch awaitable as they come in. @@ -56,4 +56,5 @@ async def await_many_dispatch(consumer_callables, dispatch): try: await task except asyncio.CancelledError: - pass + if cancel_callback: + await cancel_callback() From 485f9a97742662f269679f34daaa839679237e0a Mon Sep 17 00:00:00 2001 From: Seth Foster Date: Mon, 14 Nov 2022 17:06:31 -0500 Subject: [PATCH 2/4] Use try except instead of getattr --- channels/consumer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/channels/consumer.py b/channels/consumer.py index b7b78b1c0..f934cc9cd 100644 --- a/channels/consumer.py +++ b/channels/consumer.py @@ -47,10 +47,13 @@ async def __call__(self, scope, receive, send): self.channel_receive = functools.partial( self.channel_layer.receive, self.channel_name ) - if getattr(self.channel_layer, "clean_channel", None) and callable(self.channel_layer.clean_channel): - cancel_callback = functools.partial(self.channel_layer.clean_channel, self.channel_name) - else: - cancel_callback = None + # Handler to call when dispatch task is cancelled + cancel_callback = None + try: + if callable(self.channel_layer.clean_channel): + cancel_callback = functools.partial(self.channel_layer.clean_channel, self.channel_name) + except AttributeError: + pass # Store send function if self._sync: self.base_send = async_to_sync(send) From 847f73580bef9575aa74dde97427939b3890e1f6 Mon Sep 17 00:00:00 2001 From: Seth Foster Date: Tue, 15 Nov 2022 12:00:23 -0500 Subject: [PATCH 3/4] fix lint --- channels/consumer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/channels/consumer.py b/channels/consumer.py index f934cc9cd..c241e8fd6 100644 --- a/channels/consumer.py +++ b/channels/consumer.py @@ -50,10 +50,12 @@ async def __call__(self, scope, receive, send): # Handler to call when dispatch task is cancelled cancel_callback = None try: - if callable(self.channel_layer.clean_channel): - cancel_callback = functools.partial(self.channel_layer.clean_channel, self.channel_name) + if callable(self.channel_layer.clean_channel): + cancel_callback = functools.partial( + self.channel_layer.clean_channel, self.channel_name + ) except AttributeError: - pass + pass # Store send function if self._sync: self.base_send = async_to_sync(send) @@ -63,7 +65,9 @@ async def __call__(self, scope, receive, send): try: if self.channel_layer is not None: await await_many_dispatch( - [receive, self.channel_receive], self.dispatch, cancel_callback=cancel_callback + [receive, self.channel_receive], + self.dispatch, + cancel_callback=cancel_callback, ) else: await await_many_dispatch([receive], self.dispatch) From 70a4cfc1e0a9e3d55515e5cab93f32100dbcd8ba Mon Sep 17 00:00:00 2001 From: Seth Foster Date: Thu, 24 Nov 2022 15:11:33 -0500 Subject: [PATCH 4/4] Add unit test for cancel_callback --- tests/test_utils.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..e7654da3b --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,20 @@ +import asyncio +from unittest import mock + +import async_timeout +import pytest + +from channels.utils import await_many_dispatch + + +async def sleep_task(*args): + await asyncio.sleep(10) + + +@pytest.mark.asyncio +async def test_cancel_callback_called(): + cancel_callback = mock.AsyncMock() + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(0): + await await_many_dispatch([sleep_task], sleep_task, cancel_callback) + assert cancel_callback.called