|
|
|
@ -1,7 +1,8 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import typing as t
|
|
|
|
|
|
|
|
|
|
import aioredis
|
|
|
|
|
from redis import asyncio as aioredis
|
|
|
|
|
from redis.exceptions import ConnectionError
|
|
|
|
|
from asgiref.sync import sync_to_async
|
|
|
|
|
from django.db.models import QuerySet
|
|
|
|
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
|
|
|
|
@ -51,7 +52,7 @@ async def get_ticket(
|
|
|
|
|
uid = nacl.encoding.URLSafeBase64Encoder.encode(nacl.utils.random(32))
|
|
|
|
|
ticket_model = TicketInner(user=user.id, req=ticket_request)
|
|
|
|
|
ticket_raw = msgpack_encode(ticket_model.dict())
|
|
|
|
|
await redisw.redis.set(uid, ticket_raw, expire=TICKET_VALIDITY_SECONDS * 1000)
|
|
|
|
|
await redisw.redis.set(uid, ticket_raw, ex=TICKET_VALIDITY_SECONDS * 1000)
|
|
|
|
|
return TicketOut(ticket=uid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -103,9 +104,9 @@ async def send_item_updates(
|
|
|
|
|
|
|
|
|
|
async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user: UserType, stoken: t.Optional[str]):
|
|
|
|
|
async def producer_handler(r: aioredis.Redis, ws: WebSocket):
|
|
|
|
|
pubsub = r.pubsub()
|
|
|
|
|
channel_name = f"col.{ticket_model.req.collection}"
|
|
|
|
|
(channel,) = await r.psubscribe(channel_name)
|
|
|
|
|
assert isinstance(channel, aioredis.Channel)
|
|
|
|
|
await pubsub.subscribe(channel_name)
|
|
|
|
|
|
|
|
|
|
# Send missing items if we are not up to date
|
|
|
|
|
queryset: QuerySet[models.Collection] = get_collection_queryset(user)
|
|
|
|
@ -117,12 +118,20 @@ async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user:
|
|
|
|
|
return
|
|
|
|
|
await send_item_updates(websocket, collection, user, stoken)
|
|
|
|
|
|
|
|
|
|
async def handle_message():
|
|
|
|
|
msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=20)
|
|
|
|
|
message_raw = t.cast(t.Optional[t.Tuple[str, bytes]], msg)
|
|
|
|
|
if message_raw:
|
|
|
|
|
_, message = message_raw
|
|
|
|
|
await ws.send_bytes(message)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
# We wait on the websocket so we fail if web sockets fail or get data
|
|
|
|
|
receive = asyncio.create_task(websocket.receive())
|
|
|
|
|
done, pending = await asyncio.wait(
|
|
|
|
|
{receive, channel.wait_message()}, return_when=asyncio.FIRST_COMPLETED
|
|
|
|
|
{receive, handle_message()},
|
|
|
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
|
|
|
)
|
|
|
|
|
for task in pending:
|
|
|
|
|
task.cancel()
|
|
|
|
@ -131,12 +140,7 @@ async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user:
|
|
|
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
message_raw = t.cast(t.Optional[t.Tuple[str, bytes]], await channel.get())
|
|
|
|
|
if message_raw:
|
|
|
|
|
_, message = message_raw
|
|
|
|
|
await ws.send_bytes(message)
|
|
|
|
|
|
|
|
|
|
except aioredis.errors.ConnectionClosedError:
|
|
|
|
|
except ConnectionError:
|
|
|
|
|
await websocket.close(code=status.WS_1012_SERVICE_RESTART)
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
pass
|
|
|
|
|