master
Tal Leibman 4 years ago committed by Tom Hacohen
parent 7d86459480
commit c90e92b0f0

@ -2,6 +2,7 @@ import dataclasses
import typing as t import typing as t
from datetime import datetime from datetime import datetime
from functools import cached_property from functools import cached_property
from django.core import exceptions as django_exceptions
import nacl import nacl
import nacl.encoding import nacl.encoding
@ -11,16 +12,19 @@ import nacl.signing
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
from django.db import transaction
from django.utils import timezone from django.utils import timezone
from fastapi import APIRouter, Depends, status, Request, Response from fastapi import APIRouter, Depends, status, Request, Response
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from pydantic import BaseModel from pydantic import BaseModel
from django_etebase import app_settings from django_etebase import app_settings
from django_etebase.exceptions import EtebaseValidationError
from django_etebase.models import UserInfo from django_etebase.models import UserInfo
from django_etebase.serializers import UserSerializer from django_etebase.serializers import UserSerializer
from django_etebase.token_auth.models import AuthToken from django_etebase.token_auth.models import AuthToken
from django_etebase.token_auth.models import get_default_expiry from django_etebase.token_auth.models import get_default_expiry
from django_etebase.utils import create_user
from django_etebase.views import msgpack_encode, msgpack_decode from django_etebase.views import msgpack_encode, msgpack_decode
from .execptions import AuthenticationFailed from .execptions import AuthenticationFailed
from .msgpack import MsgpackResponse, MsgpackRoute from .msgpack import MsgpackResponse, MsgpackRoute
@ -74,6 +78,19 @@ class ChangePassword(Authentication):
return ChangePasswordResponse(**msgpack_decode(self.response)) return ChangePasswordResponse(**msgpack_decode(self.response))
class UserSignup(BaseModel):
username: str
email: str
class SignupIn(BaseModel):
user: UserSignup
salt: bytes
loginPubkey: bytes
pubkey: bytes
encryptedContent: bytes
def __renew_token(auth_token: AuthToken): def __renew_token(auth_token: AuthToken):
current_expiry = auth_token.expiry current_expiry = auth_token.expiry
new_expiry = get_default_expiry() new_expiry = get_default_expiry()
@ -252,3 +269,38 @@ async def change_password(data: ChangePassword, request: Request, user: User = D
return bad_login_response return bad_login_response
await save_changed_password(data, user) await save_changed_password(data, user)
return Response(status_code=status.HTTP_204_NO_CONTENT) return Response(status_code=status.HTTP_204_NO_CONTENT)
@sync_to_async
def signup_save(data: SignupIn):
user_data = data.user
with transaction.atomic():
try:
# XXX-TOM
# view = self.context.get("view", None)
# user_queryset = get_user_queryset(User.objects.all(), view)
user_queryset = User.objects.all()
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data.username.lower()})
except User.DoesNotExist:
# Create the user and save the casing the user chose as the first name
try:
# XXX-TOM
instance = create_user(**user_data.dict(), password=None, first_name=user_data.username, view=None)
instance.full_clean()
except EtebaseValidationError as e:
raise e
except django_exceptions.ValidationError as e:
self.transform_validation_error("user", e)
except Exception as e:
raise EtebaseValidationError("generic", str(e))
if hasattr(instance, "userinfo"):
raise EtebaseValidationError("user_exists", "User already exists", status_code=status.HTTP_409_CONFLICT)
models.UserInfo.objects.create(**validated_data, owner=instance)
return instance
@authentication_router.post("/signup/")
async def signup(data: SignupIn):
pass

@ -1,5 +1,7 @@
from fastapi import status from fastapi import status
from django_etebase.exceptions import EtebaseValidationError
class CustomHttpException(Exception): class CustomHttpException(Exception):
def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
@ -40,3 +42,47 @@ class PermissionDenied(CustomHttpException):
status_code: int = status.HTTP_403_FORBIDDEN, status_code: int = status.HTTP_403_FORBIDDEN,
): ):
super().__init__(code=code, detail=detail, status_code=status_code) super().__init__(code=code, detail=detail, status_code=status_code)
class ValidationError(CustomHttpException):
def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
super().__init__(code=code, detail=detail, status_code=status_code)
def flatten_errors(field_name, errors):
ret = []
if isinstance(errors, dict):
for error_key in errors:
error = errors[error_key]
ret.extend(flatten_errors("{}.{}".format(field_name, error_key), error))
else:
for error in errors:
if error.messages:
message = error.messages[0]
else:
message = str(error)
ret.append(
{
"field": field_name,
"code": error.code,
"detail": message,
}
)
return ret
def transform_validation_error(prefix, err):
if hasattr(err, "error_dict"):
errors = flatten_errors(prefix, err.error_dict)
elif not hasattr(err, "message"):
errors = flatten_errors(prefix, err.error_list)
else:
raise EtebaseValidationError(err.code, err.message)
raise ValidationError(code="field_errors", detail="Field validations failed.")
raise serializers.ValidationError(
{
"code": "field_errors",
"detail": "Field validations failed.",
"errors": errors,
}
)

Loading…
Cancel
Save