Django db cleanup: explicitly add it to dependencies.

We can't really add it manually, because some of the deps are auto included as
parameters. These were not being decorated which in turn meeant issues.
master
Tom Hacohen 4 years ago
parent 5b8f667e55
commit 2e21fe4994

@ -11,6 +11,7 @@ from django_etebase.token_auth.models import AuthToken, get_default_expiry
from myauth.models import UserType, get_typed_user_model from myauth.models import UserType, get_typed_user_model
from .exceptions import AuthenticationFailed from .exceptions import AuthenticationFailed
from .utils import get_object_or_404 from .utils import get_object_or_404
from .db_hack import django_db_cleanup_decorator
User = get_typed_user_model() User = get_typed_user_model()
@ -55,25 +56,30 @@ def __get_authenticated_user(api_token: str):
return token.user, token return token.user, token
@django_db_cleanup_decorator
def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData: def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
user, token = __get_authenticated_user(api_token) user, token = __get_authenticated_user(api_token)
return AuthData(user, token) return AuthData(user, token)
@django_db_cleanup_decorator
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> UserType: def get_authenticated_user(api_token: str = Depends(token_scheme)) -> UserType:
user, _ = __get_authenticated_user(api_token) user, _ = __get_authenticated_user(api_token)
return user return user
@django_db_cleanup_decorator
def get_collection_queryset(user: UserType = Depends(get_authenticated_user)) -> QuerySet: def get_collection_queryset(user: UserType = Depends(get_authenticated_user)) -> QuerySet:
default_queryset: QuerySet = models.Collection.objects.all() default_queryset: QuerySet = models.Collection.objects.all()
return default_queryset.filter(members__user=user) return default_queryset.filter(members__user=user)
@django_db_cleanup_decorator
def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection: def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection:
return get_object_or_404(queryset, uid=collection_uid) return get_object_or_404(queryset, uid=collection_uid)
@django_db_cleanup_decorator
def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet: def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
default_item_queryset: QuerySet = models.CollectionItem.objects.all() default_item_queryset: QuerySet = models.CollectionItem.objects.all()
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks') # XXX Potentially add this for performance: .prefetch_related('revisions__chunks')

@ -1,6 +1,5 @@
import typing as t import typing as t
from fastapi import params
from fastapi.routing import APIRoute, get_request_handler from fastapi.routing import APIRoute, get_request_handler
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import Request
@ -38,21 +37,9 @@ class MsgpackRoute(APIRoute):
# keep track of content-type -> response classes # keep track of content-type -> response classes
ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse} ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse}
def __init__( def __init__(self, path: str, endpoint: t.Callable[..., t.Any], *args, **kwargs):
self,
path: str,
endpoint: t.Callable[..., t.Any],
*args,
dependencies: t.Optional[t.Sequence[params.Depends]] = None,
**kwargs
):
if dependencies is not None:
dependencies = [
params.Depends(django_db_cleanup_decorator(dep.dependency), use_cache=dep.use_cache)
for dep in dependencies
]
endpoint = django_db_cleanup_decorator(endpoint) endpoint = django_db_cleanup_decorator(endpoint)
super().__init__(path, endpoint, *args, dependencies=dependencies, **kwargs) super().__init__(path, endpoint, *args, **kwargs)
def _get_media_type_route_handler(self, media_type): def _get_media_type_route_handler(self, media_type):
return get_request_handler( return get_request_handler(

@ -26,6 +26,7 @@ from ..utils import (
) )
from ..dependencies import get_collection_queryset, get_item_queryset, get_collection from ..dependencies import get_collection_queryset, get_item_queryset, get_collection
from ..sendfile import sendfile from ..sendfile import sendfile
from ..db_hack import django_db_cleanup_decorator
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
@ -222,6 +223,7 @@ def collection_list_common(
# permissions # permissions
@django_db_cleanup_decorator
def verify_collection_admin( def verify_collection_admin(
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user) collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
): ):
@ -229,6 +231,7 @@ def verify_collection_admin(
raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.") raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.")
@django_db_cleanup_decorator
def has_write_access( def has_write_access(
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user) collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
): ):

@ -19,6 +19,7 @@ from ..utils import (
PERMISSIONS_READ, PERMISSIONS_READ,
PERMISSIONS_READWRITE, PERMISSIONS_READWRITE,
) )
from ..db_hack import django_db_cleanup_decorator
User = get_typed_user_model() User = get_typed_user_model()
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
@ -86,10 +87,12 @@ class InvitationListResponse(BaseModel):
done: bool done: bool
@django_db_cleanup_decorator
def get_incoming_queryset(user: UserType = Depends(get_authenticated_user)): def get_incoming_queryset(user: UserType = Depends(get_authenticated_user)):
return default_queryset.filter(user=user) return default_queryset.filter(user=user)
@django_db_cleanup_decorator
def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)): def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)):
return default_queryset.filter(fromMember__user=user) return default_queryset.filter(fromMember__user=user)

@ -10,6 +10,7 @@ from .authentication import get_authenticated_user
from ..msgpack import MsgpackRoute from ..msgpack import MsgpackRoute
from ..utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE from ..utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE
from ..stoken_handler import filter_by_stoken_and_limit from ..stoken_handler import filter_by_stoken_and_limit
from ..db_hack import django_db_cleanup_decorator
from .collection import get_collection, verify_collection_admin from .collection import get_collection, verify_collection_admin
@ -19,10 +20,12 @@ MemberQuerySet = QuerySet[models.CollectionMember]
default_queryset: MemberQuerySet = models.CollectionMember.objects.all() default_queryset: MemberQuerySet = models.CollectionMember.objects.all()
@django_db_cleanup_decorator
def get_queryset(collection: models.Collection = Depends(get_collection)) -> MemberQuerySet: def get_queryset(collection: models.Collection = Depends(get_collection)) -> MemberQuerySet:
return default_queryset.filter(collection=collection) return default_queryset.filter(collection=collection)
@django_db_cleanup_decorator
def get_member(username: str, queryset: MemberQuerySet = Depends(get_queryset)) -> models.CollectionMember: def get_member(username: str, queryset: MemberQuerySet = Depends(get_queryset)) -> models.CollectionMember:
return get_object_or_404(queryset, user__username__iexact=username) return get_object_or_404(queryset, user__username__iexact=username)

Loading…
Cancel
Save