Fix collection list and how we return API responses.

master
Tom Hacohen 4 years ago
parent ee4e7cf498
commit 80d69a5663

@ -16,7 +16,6 @@ from django.db import transaction
from django.utils import timezone from django.utils import timezone
from fastapi import APIRouter, Depends, status, Request, Response from fastapi import APIRouter, Depends, status, Request, Response
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from django_etebase import app_settings, models from django_etebase import app_settings, models
from django_etebase.exceptions import EtebaseValidationError from django_etebase.exceptions import EtebaseValidationError
@ -27,7 +26,8 @@ from django_etebase.token_auth.models import get_default_expiry
from django_etebase.utils import create_user, get_user_queryset, CallbackContext from django_etebase.utils import create_user, get_user_queryset, CallbackContext
from django_etebase.views import msgpack_encode, msgpack_decode from django_etebase.views import msgpack_encode, msgpack_decode
from .exceptions import AuthenticationFailed, transform_validation_error, ValidationError from .exceptions import AuthenticationFailed, transform_validation_error, ValidationError
from .msgpack import MsgpackResponse, MsgpackRoute from .msgpack import MsgpackRoute
from .utils import BaseModel
User = get_user_model() User = get_user_model()
token_scheme = APIKeyHeader(name="Authorization") token_scheme = APIKeyHeader(name="Authorization")
@ -225,10 +225,10 @@ def validate_login_request(
@authentication_router.get("/is_etebase/") @authentication_router.get("/is_etebase/")
async def is_etebase(): async def is_etebase():
return MsgpackResponse({}) pass
@authentication_router.post("/login_challenge/") @authentication_router.post("/login_challenge/", response_model=LoginChallengeOut)
async def login_challenge(user: User = Depends(get_login_user)): async def login_challenge(user: User = Depends(get_login_user)):
enc_key = get_encryption_key(user.userinfo.salt) enc_key = get_encryption_key(user.userinfo.salt)
box = nacl.secret.SecretBox(enc_key) box = nacl.secret.SecretBox(enc_key)
@ -237,35 +237,31 @@ async def login_challenge(user: User = Depends(get_login_user)):
"userId": user.id, "userId": user.id,
} }
challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder)) challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
return MsgpackResponse( return LoginChallengeOut(salt=user.userinfo.salt, challenge=challenge, version=user.userinfo.version)
LoginChallengeOut(salt=user.userinfo.salt, challenge=challenge, version=user.userinfo.version)
)
@authentication_router.post("/login/") @authentication_router.post("/login/", response_model=LoginOut)
async def login(data: Login, request: Request): async def login(data: Login, request: Request):
user = await get_login_user(LoginChallengeIn(username=data.response_data.username)) user = await get_login_user(LoginChallengeIn(username=data.response_data.username))
host = request.headers.get("Host") host = request.headers.get("Host")
await validate_login_request(data.response_data, data, user, "login", host) await validate_login_request(data.response_data, data, user, "login", host)
data = await sync_to_async(LoginOut.from_orm)(user) data = await sync_to_async(LoginOut.from_orm)(user)
await sync_to_async(user_logged_in.send)(sender=user.__class__, request=None, user=user) await sync_to_async(user_logged_in.send)(sender=user.__class__, request=None, user=user)
return MsgpackResponse(content=data, status_code=status.HTTP_200_OK) return data
@authentication_router.post("/logout/") @authentication_router.post("/logout/", status_code=status.HTTP_204_NO_CONTENT)
async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)): async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)):
await sync_to_async(auth_data.token.delete)() await sync_to_async(auth_data.token.delete)()
# XXX-TOM # XXX-TOM
await sync_to_async(user_logged_out.send)(sender=auth_data.user.__class__, request=None, user=auth_data.user) await sync_to_async(user_logged_out.send)(sender=auth_data.user.__class__, request=None, user=auth_data.user)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@authentication_router.post("/change_password/") @authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT)
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)): async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
host = request.headers.get("Host") host = request.headers.get("Host")
await validate_login_request(data.response_data, data, user, "changePassword", host) await validate_login_request(data.response_data, data, user, "changePassword", host)
await sync_to_async(save_changed_password)(data, user) await sync_to_async(save_changed_password)(data, user)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@authentication_router.post("/dashboard_url/") @authentication_router.post("/dashboard_url/")
@ -278,7 +274,7 @@ def dashboard_url(user: User = Depends(get_authenticated_user)):
ret = { ret = {
"url": get_dashboard_url(request, *args, **kwargs), "url": get_dashboard_url(request, *args, **kwargs),
} }
return MsgpackResponse(ret) return ret
def signup_save(data: SignupIn, request: Request) -> User: def signup_save(data: SignupIn, request: Request) -> User:
@ -311,10 +307,10 @@ def signup_save(data: SignupIn, request: Request) -> User:
return instance return instance
@authentication_router.post("/signup/") @authentication_router.post("/signup/", response_model=LoginOut, status_code=status.HTTP_201_CREATED)
async def signup(data: SignupIn, request: Request): async def signup(data: SignupIn, request: Request):
user = await sync_to_async(signup_save)(data, request) user = await sync_to_async(signup_save)(data, request)
# XXX-TOM # XXX-TOM
data = await sync_to_async(LoginOut.from_orm)(user) data = await sync_to_async(LoginOut.from_orm)(user)
await sync_to_async(user_signed_up.send)(sender=user.__class__, request=None, user=user) await sync_to_async(user_signed_up.send)(sender=user.__class__, request=None, user=user)
return MsgpackResponse(content=data, status_code=status.HTTP_201_CREATED) return data

@ -8,14 +8,13 @@ from django.db import transaction
from django.db.models import Q from django.db.models import Q
from django.db.models import QuerySet from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from pydantic import BaseModel
from django_etebase import models from django_etebase import models
from .authentication import get_authenticated_user from .authentication import get_authenticated_user
from .exceptions import ValidationError, transform_validation_error, PermissionDenied from .exceptions import ValidationError, transform_validation_error, PermissionDenied
from .msgpack import MsgpackRoute, MsgpackResponse from .msgpack import MsgpackRoute
from .stoken_handler import filter_by_stoken_and_limit, filter_by_stoken, get_stoken_obj, get_queryset_stoken from .stoken_handler import filter_by_stoken_and_limit, filter_by_stoken, get_stoken_obj, get_queryset_stoken
from .utils import get_object_or_404, Context, Prefetch, PrefetchQuery, is_collection_admin from .utils import get_object_or_404, Context, Prefetch, PrefetchQuery, is_collection_admin, BaseModel
User = get_user_model() User = get_user_model()
collection_router = APIRouter(route_class=MsgpackRoute, tags=["collection"]) collection_router = APIRouter(route_class=MsgpackRoute, tags=["collection"])
@ -169,7 +168,7 @@ def collection_list_common(
stoken: t.Optional[str], stoken: t.Optional[str],
limit: int, limit: int,
prefetch: Prefetch, prefetch: Prefetch,
) -> MsgpackResponse: ) -> CollectionListResponse:
result, new_stoken_obj, done = filter_by_stoken_and_limit( result, new_stoken_obj, done = filter_by_stoken_and_limit(
stoken, limit, queryset, models.Collection.stoken_annotation stoken, limit, queryset, models.Collection.stoken_annotation
) )
@ -192,7 +191,7 @@ def collection_list_common(
if len(remed) > 0: if len(remed) > 0:
ret.removedMemberships = [{"uid": x} for x in remed] ret.removedMemberships = [{"uid": x} for x in remed]
return MsgpackResponse(content=ret) return ret
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet: def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
@ -230,7 +229,7 @@ def has_write_access(
# paths # paths
@collection_router.post("/list_multi/") @collection_router.post("/list_multi/", response_model=CollectionListResponse, response_model_exclude_unset=True)
async def list_multi( async def list_multi(
data: ListMulti, data: ListMulti,
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
@ -247,7 +246,7 @@ async def list_multi(
return await collection_list_common(queryset, user, stoken, limit, prefetch) return await collection_list_common(queryset, user, stoken, limit, prefetch)
@collection_router.post("/list/") @collection_router.get("/", response_model=CollectionListResponse)
async def collection_list( async def collection_list(
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
@ -323,20 +322,18 @@ def _create(data: CollectionIn, user: User):
).save() ).save()
@collection_router.post("/") @collection_router.post("/", status_code=status.HTTP_201_CREATED)
async def create(data: CollectionIn, user: User = Depends(get_authenticated_user)): async def create(data: CollectionIn, user: User = Depends(get_authenticated_user)):
await sync_to_async(_create)(data, user) await sync_to_async(_create)(data, user)
return MsgpackResponse({}, status_code=status.HTTP_201_CREATED)
@collection_router.get("/{collection_uid}/") @collection_router.get("/{collection_uid}/", response_model=CollectionOut)
def collection_get( def collection_get(
obj: models.Collection = Depends(get_collection), obj: models.Collection = Depends(get_collection),
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery prefetch: Prefetch = PrefetchQuery
): ):
ret = CollectionOut.from_orm_context(obj, Context(user, prefetch)) return CollectionOut.from_orm_context(obj, Context(user, prefetch))
return MsgpackResponse(ret)
def item_create(item_model: CollectionItemIn, collection: models.Collection, validate_etag: bool): def item_create(item_model: CollectionItemIn, collection: models.Collection, validate_etag: bool):
@ -379,15 +376,14 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val
return instance return instance
@item_router.get("/item/{item_uid}/") @item_router.get("/item/{item_uid}/", response_model=CollectionItemOut)
def item_get( def item_get(
item_uid: str, item_uid: str,
queryset: QuerySet = Depends(get_item_queryset), queryset: QuerySet = Depends(get_item_queryset),
user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery,
): ):
obj = queryset.get(uid=item_uid) obj = queryset.get(uid=item_uid)
ret = CollectionItemOut.from_orm_context(obj, Context(user, prefetch)) return CollectionItemOut.from_orm_context(obj, Context(user, prefetch))
return MsgpackResponse(ret)
@sync_to_async @sync_to_async
@ -397,18 +393,17 @@ def item_list_common(
stoken: t.Optional[str], stoken: t.Optional[str],
limit: int, limit: int,
prefetch: Prefetch, prefetch: Prefetch,
) -> MsgpackResponse: ) -> CollectionItemListResponse:
result, new_stoken_obj, done = filter_by_stoken_and_limit( result, new_stoken_obj, done = filter_by_stoken_and_limit(
stoken, limit, queryset, models.CollectionItem.stoken_annotation stoken, limit, queryset, models.CollectionItem.stoken_annotation
) )
new_stoken = new_stoken_obj and new_stoken_obj.uid new_stoken = new_stoken_obj and new_stoken_obj.uid
context = Context(user, prefetch) context = Context(user, prefetch)
data: t.List[CollectionItemOut] = [CollectionItemOut.from_orm_context(item, context) for item in result] data: t.List[CollectionItemOut] = [CollectionItemOut.from_orm_context(item, context) for item in result]
ret = CollectionItemListResponse(data=data, stoken=new_stoken, done=done) return CollectionItemListResponse(data=data, stoken=new_stoken, done=done)
return MsgpackResponse(content=ret)
@item_router.get("/item/") @item_router.get("/item/", response_model=CollectionItemListResponse)
async def item_list( async def item_list(
queryset: QuerySet = Depends(get_item_queryset), queryset: QuerySet = Depends(get_item_queryset),
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
@ -437,10 +432,10 @@ def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid
for item in data.items: for item in data.items:
item_create(item, collection_object, validate_etag) item_create(item, collection_object, validate_etag)
return MsgpackResponse({}) return None
@item_router.get("/item/{item_uid}/revision/") @item_router.get("/item/{item_uid}/revision/", response_model=CollectionItemRevisionListResponse)
def item_revisions( def item_revisions(
item_uid: str, item_uid: str,
limit: int = 50, limit: int = 50,
@ -468,15 +463,14 @@ def item_revisions(
ret_data = [CollectionItemRevisionInOut.from_orm_context(revision, context) for revision in result] ret_data = [CollectionItemRevisionInOut.from_orm_context(revision, context) for revision in result]
iterator = ret_data[-1].uid if len(result) > 0 else None iterator = ret_data[-1].uid if len(result) > 0 else None
ret = CollectionItemRevisionListResponse( return CollectionItemRevisionListResponse(
data=ret_data, data=ret_data,
iterator=iterator, iterator=iterator,
done=done, done=done,
) )
return MsgpackResponse(ret)
@item_router.post("/item/fetch_updates/") @item_router.post("/item/fetch_updates/", response_model=CollectionItemListResponse)
def fetch_updates( def fetch_updates(
data: t.List[CollectionItemBulkGetIn], data: t.List[CollectionItemBulkGetIn],
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
@ -502,12 +496,11 @@ def fetch_updates(
new_stoken = new_stoken or stoken new_stoken = new_stoken or stoken
context = Context(user, prefetch) context = Context(user, prefetch)
ret = CollectionItemListResponse( return CollectionItemListResponse(
data=[CollectionItemOut.from_orm_context(item, context) for item in queryset], data=[CollectionItemOut.from_orm_context(item, context) for item in queryset],
stoken=new_stoken, stoken=new_stoken,
done=True, # we always return all the items, so it's always done done=True, # we always return all the items, so it's always done
) )
return MsgpackResponse(ret)
@item_router.post("/item/transaction/", dependencies=[Depends(has_write_access)]) @item_router.post("/item/transaction/", dependencies=[Depends(has_write_access)])

@ -4,14 +4,13 @@ from django.contrib.auth import get_user_model
from django.db import transaction, IntegrityError from django.db import transaction, IntegrityError
from django.db.models import QuerySet from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status, Request from fastapi import APIRouter, Depends, status, Request
from pydantic import BaseModel
from django_etebase import models from django_etebase import models
from django_etebase.utils import get_user_queryset, CallbackContext from django_etebase.utils import get_user_queryset, CallbackContext
from .authentication import get_authenticated_user from .authentication import get_authenticated_user
from .exceptions import ValidationError, PermissionDenied from .exceptions import ValidationError, PermissionDenied
from .msgpack import MsgpackRoute, MsgpackResponse from .msgpack import MsgpackRoute
from .utils import get_object_or_404, Context, is_collection_admin from .utils import get_object_or_404, Context, is_collection_admin, BaseModel
User = get_user_model() User = get_user_model()
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, tags=["incoming invitation"]) invitation_incoming_router = APIRouter(route_class=MsgpackRoute, tags=["incoming invitation"])
@ -85,7 +84,7 @@ def list_common(
queryset: QuerySet, queryset: QuerySet,
iterator: t.Optional[str], iterator: t.Optional[str],
limit: int, limit: int,
) -> MsgpackResponse: ) -> InvitationListResponse:
queryset = queryset.order_by("id") queryset = queryset.order_by("id")
if iterator is not None: if iterator is not None:
@ -102,12 +101,11 @@ def list_common(
ret_data = result ret_data = result
iterator = ret_data[-1].uid if len(result) > 0 else None iterator = ret_data[-1].uid if len(result) > 0 else None
ret = InvitationListResponse( return InvitationListResponse(
data=ret_data, data=ret_data,
iterator=iterator, iterator=iterator,
done=done, done=done,
) )
return MsgpackResponse(ret)
@invitation_incoming_router.get("/", response_model=InvitationListResponse) @invitation_incoming_router.get("/", response_model=InvitationListResponse)
@ -125,8 +123,7 @@ def incoming_get(
queryset: QuerySet = Depends(get_incoming_queryset), queryset: QuerySet = Depends(get_incoming_queryset),
): ):
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
ret = CollectionInvitationOut.from_orm(obj) return CollectionInvitationOut.from_orm(obj)
return MsgpackResponse(ret)
@invitation_incoming_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT) @invitation_incoming_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT)
@ -191,8 +188,6 @@ def outgoing_create(
except IntegrityError: except IntegrityError:
raise ValidationError("invitation_exists", "Invitation already exists") raise ValidationError("invitation_exists", "Invitation already exists")
return MsgpackResponse(CollectionInvitationOut.from_orm(ret), status_code=status.HTTP_201_CREATED)
@invitation_outgoing_router.get("/", response_model=InvitationListResponse) @invitation_outgoing_router.get("/", response_model=InvitationListResponse)
def outgoing_list( def outgoing_list(
@ -221,5 +216,4 @@ def outgoing_fetch_user_profile(
kwargs = {User.USERNAME_FIELD: username.lower()} kwargs = {User.USERNAME_FIELD: username.lower()}
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs) user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)
user_info = get_object_or_404(models.UserInfo.objects.all(), owner=user) user_info = get_object_or_404(models.UserInfo.objects.all(), owner=user)
ret = UserInfoOut.from_orm(user_info) return UserInfoOut.from_orm(user_info)
return MsgpackResponse(ret)

@ -4,12 +4,11 @@ from django.contrib.auth import get_user_model
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from pydantic import BaseModel
from django_etebase import models from django_etebase import models
from .authentication import get_authenticated_user from .authentication import get_authenticated_user
from .msgpack import MsgpackRoute, MsgpackResponse from .msgpack import MsgpackRoute
from .utils import get_object_or_404 from .utils import get_object_or_404, BaseModel
from .stoken_handler import filter_by_stoken_and_limit from .stoken_handler import filter_by_stoken_and_limit
from .collection import get_collection, verify_collection_admin from .collection import get_collection, verify_collection_admin
@ -61,12 +60,11 @@ def member_list(
) )
new_stoken = new_stoken_obj and new_stoken_obj.uid new_stoken = new_stoken_obj and new_stoken_obj.uid
ret = MemberListResponse( return MemberListResponse(
data=[CollectionMemberOut.from_orm(item) for item in result], data=[CollectionMemberOut.from_orm(item) for item in result],
iterator=new_stoken, iterator=new_stoken,
done=done, done=done,
) )
return MsgpackResponse(ret)
@member_router.delete( @member_router.delete(

@ -24,7 +24,7 @@ class MsgpackResponse(Response):
return b"" return b""
if isinstance(content, BaseModel): if isinstance(content, BaseModel):
content = content.dict(exclude_unset=True) content = content.dict()
return msgpack.packb(content, use_bin_type=True) return msgpack.packb(content, use_bin_type=True)

@ -2,6 +2,7 @@ import dataclasses
import typing as t import typing as t
from fastapi import status, Query from fastapi import status, Query
from pydantic import BaseModel as PyBaseModel
from django.db.models import QuerySet from django.db.models import QuerySet
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
@ -17,6 +18,13 @@ Prefetch = t.Literal["auto", "medium"]
PrefetchQuery = Query(default="auto") PrefetchQuery = Query(default="auto")
class BaseModel(PyBaseModel):
class Config:
json_encoders = {
bytes: lambda x: x,
}
@dataclasses.dataclass @dataclasses.dataclass
class Context: class Context:
user: t.Optional[User] user: t.Optional[User]

Loading…
Cancel
Save