Replace aioredis with redis-py

aioredis has been merged into redis-py and will no longer be maintained
as a separate project.
master
Tom Hacohen 2 years ago committed by Xiretza
parent 2f1f95fea9
commit 8c6d04e8d3

@ -1,5 +1,5 @@
import typing as t import typing as t
import aioredis from redis import asyncio as aioredis
from etebase_server.django import app_settings from etebase_server.django import app_settings
@ -12,12 +12,11 @@ class RedisWrapper:
async def setup(self): async def setup(self):
if self.redis_uri is not None: if self.redis_uri is not None:
self.redis = await aioredis.create_redis_pool(self.redis_uri) self.redis = await aioredis.from_url(self.redis_uri)
async def close(self): async def close(self):
if hasattr(self, "redis"): if hasattr(self, "redis"):
self.redis.close() await self.redis.close()
await self.redis.wait_closed()
@property @property
def is_active(self): def is_active(self):

@ -1,7 +1,8 @@
import asyncio import asyncio
import typing as t import typing as t
import aioredis from redis import asyncio as aioredis
from redis.exceptions import ConnectionError
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.db.models import QuerySet from django.db.models import QuerySet
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
@ -51,7 +52,7 @@ async def get_ticket(
uid = nacl.encoding.URLSafeBase64Encoder.encode(nacl.utils.random(32)) uid = nacl.encoding.URLSafeBase64Encoder.encode(nacl.utils.random(32))
ticket_model = TicketInner(user=user.id, req=ticket_request) ticket_model = TicketInner(user=user.id, req=ticket_request)
ticket_raw = msgpack_encode(ticket_model.dict()) ticket_raw = msgpack_encode(ticket_model.dict())
await redisw.redis.set(uid, ticket_raw, expire=TICKET_VALIDITY_SECONDS * 1000) await redisw.redis.set(uid, ticket_raw, ex=TICKET_VALIDITY_SECONDS * 1000)
return TicketOut(ticket=uid) return TicketOut(ticket=uid)
@ -103,9 +104,9 @@ async def send_item_updates(
async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user: UserType, stoken: t.Optional[str]): async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user: UserType, stoken: t.Optional[str]):
async def producer_handler(r: aioredis.Redis, ws: WebSocket): async def producer_handler(r: aioredis.Redis, ws: WebSocket):
pubsub = r.pubsub()
channel_name = f"col.{ticket_model.req.collection}" channel_name = f"col.{ticket_model.req.collection}"
(channel,) = await r.psubscribe(channel_name) await pubsub.subscribe(channel_name)
assert isinstance(channel, aioredis.Channel)
# Send missing items if we are not up to date # Send missing items if we are not up to date
queryset: QuerySet[models.Collection] = get_collection_queryset(user) queryset: QuerySet[models.Collection] = get_collection_queryset(user)
@ -117,12 +118,20 @@ async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user:
return return
await send_item_updates(websocket, collection, user, stoken) await send_item_updates(websocket, collection, user, stoken)
async def handle_message():
msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=20)
message_raw = t.cast(t.Optional[t.Tuple[str, bytes]], msg)
if message_raw:
_, message = message_raw
await ws.send_bytes(message)
try: try:
while True: while True:
# We wait on the websocket so we fail if web sockets fail or get data # We wait on the websocket so we fail if web sockets fail or get data
receive = asyncio.create_task(websocket.receive()) receive = asyncio.create_task(websocket.receive())
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
{receive, channel.wait_message()}, return_when=asyncio.FIRST_COMPLETED {receive, handle_message()},
return_when=asyncio.FIRST_COMPLETED,
) )
for task in pending: for task in pending:
task.cancel() task.cancel()
@ -131,12 +140,7 @@ async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION) await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return return
message_raw = t.cast(t.Optional[t.Tuple[str, bytes]], await channel.get()) except ConnectionError:
if message_raw:
_, message = message_raw
await ws.send_bytes(message)
except aioredis.errors.ConnectionClosedError:
await websocket.close(code=status.WS_1012_SERVICE_RESTART) await websocket.close(code=status.WS_1012_SERVICE_RESTART)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

@ -5,4 +5,4 @@ fastapi
typing_extensions typing_extensions
uvicorn[standard] uvicorn[standard]
aiofiles aiofiles
aioredis redis>=4.2.0rc1

@ -6,8 +6,6 @@
# #
aiofiles==0.8.0 aiofiles==0.8.0
# via -r requirements.in/base.txt # via -r requirements.in/base.txt
aioredis==2.0.1
# via -r requirements.in/base.txt
anyio==3.5.0 anyio==3.5.0
# via # via
# starlette # starlette
@ -17,11 +15,13 @@ asgiref==3.5.0
# django # django
# uvicorn # uvicorn
async-timeout==4.0.2 async-timeout==4.0.2
# via aioredis # via redis
cffi==1.15.0 cffi==1.15.0
# via pynacl # via pynacl
click==8.0.4 click==8.0.4
# via uvicorn # via uvicorn
deprecated==1.2.13
# via redis
django==3.2.13 django==3.2.13
# via -r requirements.in/base.txt # via -r requirements.in/base.txt
fastapi==0.75.0 fastapi==0.75.0
@ -34,18 +34,24 @@ idna==3.3
# via anyio # via anyio
msgpack==1.0.3 msgpack==1.0.3
# via -r requirements.in/base.txt # via -r requirements.in/base.txt
packaging==21.3
# via redis
pycparser==2.21 pycparser==2.21
# via cffi # via cffi
pydantic==1.9.0 pydantic==1.9.0
# via fastapi # via fastapi
pynacl==1.5.0 pynacl==1.5.0
# via -r requirements.in/base.txt # via -r requirements.in/base.txt
pyparsing==3.0.9
# via packaging
python-dotenv==0.19.2 python-dotenv==0.19.2
# via uvicorn # via uvicorn
pytz==2022.1 pytz==2022.1
# via django # via django
pyyaml==6.0 pyyaml==6.0
# via uvicorn # via uvicorn
redis==4.3.4
# via -r requirements.in/base.txt
sniffio==1.2.0 sniffio==1.2.0
# via anyio # via anyio
sqlparse==0.4.2 sqlparse==0.4.2
@ -55,7 +61,6 @@ starlette==0.17.1
typing-extensions==4.1.1 typing-extensions==4.1.1
# via # via
# -r requirements.in/base.txt # -r requirements.in/base.txt
# aioredis
# pydantic # pydantic
uvicorn[standard]==0.17.6 uvicorn[standard]==0.17.6
# via -r requirements.in/base.txt # via -r requirements.in/base.txt
@ -65,3 +70,5 @@ watchgod==0.8.1
# via uvicorn # via uvicorn
websockets==10.2 websockets==10.2
# via uvicorn # via uvicorn
wrapt==1.14.1
# via deprecated

Loading…
Cancel
Save