Fix login_challenge to work with get_user_queryset.

master
Tom Hacohen 4 years ago
parent ff55904f49
commit e6b47ae1a9

@ -111,11 +111,13 @@ class SignupIn(BaseModel):
encryptedContent: bytes encryptedContent: bytes
@sync_to_async def get_login_user(request: Request, challenge: LoginChallengeIn) -> UserType:
def __get_login_user(username: str) -> UserType: username = challenge.username
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()} kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
try: try:
user = User.objects.get(**kwargs) user_queryset = get_user_queryset(User.objects.all(), CallbackContext(request.path_params))
user = user_queryset.get(**kwargs)
if not hasattr(user, "userinfo"): if not hasattr(user, "userinfo"):
raise AuthenticationFailed(code="user_not_init", detail="User not properly init") raise AuthenticationFailed(code="user_not_init", detail="User not properly init")
return user return user
@ -123,11 +125,6 @@ def __get_login_user(username: str) -> UserType:
raise AuthenticationFailed(code="user_not_found", detail="User not found") raise AuthenticationFailed(code="user_not_found", detail="User not found")
async def get_login_user(challenge: LoginChallengeIn) -> UserType:
user = await __get_login_user(challenge.username)
return user
def get_encryption_key(salt): def get_encryption_key(salt):
key = nacl.hash.blake2b(settings.SECRET_KEY.encode(), encoder=nacl.encoding.RawEncoder) key = nacl.hash.blake2b(settings.SECRET_KEY.encode(), encoder=nacl.encoding.RawEncoder)
return nacl.hash.blake2b( return nacl.hash.blake2b(
@ -196,7 +193,7 @@ def login_challenge(user: UserType = Depends(get_login_user)):
@authentication_router.post("/login/", response_model=LoginOut) @authentication_router.post("/login/", response_model=LoginOut)
async def login(data: Login, request: Request): async def login(data: Login, request: Request):
user = await get_login_user(LoginChallengeIn(username=data.response_data.username)) user = await sync_to_async(get_login_user)(request, LoginChallengeIn(username=data.response_data.username))
host = request.headers.get("Host") host = request.headers.get("Host")
await validate_login_request(data.response_data, data, user, "login", host) await validate_login_request(data.response_data, data, user, "login", host)
data = await sync_to_async(LoginOut.from_orm)(user) data = await sync_to_async(LoginOut.from_orm)(user)

Loading…
Cancel
Save