Move common dependencies to their own file.

master
Tom Hacohen 4 years ago
parent 3e39aa88a1
commit c2a2e710c9

@ -1,4 +1,3 @@
import dataclasses
import typing as t import typing as t
from datetime import datetime from datetime import datetime
from functools import cached_property from functools import cached_property
@ -13,33 +12,22 @@ from django.conf import settings
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
from django.core import exceptions as django_exceptions from django.core import exceptions as django_exceptions
from django.db import transaction from django.db import transaction
from django.utils import timezone
from fastapi import APIRouter, Depends, status, Request from fastapi import APIRouter, Depends, status, Request
from fastapi.security import APIKeyHeader
from django_etebase import app_settings, models from django_etebase import app_settings, models
from django_etebase.token_auth.models import AuthToken
from django_etebase.models import UserInfo from django_etebase.models import UserInfo
from django_etebase.signals import user_signed_up from django_etebase.signals import user_signed_up
from django_etebase.token_auth.models import AuthToken
from django_etebase.token_auth.models import get_default_expiry
from django_etebase.utils import create_user, get_user_queryset, CallbackContext from django_etebase.utils import create_user, get_user_queryset, CallbackContext
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
from .msgpack import MsgpackRoute from .msgpack import MsgpackRoute
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
from .dependencies import AuthData, get_auth_data, get_authenticated_user
User = get_user_model() User = get_user_model()
token_scheme = APIKeyHeader(name="Authorization")
AUTO_REFRESH = True
MIN_REFRESH_INTERVAL = 60
authentication_router = APIRouter(route_class=MsgpackRoute) authentication_router = APIRouter(route_class=MsgpackRoute)
@dataclasses.dataclass(frozen=True)
class AuthData:
user: User
token: AuthToken
class LoginChallengeIn(BaseModel): class LoginChallengeIn(BaseModel):
username: str username: str
@ -115,47 +103,6 @@ class SignupIn(BaseModel):
encryptedContent: bytes encryptedContent: bytes
def __renew_token(auth_token: AuthToken):
current_expiry = auth_token.expiry
new_expiry = get_default_expiry()
# Throttle refreshing of token to avoid db writes
delta = (new_expiry - current_expiry).total_seconds()
if delta > MIN_REFRESH_INTERVAL:
auth_token.expiry = new_expiry
auth_token.save(update_fields=("expiry",))
@sync_to_async
def __get_authenticated_user(api_token: str):
api_token = api_token.split()[1]
try:
token: AuthToken = AuthToken.objects.select_related("user").get(key=api_token)
except AuthToken.DoesNotExist:
raise AuthenticationFailed(detail="Invalid token.")
if not token.user.is_active:
raise AuthenticationFailed(detail="User inactive or deleted.")
if token.expiry is not None:
if token.expiry < timezone.now():
token.delete()
raise AuthenticationFailed(detail="Invalid token.")
if AUTO_REFRESH:
__renew_token(token)
return token.user, token
async def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
user, token = await __get_authenticated_user(api_token)
return AuthData(user, token)
async def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User:
user, token = await __get_authenticated_user(api_token)
return user
@sync_to_async @sync_to_async
def __get_login_user(username: str) -> User: def __get_login_user(username: str) -> User:
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()} kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}

@ -5,8 +5,7 @@ from django.contrib.auth import get_user_model
from django.core import exceptions as django_exceptions from django.core import exceptions as django_exceptions
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
from django.db import transaction from django.db import transaction
from django.db.models import Q from django.db.models import Q, QuerySet
from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from django_etebase import models from django_etebase import models
@ -25,12 +24,11 @@ from .utils import (
PERMISSIONS_READ, PERMISSIONS_READ,
PERMISSIONS_READWRITE, PERMISSIONS_READWRITE,
) )
from .dependencies import get_collection_queryset, get_item_queryset, get_collection
User = get_user_model() User = get_user_model()
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
default_queryset: QuerySet = models.Collection.objects.all()
default_item_queryset: QuerySet = models.CollectionItem.objects.all()
class ListMulti(BaseModel): class ListMulti(BaseModel):
@ -203,21 +201,6 @@ def collection_list_common(
return ret return ret
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
return default_queryset.filter(members__user=user)
def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection:
return get_object_or_404(queryset, uid=collection_uid)
def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
queryset = default_item_queryset.filter(collection__pk=collection.pk, revisions__current=True)
return queryset
# permissions # permissions

@ -0,0 +1,82 @@
import dataclasses
from fastapi import Depends
from fastapi.security import APIKeyHeader
from django.contrib.auth import get_user_model
from django.utils import timezone
from django.db.models import QuerySet
from django_etebase import models
from django_etebase.token_auth.models import AuthToken, get_default_expiry
from .exceptions import AuthenticationFailed
from .utils import get_object_or_404
User = get_user_model()
token_scheme = APIKeyHeader(name="Authorization")
AUTO_REFRESH = True
MIN_REFRESH_INTERVAL = 60
@dataclasses.dataclass(frozen=True)
class AuthData:
user: User
token: AuthToken
def __renew_token(auth_token: AuthToken):
current_expiry = auth_token.expiry
new_expiry = get_default_expiry()
# Throttle refreshing of token to avoid db writes
delta = (new_expiry - current_expiry).total_seconds()
if delta > MIN_REFRESH_INTERVAL:
auth_token.expiry = new_expiry
auth_token.save(update_fields=("expiry",))
def __get_authenticated_user(api_token: str):
api_token = api_token.split()[1]
try:
token: AuthToken = AuthToken.objects.select_related("user").get(key=api_token)
except AuthToken.DoesNotExist:
raise AuthenticationFailed(detail="Invalid token.")
if not token.user.is_active:
raise AuthenticationFailed(detail="User inactive or deleted.")
if token.expiry is not None:
if token.expiry < timezone.now():
token.delete()
raise AuthenticationFailed(detail="Invalid token.")
if AUTO_REFRESH:
__renew_token(token)
return token.user, token
def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
user, token = __get_authenticated_user(api_token)
return AuthData(user, token)
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User:
user, _ = __get_authenticated_user(api_token)
return user
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
default_queryset: QuerySet = models.Collection.objects.all()
return default_queryset.filter(members__user=user)
def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection:
return get_object_or_404(queryset, uid=collection_uid)
def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
default_item_queryset: QuerySet = models.CollectionItem.objects.all()
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
queryset = default_item_queryset.filter(collection__pk=collection.pk, revisions__current=True)
return queryset
Loading…
Cancel
Save