Add support for custom user filtering.

master
Tom Hacohen 5 years ago
parent 3680bd53b1
commit f9add36f18

@ -46,6 +46,13 @@ class AppSettings:
ret.append(self.import_from_str(perm)) ret.append(self.import_from_str(perm))
return ret return ret
@property
def GET_USER_QUERYSET(self): # pylint: disable=invalid-name
get_user_queryset = self._setting("GET_USER_QUERYSET", None)
if get_user_queryset is not None:
return self.import_from_str(get_user_queryset)
return None
@property @property
def CHALLENGE_VALID_SECONDS(self): # pylint: disable=invalid-name def CHALLENGE_VALID_SECONDS(self): # pylint: disable=invalid-name
return self._setting("CHALLENGE_VALID_SECONDS", 60) return self._setting("CHALLENGE_VALID_SECONDS", 60)

@ -20,6 +20,7 @@ from django.contrib.auth import get_user_model
from django.db import transaction from django.db import transaction
from rest_framework import serializers from rest_framework import serializers
from . import models from . import models
from .utils import get_user_queryset
User = get_user_model() User = get_user_model()
@ -91,6 +92,15 @@ class CollectionContentField(BinaryBase64Field):
return None return None
class UserSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
view = self.context.get('view', None)
return get_user_queryset(super().get_queryset(), view)
def __init__(self, **kwargs):
super().__init__(slug_field=User.USERNAME_FIELD, **kwargs)
class ChunksField(serializers.RelatedField): class ChunksField(serializers.RelatedField):
def to_representation(self, obj): def to_representation(self, obj):
obj = obj.chunk obj = obj.chunk
@ -252,9 +262,8 @@ class CollectionSerializer(serializers.ModelSerializer):
class CollectionMemberSerializer(serializers.ModelSerializer): class CollectionMemberSerializer(serializers.ModelSerializer):
username = serializers.SlugRelatedField( username = UserSlugRelatedField(
source='user', source='user',
slug_field=User.USERNAME_FIELD,
read_only=True, read_only=True,
) )
@ -278,9 +287,8 @@ class CollectionMemberSerializer(serializers.ModelSerializer):
class CollectionInvitationSerializer(serializers.ModelSerializer): class CollectionInvitationSerializer(serializers.ModelSerializer):
username = serializers.SlugRelatedField( username = UserSlugRelatedField(
source='user', source='user',
slug_field=User.USERNAME_FIELD,
queryset=User.objects queryset=User.objects
) )
collection = serializers.CharField(source='collection.uid') collection = serializers.CharField(source='collection.uid')

@ -0,0 +1,12 @@
from django.contrib.auth import get_user_model
from . import app_settings
User = get_user_model()
def get_user_queryset(queryset, view):
custom_func = app_settings.GET_USER_QUERYSET
if custom_func is not None:
return custom_func(queryset, view)
return queryset

@ -71,6 +71,7 @@ from .serializers import (
UserInfoPubkeySerializer, UserInfoPubkeySerializer,
UserSerializer, UserSerializer,
) )
from .utils import get_user_queryset
User = get_user_model() User = get_user_model()
@ -558,8 +559,9 @@ class InvitationOutgoingViewSet(InvitationBaseViewSet):
@action_decorator(detail=False, allowed_methods=['GET'], methods=['GET']) @action_decorator(detail=False, allowed_methods=['GET'], methods=['GET'])
def fetch_user_profile(self, request, *args, **kwargs): def fetch_user_profile(self, request, *args, **kwargs):
username = request.GET.get('username') username = request.GET.get('username')
kwargs = {'owner__' + User.USERNAME_FIELD: username} kwargs = {User.USERNAME_FIELD: username}
user_info = get_object_or_404(UserInfo.objects.all(), **kwargs) user = get_object_or_404(get_user_queryset(User.objects.all(), self), **kwargs)
user_info = get_object_or_404(UserInfo.objects.all(), owner=user)
serializer = UserInfoPubkeySerializer(user_info) serializer = UserInfoPubkeySerializer(user_info)
return Response(serializer.data) return Response(serializer.data)
@ -597,7 +599,7 @@ class AuthenticationViewSet(viewsets.ViewSet):
encoder=nacl.encoding.RawEncoder) encoder=nacl.encoding.RawEncoder)
def get_queryset(self): def get_queryset(self):
return User.objects.all() return get_user_queryset(User.objects.all(), self)
def login_response_data(self, user): def login_response_data(self, user):
return { return {
@ -756,7 +758,8 @@ class TestAuthenticationViewSet(viewsets.ViewSet):
return HttpResponseBadRequest("Only allowed in debug mode.") return HttpResponseBadRequest("Only allowed in debug mode.")
with transaction.atomic(): with transaction.atomic():
user = get_object_or_404(User.objects.all(), username=request.data.get('user').get('username')) user_queryset = get_user_queryset(User.objects.all(), self)
user = get_object_or_404(user_queryset, username=request.data.get('user').get('username'))
# Only allow test users for extra safety # Only allow test users for extra safety
if not getattr(user, User.USERNAME_FIELD).startswith('test_user'): if not getattr(user, User.USERNAME_FIELD).startswith('test_user'):

Loading…
Cancel
Save