From 46b4f08afa6b6fe96f3fdc3aeb98e96290f3ec05 Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Mon, 13 Jul 2020 16:03:34 +0300 Subject: [PATCH] Signup: use the get_user_queryset function when checking if user exists. --- django_etebase/serializers.py | 4 +++- django_etebase/views.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/django_etebase/serializers.py b/django_etebase/serializers.py index a77c037..29e1d4f 100644 --- a/django_etebase/serializers.py +++ b/django_etebase/serializers.py @@ -394,7 +394,9 @@ class AuthenticationSignupSerializer(serializers.Serializer): with transaction.atomic(): try: - instance = User.objects.get_by_natural_key(user_data['username']) + view = self.context.get('view', None) + user_queryset = get_user_queryset(User.objects.all(), view) + instance = user_queryset.get(**{User.USERNAME_FIELD: user_data['username'].lower()}) except User.DoesNotExist: # Create the user and save the casing the user chose as the first name instance = User.objects.create_user(**user_data, password=None, first_name=user_data['username']) diff --git a/django_etebase/views.py b/django_etebase/views.py index 327bc08..8a6ff85 100644 --- a/django_etebase/views.py +++ b/django_etebase/views.py @@ -601,6 +601,13 @@ class AuthenticationViewSet(viewsets.ViewSet): def get_queryset(self): return get_user_queryset(User.objects.all(), self) + def get_serializer_context(self): + return { + 'request': self.request, + 'format': self.format_kwarg, + 'view': self + } + def login_response_data(self, user): return { 'token': AuthToken.objects.create(user=user).key, @@ -612,7 +619,7 @@ class AuthenticationViewSet(viewsets.ViewSet): @action_decorator(detail=False, methods=['POST']) def signup(self, request, *args, **kwargs): - serializer = AuthenticationSignupSerializer(data=request.data) + serializer = AuthenticationSignupSerializer(data=request.data, context=self.get_serializer_context()) serializer.is_valid(raise_exception=True) user = serializer.save() @@ -748,6 +755,13 @@ class TestAuthenticationViewSet(viewsets.ViewSet): renderer_classes = BaseViewSet.renderer_classes parser_classes = BaseViewSet.parser_classes + def get_serializer_context(self): + return { + 'request': self.request, + 'format': self.format_kwarg, + 'view': self + } + def list(self, request, *args, **kwargs): return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) @@ -768,7 +782,7 @@ class TestAuthenticationViewSet(viewsets.ViewSet): if hasattr(user, 'userinfo'): user.userinfo.delete() - serializer = AuthenticationSignupSerializer(data=request.data) + serializer = AuthenticationSignupSerializer(data=request.data, context=self.get_serializer_context()) serializer.is_valid(raise_exception=True) serializer.save()