From 61bd82f1e3dc6a993d823b0941fbb6f7a864f3fa Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Mon, 11 Jan 2021 18:39:01 +0200 Subject: [PATCH] Subscriptions: stream missing items if user passed an old stoken. --- etebase_fastapi/routers/websocket.py | 35 ++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/etebase_fastapi/routers/websocket.py b/etebase_fastapi/routers/websocket.py index 2d599db..ad3331b 100644 --- a/etebase_fastapi/routers/websocket.py +++ b/etebase_fastapi/routers/websocket.py @@ -2,6 +2,7 @@ import asyncio import typing as t import aioredis +from asgiref.sync import sync_to_async from django.db.models import QuerySet from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status import nacl.encoding @@ -11,6 +12,7 @@ from django_etebase import models from django_etebase.utils import CallbackContext, get_user_queryset from myauth.models import UserType, get_typed_user_model +from ..dependencies import get_collection_queryset, get_item_queryset from ..exceptions import NotSupported from ..msgpack import MsgpackRoute, msgpack_decode, msgpack_encode from ..redis import redisw @@ -72,20 +74,49 @@ def get_websocket_user(websocket: WebSocket, ticket_model: t.Optional[TicketInne @websocket_router.websocket("/{ticket}/") async def websocket_endpoint( websocket: WebSocket, + stoken: t.Optional[str], user: t.Optional[UserType] = Depends(get_websocket_user), ticket_model: TicketInner = Depends(load_websocket_ticket), ): if user is None: return await websocket.accept() - await redis_connector(websocket, ticket_model) + await redis_connector(websocket, ticket_model, user, stoken) -async def redis_connector(websocket: WebSocket, ticket_model: TicketInner): +async def send_item_updates( + websocket: WebSocket, + collection: models.Collection, + user: UserType, + stoken: t.Optional[str], +): + from .collection import item_list_common + + done = False + while not done: + queryset = await sync_to_async(get_item_queryset)(collection) + response = await sync_to_async(item_list_common)(queryset, user, stoken, limit=50, prefetch="auto") + done = response.done + if len(response.data) > 0: + await websocket.send_bytes(msgpack_encode(response.dict())) + + +async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user: UserType, stoken: t.Optional[str]): async def producer_handler(r: aioredis.Redis, ws: WebSocket): channel_name = f"col.{ticket_model.req.collection}" (channel,) = await r.psubscribe(channel_name) assert isinstance(channel, aioredis.Channel) + + # Send missing items if we are not up to date + queryset: QuerySet[models.Collection] = get_collection_queryset(user) + collection: t.Optional[models.Collection] = await sync_to_async( + queryset.filter(uid=ticket_model.req.collection).first + )() + if collection is None: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return + await send_item_updates(websocket, collection, user, stoken) + try: while True: # We wait on the websocket so we fail if web sockets fail or get data