Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored Redis connection utilities to share between layers. #352

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 3 additions & 42 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .utils import _consistent_hash, _wrap_close
from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
self.prefix = prefix
assert isinstance(self.prefix, str), "Prefix must be unicode"
# Configure the host objects
self.hosts = self.decode_hosts(hosts)
self.hosts = decode_hosts(hosts)
self.ring_size = len(self.hosts)
# Cached redis connection pools and the event loop they are from
self._layers = {}
Expand Down Expand Up @@ -146,46 +146,7 @@ def __init__(
self.receive_clean_locks = ChannelLock()

def create_pool(self, index):
host = self.hosts[index]

if "address" in host:
return aioredis.ConnectionPool.from_url(host["address"])
elif "master_name" in host:
sentinels = host.pop("sentinels")
master_name = host.pop("master_name")
sentinel_kwargs = host.pop("sentinel_kwargs", None)
return aioredis.sentinel.SentinelConnectionPool(
master_name,
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
**host,
)
else:
return aioredis.ConnectionPool(**host)

def decode_hosts(self, hosts):
"""
Takes the value of the "hosts" argument passed to the class and returns
a list of kwargs to use for the Redis connection constructor.
"""
# If no hosts were provided, return a default value
if not hosts:
return [{"address": "redis://localhost:6379"}]
# If they provided just a string, scold them.
if isinstance(hosts, (str, bytes)):
raise ValueError(
"You must pass a list of Redis hosts, even if there is only one."
)

# Decode each hosts entry into a kwargs dict
result = []
for entry in hosts:
if isinstance(entry, dict):
result.append(entry)
elif isinstance(entry, tuple):
result.append({"host": entry[0], "port": entry[1]})
else:
result.append({"address": entry})
return result
return create_pool(self.hosts[index])

def _setup_encryption(self, symmetric_encryption_keys):
# See if we can do encryption if they asked
Expand Down
29 changes: 6 additions & 23 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import msgpack
from redis import asyncio as aioredis

from .utils import _consistent_hash, _wrap_close
from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,12 +81,6 @@ def __init__(
channel_layer=None,
**kwargs,
):
if hosts is None:
hosts = ["redis://localhost:6379"]
assert (
isinstance(hosts, list) and len(hosts) > 0
), "`hosts` must be a list with at least one Redis server"

self.prefix = prefix

self.on_disconnect = on_disconnect
Expand All @@ -102,7 +96,9 @@ def __init__(
self.groups = {}

# For each host, we create a `RedisSingleShardConnection` to manage the connection to that host.
self._shards = [RedisSingleShardConnection(host, self) for host in hosts]
self._shards = [
RedisSingleShardConnection(host, self) for host in decode_hosts(hosts)
]

def _get_shard(self, channel_or_group_name):
"""
Expand Down Expand Up @@ -247,9 +243,7 @@ async def flush(self):

class RedisSingleShardConnection:
def __init__(self, host, channel_layer):
self.host = host.copy() if type(host) is dict else {"address": host}
self.master_name = self.host.pop("master_name", None)
self.sentinel_kwargs = self.host.pop("sentinel_kwargs", None)
self.host = host
self.channel_layer = channel_layer
self._subscribed_to = set()
self._lock = asyncio.Lock()
Expand Down Expand Up @@ -331,18 +325,7 @@ def _receive_message(self, message):

def _ensure_redis(self):
if self._redis is None:
if self.master_name is None:
pool = aioredis.ConnectionPool.from_url(self.host["address"])
else:
# aioredis default timeout is way too low
pool = aioredis.sentinel.SentinelConnectionPool(
self.master_name,
aioredis.sentinel.Sentinel(
self.host["sentinels"],
socket_timeout=2,
sentinel_kwargs=self.sentinel_kwargs,
),
)
pool = create_pool(self.host)
self._redis = aioredis.Redis(connection_pool=pool)
self._pubsub = self._redis.pubsub()

Expand Down
52 changes: 52 additions & 0 deletions channels_redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import binascii
import types

from redis import asyncio as aioredis


def _consistent_hash(value, ring_size):
"""
Expand Down Expand Up @@ -31,3 +33,53 @@ def _wrapper(self, *args, **kwargs):
return self.close(*args, **kwargs)

loop.close = types.MethodType(_wrapper, loop)


def decode_hosts(hosts):
"""
Takes the value of the "hosts" argument and returns
a list of kwargs to use for the Redis connection constructor.
"""
# If no hosts were provided, return a default value
if not hosts:
return [{"address": "redis://localhost:6379"}]
# If they provided just a string, scold them.
if isinstance(hosts, (str, bytes)):
raise ValueError(
"You must pass a list of Redis hosts, even if there is only one."
)

# Decode each hosts entry into a kwargs dict
result = []
for entry in hosts:
if isinstance(entry, dict):
result.append(entry)
elif isinstance(entry, (tuple, list)):
result.append({"host": entry[0], "port": entry[1]})
else:
result.append({"address": entry})
return result


def create_pool(host):
"""
Takes the value of the "host" argument and returns a suited connection pool to
the corresponding redis instance.
"""
# avoid side-effects from modifying host
host = host.copy()
if "address" in host:
address = host.pop("address")
return aioredis.ConnectionPool.from_url(address, **host)

master_name = host.pop("master_name", None)
if master_name is not None:
sentinels = host.pop("sentinels")
sentinel_kwargs = host.pop("sentinel_kwargs", None)
return aioredis.sentinel.SentinelConnectionPool(
master_name,
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
**host
)

return aioredis.ConnectionPool(**host)