diff --git a/bellows/ezsp/fragmentation.py b/bellows/ezsp/fragmentation.py new file mode 100644 index 00000000..0e84b2e2 --- /dev/null +++ b/bellows/ezsp/fragmentation.py @@ -0,0 +1,120 @@ +"""Implements APS fragmentation reassembly on the EZSP Host side, +mirroring the logic from fragmentation.c in the EmberZNet stack. +""" + +import asyncio +import logging +from typing import Dict, Optional, Tuple + +LOGGER = logging.getLogger(__name__) + +# The maximum time (in seconds) we wait for all fragments of a given message. +# If not all fragments arrive within this time, we discard the partial data. +FRAGMENT_TIMEOUT = 10 + +# store partial data keyed by (sender, aps_sequence, profile_id, cluster_id) +FragmentKey = Tuple[int, int, int, int] + + +class _FragmentEntry: + def __init__(self, fragment_count: int): + self.fragment_count = fragment_count + self.fragments_received = 0 + self.fragment_data = {} + self.start_time = asyncio.get_event_loop().time() + + def add_fragment(self, index: int, data: bytes) -> None: + if index not in self.fragment_data: + self.fragment_data[index] = data + self.fragments_received += 1 + + def is_complete(self) -> bool: + return self.fragments_received == self.fragment_count + + def assemble(self) -> bytes: + return b"".join( + self.fragment_data[i] for i in sorted(self.fragment_data.keys()) + ) + + +class FragmentManager: + def __init__(self): + self._partial: Dict[FragmentKey, _FragmentEntry] = {} + self._cleanup_timers: Dict[FragmentKey, asyncio.TimerHandle] = {} + + def handle_incoming_fragment( + self, + sender_nwk: int, + aps_sequence: int, + profile_id: int, + cluster_id: int, + fragment_count: int, + fragment_index: int, + payload: bytes, + ) -> Tuple[bool, Optional[bytes], int, int]: + """Handle a newly received fragment. + + :param sender_nwk: NWK address or the short ID of the sender. + :param aps_sequence: The APS sequence from the incoming APS frame. + :param profile_id: The APS frame's profileId. + :param cluster_id: The APS frame's clusterId. + :param fragment_count: The total number of expected message fragments. + :param fragment_index: The index of the current fragment being processed. + :param payload: The fragment of data for this message. + :return: (complete, reassembled_data, fragment_count, fragment_index) + complete = True if we have all fragments now, else False + reassembled_data = the final complete payload (bytes) if complete is True + fragment_coutn = the total number of fragments holding the complete packet + fragment_index = the index of the current received fragment + """ + + key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id) + + # If we have never seen this message, create a reassembly entry. + if key not in self._partial: + entry = _FragmentEntry(fragment_count) + self._partial[key] = entry + else: + entry = self._partial[key] + + LOGGER.debug( + "Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)", + fragment_index + 1, + fragment_count, + sender_nwk, + aps_sequence, + cluster_id, + ) + + entry.add_fragment(fragment_index, payload) + + loop = asyncio.get_running_loop() + self._cleanup_timers[key] = loop.call_later( + FRAGMENT_TIMEOUT, self.cleanup_partial, key + ) + + if entry.is_complete(): + reassembled = entry.assemble() + del self._partial[key] + timer = self._cleanup_timers.pop(key, None) + if timer: + timer.cancel() + LOGGER.debug( + "Message reassembly complete. Total length=%d", len(reassembled) + ) + return (True, reassembled, fragment_count, fragment_index) + else: + return (False, None, fragment_count, fragment_index) + + def cleanup_partial(self, key: FragmentKey): + # Called when FRAGMENT_TIMEOUT passes with no new fragments for that key. + LOGGER.debug( + "Timeout for partial reassembly of fragmented message, discarding key=%s", + key, + ) + self._partial.pop(key, None) + self._cleanup_timers.pop(key, None) + + +# Create a single global manager instance +fragment_manager = FragmentManager() diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index b03df772..d7d968dd 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -181,6 +181,57 @@ def __call__(self, data: bytes) -> None: if data: LOGGER.debug("Frame contains trailing data: %s", data) + if ( + frame_name == "incomingMessageHandler" and result[1].options & 0x8000 + ): # incoming message with APS_OPTION_FRAGMENT raised + from bellows.ezsp.fragmentation import fragment_manager + + # Extract received APS frame and sender + aps_frame = result[1] + sender = result[4] + + group_id = aps_frame.groupId + profile_id = aps_frame.profileId + cluster_id = aps_frame.clusterId + aps_seq = aps_frame.sequence + + fragment_count = (group_id >> 8) & 0xFF + fragment_index = group_id & 0xFF + + ( + complete, + reassembled, + frag_count, + frag_index, + ) = fragment_manager.handle_incoming_fragment( + sender_nwk=sender, + aps_sequence=aps_seq, + profile_id=profile_id, + cluster_id=cluster_id, + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=result[7], + ) + if not hasattr(self, "_ack_tasks"): + self._ack_tasks = set() + ack_task = asyncio.create_task( + self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) + ) # APS Ack + self._ack_tasks.add(ack_task) + ack_task.add_done_callback(lambda t: self._ack_tasks.discard(t)) + + if not complete: + # Do not pass partial data up the stack + LOGGER.debug("Fragment reassembly not complete. waiting for more data.") + return + else: + # Replace partial data with fully reassembled data + result[7] = reassembled + + LOGGER.debug( + "Reassembled fragmented message. Proceeding with normal handling." + ) + if sequence in self._awaiting: expected_id, schema, future = self._awaiting.pop(sequence) try: @@ -205,6 +256,32 @@ def __call__(self, data: bytes) -> None: else: self._handle_callback(frame_name, result) + async def _send_fragment_ack( + self, + sender: int, + incoming_aps: t.EmberApsFrame, + fragment_count: int, + fragment_index: int, + ) -> t.EmberStatus: + ackFrame = t.EmberApsFrame( + profileId=incoming_aps.profileId, + clusterId=incoming_aps.clusterId, + sourceEndpoint=incoming_aps.destinationEndpoint, + destinationEndpoint=incoming_aps.sourceEndpoint, + options=incoming_aps.options, + groupId=((0xFF00) | (fragment_index & 0xFF)), + sequence=incoming_aps.sequence, + ) + + LOGGER.debug( + "Sending fragment ack to 0x%04X for fragment index=%d/%d", + sender, + fragment_index + 1, + fragment_count, + ) + status = await self.sendReply(sender, ackFrame, b"") + return status[0] + def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: raise AttributeError(f"{name} not found in COMMANDS") diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 98f5678e..a6c2e3c1 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -133,3 +133,160 @@ async def test_parsing_schema_response(prot_hndl_v9): rsp = await coro assert rsp == GetTokenDataRsp(status=t.EmberStatus.LIBRARY_NOT_PRESENT) + + +@pytest.mark.asyncio +async def test_send_fragment_ack(prot_hndl, caplog): + """Test the _send_fragment_ack method.""" + sender = 0x1D6F + incoming_aps = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, + groupId=512, + sequence=238, + ) + fragment_count = 2 + fragment_index = 0 + + expected_ack_frame = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, + groupId=((0xFF00) | (fragment_index & 0xFF)), + sequence=238, + ) + + with patch.object(prot_hndl, "sendReply", new=AsyncMock()) as mock_send_reply: + mock_send_reply.return_value = (t.EmberStatus.SUCCESS,) + + caplog.set_level(logging.DEBUG) + status = await prot_hndl._send_fragment_ack( + sender, incoming_aps, fragment_count, fragment_index + ) + + # Assertions + assert status == t.EmberStatus.SUCCESS + assert ( + "Sending fragment ack to 0x1d6f for fragment index=1/2".lower() + in caplog.text.lower() + ) + mock_send_reply.assert_called_once_with(sender, expected_ack_frame, b"") + + +@pytest.mark.asyncio +async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog): + """Test handling of an incomplete fragmented message.""" + packet = b"\x90\x01\x45\x00\x05\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x01\xdd" + + # Parse packet manually to extract parameters for assertions + sender = 0x1D6F + aps_frame = t.EmberApsFrame( + profileId=261, # 0x0105 + clusterId=65281, # 0xFF01 + sourceEndpoint=2, # 0x02 + destinationEndpoint=2, # 0x02 + options=33088, # 0x8140 (APS_OPTION_FRAGMENT + others) + groupId=512, # 0x0002 (fragment_count=2, fragment_index=0) + sequence=238, # 0xEE + ) + + with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: + mock_ack.return_value = None + + caplog.set_level(logging.DEBUG) + prot_hndl(packet) + + assert hasattr(prot_hndl, "_ack_tasks") + assert len(prot_hndl._ack_tasks) == 1 + ack_task = next(iter(prot_hndl._ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task" + + prot_hndl._handle_callback.assert_not_called() + assert "Fragment reassembly not complete. waiting for more data." in caplog.text + mock_ack.assert_called_once_with(sender, aps_frame, 2, 0) + + +@pytest.mark.asyncio +async def test_incoming_fragmented_message_complete(prot_hndl, caplog): + """Test handling of a complete fragmented message.""" + packet1 = ( + b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x09" + + b"complete " + ) # fragment index 0 + packet2 = ( + b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07" + + b"message" + ) # fragment index 1 + sender = 0x1D6F + + aps_frame_1 = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, # Includes APS_OPTION_FRAGMENT + groupId=512, # fragment_count=2, fragment_index=0 + sequence=238, + ) + aps_frame_2 = t.EmberApsFrame( + profileId=260, + clusterId=65281, + sourceEndpoint=2, + destinationEndpoint=2, + options=33088, + groupId=513, # fragment_count=2, fragment_index=1 + sequence=238, + ) + reassembled = b"complete message" + + with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: + mock_ack.return_value = None + caplog.set_level(logging.DEBUG) + + # Packet 1 + prot_hndl(packet1) + assert hasattr(prot_hndl, "_ack_tasks") + assert len(prot_hndl._ack_tasks) == 1 + ack_task = next(iter(prot_hndl._ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task" + + prot_hndl._handle_callback.assert_not_called() + assert ( + "Reassembled fragmented message. Proceeding with normal handling." + not in caplog.text + ) + mock_ack.assert_called_with(sender, aps_frame_1, 2, 0) + + # Packet 2 + prot_hndl(packet2) + assert hasattr(prot_hndl, "_ack_tasks") + assert len(prot_hndl._ack_tasks) == 1 + ack_task = next(iter(prot_hndl._ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert len(prot_hndl._ack_tasks) == 0, "Done callback should have removed task" + + prot_hndl._handle_callback.assert_called_once_with( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00 + aps_frame_2, # Parsed APS frame + 255, # lastHopLqi: 0xFF + -8, # lastHopRssi: 0xF8 + sender, # 0x1D6F + 255, # bindingIndex: 0xFF + 255, # addressIndex: 0xFF + reassembled, # Reassembled payload + ], + ) + assert ( + "Reassembled fragmented message. Proceeding with normal handling." + in caplog.text + ) + mock_ack.assert_called_with(sender, aps_frame_2, 2, 1) diff --git a/tests/test_fragmentation.py b/tests/test_fragmentation.py new file mode 100644 index 00000000..35a816b7 --- /dev/null +++ b/tests/test_fragmentation.py @@ -0,0 +1,216 @@ +from unittest.mock import MagicMock + +import pytest + +from bellows.ezsp.fragmentation import fragment_manager + + +@pytest.fixture +def frag_manager(): + """Return a new FragmentManager instance for each test.""" + return fragment_manager + + +@pytest.mark.asyncio +async def test_single_fragment_complete(frag_manager): + """ + If we receive a single-fragment message (fragment_count=1, fragment_index=0), + the manager should immediately report completion. + """ + key = (0x1234, 0xAB, 0x1234, 0x5678) + fragment_count = 1 + fragment_index = 0 + payload = b"Hello single fragment" + + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=payload, + ) + + assert complete is True + assert reassembled == payload + assert returned_frag_count == fragment_count + assert returned_frag_index == fragment_index + # Make sure it's no longer tracked as partial + assert key not in frag_manager._partial + assert key not in frag_manager._cleanup_timers + + +@pytest.mark.asyncio +async def test_two_fragments_in_order(frag_manager): + """ + A two-fragment message should remain partial until we've received both pieces. + """ + key = (0x1111, 0x01, 0x9999, 0x2222) + fragment_count = 2 + + # First fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"Frag0-", + ) + assert complete is False + assert reassembled is None + assert key in frag_manager._partial + assert frag_manager._partial[key].fragments_received == 1 + + # Second fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=1, + payload=b"Frag1", + ) + assert complete is True + assert reassembled == b"Frag0-Frag1" + # It's removed from partials after completion + assert key not in frag_manager._partial + assert key not in frag_manager._cleanup_timers + + +@pytest.mark.asyncio +async def test_out_of_order_fragments(frag_manager): + """ + Receiving fragments in reverse order should still produce the correct reassembly once all arrive. + """ + key = (0x9999, 0xCD, 0x1234, 0xABCD) + fragment_count = 2 + + # Second fragment arrives first + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=1, + payload=b"World", + ) + assert not complete + assert reassembled is None + + # Then the first fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"Hello ", + ) + assert complete + assert reassembled == b"Hello World" + + +@pytest.mark.asyncio +async def test_repeated_fragments_ignored(frag_manager): + """ + Ensure repeated arrivals of the same fragment index do not double-count or break the logic. + """ + key = (0xAAA, 0xBB, 0xCCC, 0xDDD) + fragment_count = 2 + + # First fragment + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"first", + ) + assert not complete + assert frag_manager._partial[key].fragments_received == 1 + + # Repeat the same fragment index + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=0, + payload=b"first", + ) + assert not complete + assert frag_manager._partial[key].fragments_received == 1, "Should not increment" + + # Second fragment completes + ( + complete, + reassembled, + returned_frag_count, + returned_frag_index, + ) = frag_manager.handle_incoming_fragment( + sender_nwk=key[0], + aps_sequence=key[1], + profile_id=key[2], + cluster_id=key[3], + fragment_count=fragment_count, + fragment_index=1, + payload=b"second", + ) + assert complete + assert reassembled == b"firstsecond" + + +@pytest.mark.asyncio +async def test_cleanup_partial(frag_manager, caplog): + key = (0x1234, 0xAB, 0x1234, 0x5678) + + frag_manager._partial[key] = MagicMock() + frag_manager._cleanup_timers[key] = MagicMock() + frag_manager.cleanup_partial(key) + + assert key not in frag_manager._partial + assert key not in frag_manager._cleanup_timers