|
|
@ -2,7 +2,6 @@ import dataclasses
|
|
|
|
import typing as t
|
|
|
|
import typing as t
|
|
|
|
from datetime import datetime
|
|
|
|
from datetime import datetime
|
|
|
|
from functools import cached_property
|
|
|
|
from functools import cached_property
|
|
|
|
from django.core import exceptions as django_exceptions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import nacl
|
|
|
|
import nacl
|
|
|
|
import nacl.encoding
|
|
|
|
import nacl.encoding
|
|
|
@ -12,6 +11,7 @@ import nacl.signing
|
|
|
|
from asgiref.sync import sync_to_async
|
|
|
|
from asgiref.sync import sync_to_async
|
|
|
|
from django.conf import settings
|
|
|
|
from django.conf import settings
|
|
|
|
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
|
|
|
|
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
|
|
|
|
|
|
|
|
from django.core import exceptions as django_exceptions
|
|
|
|
from django.db import transaction
|
|
|
|
from django.db import transaction
|
|
|
|
from django.utils import timezone
|
|
|
|
from django.utils import timezone
|
|
|
|
from fastapi import APIRouter, Depends, status, Request, Response
|
|
|
|
from fastapi import APIRouter, Depends, status, Request, Response
|
|
|
@ -21,7 +21,6 @@ from pydantic import BaseModel
|
|
|
|
from django_etebase import app_settings, models
|
|
|
|
from django_etebase import app_settings, models
|
|
|
|
from django_etebase.exceptions import EtebaseValidationError
|
|
|
|
from django_etebase.exceptions import EtebaseValidationError
|
|
|
|
from django_etebase.models import UserInfo
|
|
|
|
from django_etebase.models import UserInfo
|
|
|
|
from django_etebase.serializers import UserSerializer
|
|
|
|
|
|
|
|
from django_etebase.signals import user_signed_up
|
|
|
|
from django_etebase.signals import user_signed_up
|
|
|
|
from django_etebase.token_auth.models import AuthToken
|
|
|
|
from django_etebase.token_auth.models import AuthToken
|
|
|
|
from django_etebase.token_auth.models import get_default_expiry
|
|
|
|
from django_etebase.token_auth.models import get_default_expiry
|
|
|
@ -43,10 +42,16 @@ class AuthData:
|
|
|
|
token: AuthToken
|
|
|
|
token: AuthToken
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoginChallengeData(BaseModel):
|
|
|
|
class LoginChallengeIn(BaseModel):
|
|
|
|
username: str
|
|
|
|
username: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoginChallengeOut(BaseModel):
|
|
|
|
|
|
|
|
salt: bytes
|
|
|
|
|
|
|
|
challenge: bytes
|
|
|
|
|
|
|
|
version: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoginResponse(BaseModel):
|
|
|
|
class LoginResponse(BaseModel):
|
|
|
|
username: str
|
|
|
|
username: str
|
|
|
|
challenge: bytes
|
|
|
|
challenge: bytes
|
|
|
@ -54,6 +59,26 @@ class LoginResponse(BaseModel):
|
|
|
|
action: t.Literal["login", "changePassword"]
|
|
|
|
action: t.Literal["login", "changePassword"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserOut(BaseModel):
|
|
|
|
|
|
|
|
pubkey: bytes
|
|
|
|
|
|
|
|
encryptedContent: bytes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def from_orm(cls: t.Type["UserOut"], obj: User) -> "UserOut":
|
|
|
|
|
|
|
|
return cls(pubkey=obj.userinfo.pubkey, encryptedContent=obj.userinfo.encryptedContent)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoginOut(BaseModel):
|
|
|
|
|
|
|
|
token: str
|
|
|
|
|
|
|
|
user: UserOut
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def from_orm(cls: t.Type["LoginOut"], obj: User) -> "LoginOut":
|
|
|
|
|
|
|
|
token = AuthToken.objects.create(user=obj).key
|
|
|
|
|
|
|
|
user = UserOut.from_orm(obj)
|
|
|
|
|
|
|
|
return cls(token=token, user=user)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Authentication(BaseModel):
|
|
|
|
class Authentication(BaseModel):
|
|
|
|
class Config:
|
|
|
|
class Config:
|
|
|
|
keep_untouched = (cached_property,)
|
|
|
|
keep_untouched = (cached_property,)
|
|
|
@ -145,7 +170,7 @@ def __get_login_user(username: str) -> User:
|
|
|
|
raise AuthenticationFailed(code="user_not_found", detail="User not found")
|
|
|
|
raise AuthenticationFailed(code="user_not_found", detail="User not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_login_user(challenge: LoginChallengeData) -> User:
|
|
|
|
async def get_login_user(challenge: LoginChallengeIn) -> User:
|
|
|
|
user = await __get_login_user(challenge.username)
|
|
|
|
user = await __get_login_user(challenge.username)
|
|
|
|
return user
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
|
@ -161,7 +186,6 @@ def get_encryption_key(salt):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sync_to_async
|
|
|
|
|
|
|
|
def save_changed_password(data: ChangePassword, user: User):
|
|
|
|
def save_changed_password(data: ChangePassword, user: User):
|
|
|
|
response_data = data.response_data
|
|
|
|
response_data = data.response_data
|
|
|
|
user_info: UserInfo = user.userinfo
|
|
|
|
user_info: UserInfo = user.userinfo
|
|
|
@ -170,24 +194,6 @@ def save_changed_password(data: ChangePassword, user: User):
|
|
|
|
user_info.save()
|
|
|
|
user_info.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sync_to_async
|
|
|
|
|
|
|
|
def login_response_data(user: User):
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
|
|
"token": AuthToken.objects.create(user=user).key,
|
|
|
|
|
|
|
|
"user": UserSerializer(user).data,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sync_to_async
|
|
|
|
|
|
|
|
def send_user_logged_in_async(user: User, request: Request):
|
|
|
|
|
|
|
|
user_logged_in.send(sender=user.__class__, request=request, user=user)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sync_to_async
|
|
|
|
|
|
|
|
def send_user_logged_out_async(user: User, request: Request):
|
|
|
|
|
|
|
|
user_logged_out.send(sender=user.__class__, request=request, user=user)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sync_to_async
|
|
|
|
@sync_to_async
|
|
|
|
def validate_login_request(
|
|
|
|
def validate_login_request(
|
|
|
|
validated_data: LoginResponse,
|
|
|
|
validated_data: LoginResponse,
|
|
|
@ -195,39 +201,26 @@ def validate_login_request(
|
|
|
|
user: User,
|
|
|
|
user: User,
|
|
|
|
expected_action: str,
|
|
|
|
expected_action: str,
|
|
|
|
host_from_request: str,
|
|
|
|
host_from_request: str,
|
|
|
|
) -> t.Optional[MsgpackResponse]:
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
enc_key = get_encryption_key(bytes(user.userinfo.salt))
|
|
|
|
enc_key = get_encryption_key(bytes(user.userinfo.salt))
|
|
|
|
box = nacl.secret.SecretBox(enc_key)
|
|
|
|
box = nacl.secret.SecretBox(enc_key)
|
|
|
|
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
|
|
|
|
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
if validated_data.action != expected_action:
|
|
|
|
if validated_data.action != expected_action:
|
|
|
|
content = {
|
|
|
|
raise ValidationError("wrong_action", f'Expected "{challenge_sent_to_user.response}" but got something else')
|
|
|
|
"code": "wrong_action",
|
|
|
|
|
|
|
|
"detail": 'Expected "{}" but got something else'.format(challenge_sent_to_user.response),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
|
|
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
|
|
|
|
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
|
|
|
|
content = {"code": "challenge_expired", "detail": "Login challenge has expired"}
|
|
|
|
raise ValidationError("challenge_expired", "Login challenge has expired")
|
|
|
|
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
|
|
elif challenge_data["userId"] != user.id:
|
|
|
|
elif challenge_data["userId"] != user.id:
|
|
|
|
content = {"code": "wrong_user", "detail": "This challenge is for the wrong user"}
|
|
|
|
raise ValidationError("wrong_user", "This challenge is for the wrong user")
|
|
|
|
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
|
|
elif not settings.DEBUG and validated_data.host.split(":", 1)[0] != host_from_request:
|
|
|
|
elif not settings.DEBUG and validated_data.host.split(":", 1)[0] != host_from_request:
|
|
|
|
detail = 'Found wrong host name. Got: "{}" expected: "{}"'.format(validated_data.host, host_from_request)
|
|
|
|
raise ValidationError(
|
|
|
|
content = {"code": "wrong_host", "detail": detail}
|
|
|
|
"wrong_host", f'Found wrong host name. Got: "{validated_data.host}" expected: "{host_from_request}"'
|
|
|
|
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
)
|
|
|
|
verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder)
|
|
|
|
verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
verify_key.verify(challenge_sent_to_user.response, challenge_sent_to_user.signature)
|
|
|
|
verify_key.verify(challenge_sent_to_user.response, challenge_sent_to_user.signature)
|
|
|
|
except nacl.exceptions.BadSignatureError:
|
|
|
|
except nacl.exceptions.BadSignatureError:
|
|
|
|
return MsgpackResponse(
|
|
|
|
raise ValidationError("login_bad_signature", "Wrong password for user.", status.HTTP_401_UNAUTHORIZED)
|
|
|
|
{"code": "login_bad_signature", "detail": "Wrong password for user."},
|
|
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@authentication_router.post("/login_challenge/")
|
|
|
|
@authentication_router.post("/login_challenge/")
|
|
|
@ -239,35 +232,34 @@ async def login_challenge(user: User = Depends(get_login_user)):
|
|
|
|
"userId": user.id,
|
|
|
|
"userId": user.id,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
|
|
|
|
challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
|
|
|
|
return MsgpackResponse({"salt": user.userinfo.salt, "version": user.userinfo.version, "challenge": challenge})
|
|
|
|
return MsgpackResponse(
|
|
|
|
|
|
|
|
LoginChallengeOut(salt=user.userinfo.salt, challenge=challenge, version=user.userinfo.version)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@authentication_router.post("/login/")
|
|
|
|
@authentication_router.post("/login/")
|
|
|
|
async def login(data: Login, request: Request):
|
|
|
|
async def login(data: Login, request: Request):
|
|
|
|
user = await get_login_user(LoginChallengeData(username=data.response_data.username))
|
|
|
|
user = await get_login_user(LoginChallengeIn(username=data.response_data.username))
|
|
|
|
host = request.headers.get("Host")
|
|
|
|
host = request.headers.get("Host")
|
|
|
|
bad_login_response = await validate_login_request(data.response_data, data, user, "login", host)
|
|
|
|
await validate_login_request(data.response_data, data, user, "login", host)
|
|
|
|
if bad_login_response is not None:
|
|
|
|
data = await sync_to_async(LoginOut.from_orm)(user)
|
|
|
|
return bad_login_response
|
|
|
|
await sync_to_async(user_logged_in.send)(sender=user.__class__, request=None, user=user)
|
|
|
|
data = await login_response_data(user)
|
|
|
|
return MsgpackResponse(content=data, status_code=status.HTTP_200_OK)
|
|
|
|
await send_user_logged_in_async(user, request)
|
|
|
|
|
|
|
|
return MsgpackResponse(data, status_code=status.HTTP_200_OK)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@authentication_router.post("/logout/")
|
|
|
|
@authentication_router.post("/logout/")
|
|
|
|
async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)):
|
|
|
|
async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)):
|
|
|
|
await sync_to_async(auth_data.token.delete)()
|
|
|
|
await sync_to_async(auth_data.token.delete)()
|
|
|
|
await send_user_logged_out_async(auth_data.user, request)
|
|
|
|
# XXX-TOM
|
|
|
|
|
|
|
|
await sync_to_async(user_logged_out.send)(sender=auth_data.user.__class__, request=None, user=auth_data.user)
|
|
|
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
|
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@authentication_router.post("/change_password/")
|
|
|
|
@authentication_router.post("/change_password/")
|
|
|
|
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
|
|
|
|
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
|
|
|
|
host = request.headers.get("Host")
|
|
|
|
host = request.headers.get("Host")
|
|
|
|
bad_login_response = await validate_login_request(data.response_data, data, user, "changePassword", host)
|
|
|
|
await validate_login_request(data.response_data, data, user, "changePassword", host)
|
|
|
|
if bad_login_response is not None:
|
|
|
|
await sync_to_async(save_changed_password)(data, user)
|
|
|
|
return bad_login_response
|
|
|
|
|
|
|
|
await save_changed_password(data, user)
|
|
|
|
|
|
|
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
|
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -300,15 +292,10 @@ def signup_save(data: SignupIn) -> User:
|
|
|
|
return instance
|
|
|
|
return instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sync_to_async
|
|
|
|
|
|
|
|
def send_user_signed_up_async(user: User, request):
|
|
|
|
|
|
|
|
user_signed_up.send(sender=user.__class__, request=request, user=user)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@authentication_router.post("/signup/")
|
|
|
|
@authentication_router.post("/signup/")
|
|
|
|
async def signup(data: SignupIn):
|
|
|
|
async def signup(data: SignupIn):
|
|
|
|
user = await sync_to_async(signup_save)(data)
|
|
|
|
user = await sync_to_async(signup_save)(data)
|
|
|
|
# XXX-TOM
|
|
|
|
# XXX-TOM
|
|
|
|
data = await login_response_data(user)
|
|
|
|
data = await sync_to_async(LoginOut.from_orm)(user)
|
|
|
|
await send_user_signed_up_async(user, None)
|
|
|
|
await sync_to_async(user_signed_up.send)(sender=user.__class__, request=None, user=user)
|
|
|
|
return MsgpackResponse(content=data, status_code=status.HTTP_201_CREATED)
|
|
|
|
return MsgpackResponse(content=data, status_code=status.HTTP_201_CREATED)
|
|
|
|