Skip to content

Commit 966641f

Browse files
Add serialization registry. (#401)
Co-authored-by: sevdog <setti.davide89@gmail.com>
1 parent 13cef45 commit 966641f

File tree

6 files changed

+371
-75
lines changed

6 files changed

+371
-75
lines changed

README.rst

+47
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
171171
If you want to enforce a matching order, use an ``OrderedDict`` as the
172172
argument; channels will then be matched in the order the dict provides them.
173173

174+
.. _encryption
174175
``symmetric_encryption_keys``
175176
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
176177

@@ -237,6 +238,52 @@ And then in your channels consumer, you can implement the handler:
237238
async def redis_disconnect(self, *args):
238239
# Handle disconnect
239240
241+
242+
243+
``serializer_format``
244+
~~~~~~~~~~~~~~~~~~~~~~
245+
246+
By default every message sent to redis is encoded using `msgpack <https://msgpack.org/>`_ (_currently ``msgpack`` is a mandatory dependency of this package, it may become optional in a future release_).
247+
It is also possible to switch to `JSON <http://www.json.org/>`_:
248+
249+
.. code-block:: python
250+
251+
CHANNEL_LAYERS = {
252+
"default": {
253+
"BACKEND": "channels_redis.core.RedisChannelLayer",
254+
"CONFIG": {
255+
"hosts": ["redis://:password@127.0.0.1:6379/0"],
256+
"serializer_format": "json",
257+
},
258+
},
259+
}
260+
261+
262+
Custom serializers can be defined by:
263+
264+
- extending ``channels_redis.serializers.BaseMessageSerializer``, implementing ``as_bytes `` and ``from_bytes`` methods
265+
- using any class which accepts generic keyword arguments and provides ``serialize``/``deserialize`` methods
266+
267+
Then it may be registered (or can be overriden) by using ``channels_redis.serializers.registry``:
268+
269+
.. code-block:: python
270+
271+
from channels_redis.serializers import registry
272+
273+
class MyFormatSerializer:
274+
def serialize(self, message):
275+
...
276+
def deserialize(self, message):
277+
...
278+
279+
registry.register_serializer('myformat', MyFormatSerializer)
280+
281+
**NOTE**: the registry allows you to override the serializer class used for a specific format without any check nor constraint. Thus it is recommended that to pay particular attention to the order-of-imports when using third-party serializers which may override a built-in format.
282+
283+
284+
Serializers are also responsible for encryption using *symmetric_encryption_keys*. When extending ``channels_redis.serializers.BaseMessageSerializer`` encryption is already configured in the base class, unless you override the ``serialize``/``deserialize`` methods: in this case you should call ``self.crypter.encrypt`` in serialization and ``self.crypter.decrypt`` in deserialization process. When using a fully custom serializer, expect an optional sequence of keys to be passed via ``symmetric_encryption_keys``.
285+
286+
240287
Dependencies
241288
------------
242289

channels_redis/core.py

+13-48
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import asyncio
2-
import base64
32
import collections
43
import functools
5-
import hashlib
64
import itertools
75
import logging
8-
import random
96
import time
107
import uuid
118

12-
import msgpack
139
from redis import asyncio as aioredis
1410

1511
from channels.exceptions import ChannelFull
1612
from channels.layers import BaseChannelLayer
1713

14+
from .serializers import registry
1815
from .utils import (
1916
_close_redis,
2017
_consistent_hash,
@@ -115,6 +112,8 @@ def __init__(
115112
capacity=100,
116113
channel_capacity=None,
117114
symmetric_encryption_keys=None,
115+
random_prefix_length=12,
116+
serializer_format="msgpack",
118117
):
119118
# Store basic information
120119
self.expiry = expiry
@@ -126,15 +125,21 @@ def __init__(
126125
# Configure the host objects
127126
self.hosts = decode_hosts(hosts)
128127
self.ring_size = len(self.hosts)
128+
# serialization
129+
self._serializer = registry.get_serializer(
130+
serializer_format,
131+
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
132+
random_prefix_length=random_prefix_length,
133+
expiry=self.expiry,
134+
symmetric_encryption_keys=symmetric_encryption_keys,
135+
)
129136
# Cached redis connection pools and the event loop they are from
130137
self._layers = {}
131138
# Normal channels choose a host index by cycling through the available hosts
132139
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
133140
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
134141
# Decide on a unique client prefix to use in ! sections
135142
self.client_prefix = uuid.uuid4().hex
136-
# Set up any encryption objects
137-
self._setup_encryption(symmetric_encryption_keys)
138143
# Number of coroutines trying to receive right now
139144
self.receive_count = 0
140145
# The receive lock
@@ -154,24 +159,6 @@ def __init__(
154159
def create_pool(self, index):
155160
return create_pool(self.hosts[index])
156161

157-
def _setup_encryption(self, symmetric_encryption_keys):
158-
# See if we can do encryption if they asked
159-
if symmetric_encryption_keys:
160-
if isinstance(symmetric_encryption_keys, (str, bytes)):
161-
raise ValueError(
162-
"symmetric_encryption_keys must be a list of possible keys"
163-
)
164-
try:
165-
from cryptography.fernet import MultiFernet
166-
except ImportError:
167-
raise ValueError(
168-
"Cannot run with encryption without 'cryptography' installed."
169-
)
170-
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
171-
self.crypter = MultiFernet(sub_fernets)
172-
else:
173-
self.crypter = None
174-
175162
### Channel layer API ###
176163

177164
extensions = ["groups", "flush"]
@@ -656,41 +643,19 @@ def serialize(self, message):
656643
"""
657644
Serializes message to a byte string.
658645
"""
659-
value = msgpack.packb(message, use_bin_type=True)
660-
if self.crypter:
661-
value = self.crypter.encrypt(value)
662-
663-
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
664-
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
665-
return random_prefix + value
646+
return self._serializer.serialize(message)
666647

667648
def deserialize(self, message):
668649
"""
669650
Deserializes from a byte string.
670651
"""
671-
# Removes the random prefix
672-
message = message[12:]
673-
674-
if self.crypter:
675-
message = self.crypter.decrypt(message, self.expiry + 10)
676-
return msgpack.unpackb(message, raw=False)
652+
return self._serializer.deserialize(message)
677653

678654
### Internal functions ###
679655

680656
def consistent_hash(self, value):
681657
return _consistent_hash(value, self.ring_size)
682658

683-
def make_fernet(self, key):
684-
"""
685-
Given a single encryption key, returns a Fernet instance using it.
686-
"""
687-
from cryptography.fernet import Fernet
688-
689-
if isinstance(key, str):
690-
key = key.encode("utf8")
691-
formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest())
692-
return Fernet(formatted_key)
693-
694659
def __str__(self):
695660
return f"{self.__class__.__name__}(hosts={self.hosts})"
696661

channels_redis/pubsub.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import logging
44
import uuid
55

6-
import msgpack
76
from redis import asyncio as aioredis
87

8+
from .serializers import registry
99
from .utils import (
1010
_close_redis,
1111
_consistent_hash,
@@ -25,10 +25,21 @@ async def _async_proxy(obj, name, *args, **kwargs):
2525

2626

2727
class RedisPubSubChannelLayer:
28-
def __init__(self, *args, **kwargs) -> None:
28+
def __init__(
29+
self,
30+
*args,
31+
symmetric_encryption_keys=None,
32+
serializer_format="msgpack",
33+
**kwargs,
34+
) -> None:
2935
self._args = args
3036
self._kwargs = kwargs
3137
self._layers = {}
38+
# serialization
39+
self._serializer = registry.get_serializer(
40+
serializer_format,
41+
symmetric_encryption_keys=symmetric_encryption_keys,
42+
)
3243

3344
def __getattr__(self, name):
3445
if name in (
@@ -48,13 +59,13 @@ def serialize(self, message):
4859
"""
4960
Serializes message to a byte string.
5061
"""
51-
return msgpack.packb(message)
62+
return self._serializer.serialize(message)
5263

5364
def deserialize(self, message):
5465
"""
5566
Deserializes from a byte string.
5667
"""
57-
return msgpack.unpackb(message)
68+
return self._serializer.deserialize(message)
5869

5970
def _get_layer(self):
6071
loop = asyncio.get_running_loop()

0 commit comments

Comments
 (0)