Pass generic context to callbacks instead of the whole view

master
Tom Hacohen 4 years ago
parent 5a6c8a1d05
commit c2eb4fd30c

@ -20,7 +20,7 @@ from django.contrib.auth import get_user_model
from django.db import IntegrityError, transaction from django.db import IntegrityError, transaction
from rest_framework import serializers, status from rest_framework import serializers, status
from . import models from . import models
from .utils import get_user_queryset, create_user from .utils import get_user_queryset, create_user, CallbackContext
from .exceptions import EtebaseValidationError from .exceptions import EtebaseValidationError
@ -102,7 +102,7 @@ class CollectionTypeField(BinaryBase64Field):
class UserSlugRelatedField(serializers.SlugRelatedField): class UserSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self): def get_queryset(self):
view = self.context.get("view", None) view = self.context.get("view", None)
return get_user_queryset(super().get_queryset(), view) return get_user_queryset(super().get_queryset(), context=CallbackContext(view.kwargs))
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(slug_field=User.USERNAME_FIELD, **kwargs) super().__init__(slug_field=User.USERNAME_FIELD, **kwargs)
@ -515,12 +515,17 @@ class AuthenticationSignupSerializer(BetterErrorsMixin, serializers.Serializer):
with transaction.atomic(): with transaction.atomic():
try: try:
view = self.context.get("view", None) view = self.context.get("view", None)
user_queryset = get_user_queryset(User.objects.all(), view) user_queryset = get_user_queryset(User.objects.all(), context=CallbackContext(view.kwargs))
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()}) instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()})
except User.DoesNotExist: except User.DoesNotExist:
# Create the user and save the casing the user chose as the first name # Create the user and save the casing the user chose as the first name
try: try:
instance = create_user(**user_data, password=None, first_name=user_data["username"], view=view) instance = create_user(
**user_data,
password=None,
first_name=user_data["username"],
context=CallbackContext(view.kwargs)
)
instance.full_clean() instance.full_clean()
except EtebaseValidationError as e: except EtebaseValidationError as e:
raise e raise e

@ -1,3 +1,6 @@
import typing as t
from dataclasses import dataclass
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
@ -7,18 +10,24 @@ from . import app_settings
User = get_user_model() User = get_user_model()
def get_user_queryset(queryset, view): @dataclass
class CallbackContext:
"""Class for passing extra context to callbacks"""
url_kwargs: t.Dict[str, t.Any]
def get_user_queryset(queryset, context: CallbackContext):
custom_func = app_settings.GET_USER_QUERYSET_FUNC custom_func = app_settings.GET_USER_QUERYSET_FUNC
if custom_func is not None: if custom_func is not None:
return custom_func(queryset, view) return custom_func(queryset, context)
return queryset return queryset
def create_user(*args, **kwargs): def create_user(context: CallbackContext, *args, **kwargs):
custom_func = app_settings.CREATE_USER_FUNC custom_func = app_settings.CREATE_USER_FUNC
if custom_func is not None: if custom_func is not None:
return custom_func(*args, **kwargs) return custom_func(*args, **kwargs)
_ = kwargs.pop("view")
return User.objects.create_user(*args, **kwargs) return User.objects.create_user(*args, **kwargs)

@ -73,7 +73,7 @@ from .serializers import (
UserInfoPubkeySerializer, UserInfoPubkeySerializer,
UserSerializer, UserSerializer,
) )
from .utils import get_user_queryset from .utils import get_user_queryset, CallbackContext
from .exceptions import EtebaseValidationError from .exceptions import EtebaseValidationError
from .parsers import ChunkUploadParser from .parsers import ChunkUploadParser
from .signals import user_signed_up from .signals import user_signed_up
@ -598,7 +598,7 @@ class InvitationOutgoingViewSet(InvitationBaseViewSet):
def fetch_user_profile(self, request, *args, **kwargs): def fetch_user_profile(self, request, *args, **kwargs):
username = request.GET.get("username") username = request.GET.get("username")
kwargs = {User.USERNAME_FIELD: username.lower()} kwargs = {User.USERNAME_FIELD: username.lower()}
user = get_object_or_404(get_user_queryset(User.objects.all(), self), **kwargs) user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)), **kwargs)
user_info = get_object_or_404(UserInfo.objects.all(), owner=user) user_info = get_object_or_404(UserInfo.objects.all(), owner=user)
serializer = UserInfoPubkeySerializer(user_info) serializer = UserInfoPubkeySerializer(user_info)
return Response(serializer.data) return Response(serializer.data)
@ -642,7 +642,7 @@ class AuthenticationViewSet(viewsets.ViewSet):
) )
def get_queryset(self): def get_queryset(self):
return get_user_queryset(User.objects.all(), self) return get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
def get_serializer_context(self): def get_serializer_context(self):
return {"request": self.request, "format": self.format_kwarg, "view": self} return {"request": self.request, "format": self.format_kwarg, "view": self}
@ -837,7 +837,7 @@ class TestAuthenticationViewSet(viewsets.ViewSet):
return HttpResponseBadRequest("Only allowed in debug mode.") return HttpResponseBadRequest("Only allowed in debug mode.")
with transaction.atomic(): with transaction.atomic():
user_queryset = get_user_queryset(User.objects.all(), self) user_queryset = get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
user = get_object_or_404(user_queryset, username=request.data.get("user").get("username")) user = get_object_or_404(user_queryset, username=request.data.get("user").get("username"))
# Only allow test users for extra safety # Only allow test users for extra safety

Loading…
Cancel
Save