diff --git a/etebase_fastapi/__init__.py b/etebase_fastapi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/etebase_fastapi/app.py b/etebase_fastapi/app.py new file mode 100644 index 0000000..0ee7aae --- /dev/null +++ b/etebase_fastapi/app.py @@ -0,0 +1,29 @@ +import os + +from django.core.wsgi import get_wsgi_application +from fastapi.middleware.cors import CORSMiddleware + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "etebase_server.settings") +application = get_wsgi_application() +from fastapi import FastAPI, Request + +from .execptions import CustomHttpException +from .authentication import authentication_router +from .msgpack import MsgpackResponse + +app = FastAPI() +app.include_router(authentication_router, prefix="/api/v1/authentication") +app.add_middleware( + CORSMiddleware, allow_origin_regex="https?://.*", allow_credentials=True, allow_methods=["*"], allow_headers=["*"] +) + + +@app.exception_handler(CustomHttpException) +async def custom_exception_handler(request: Request, exc: CustomHttpException): + return MsgpackResponse(status_code=exc.status_code, content=exc.as_dict) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8080) diff --git a/etebase_fastapi/authentication.py b/etebase_fastapi/authentication.py new file mode 100644 index 0000000..b1a3272 --- /dev/null +++ b/etebase_fastapi/authentication.py @@ -0,0 +1,251 @@ +import dataclasses +import typing as t +from datetime import datetime +from functools import cached_property + +import nacl +import nacl.encoding +import nacl.hash +import nacl.secret +import nacl.signing +from asgiref.sync import sync_to_async +from django.conf import settings +from django.contrib.auth import get_user_model, user_logged_out, user_logged_in +from django.utils import timezone +from fastapi import APIRouter, Depends, status, Request, Response +from fastapi.security import APIKeyHeader +from pydantic import BaseModel + +from django_etebase import app_settings +from django_etebase.models import UserInfo +from django_etebase.serializers import UserSerializer +from django_etebase.token_auth.models import AuthToken +from django_etebase.token_auth.models import get_default_expiry +from django_etebase.views import msgpack_encode, msgpack_decode +from .execptions import AuthenticationFailed +from .msgpack import MsgpackResponse, MsgpackRoute + +User = get_user_model() +token_scheme = APIKeyHeader(name="Authorization") +AUTO_REFRESH = True +MIN_REFRESH_INTERVAL = 60 +authentication_router = APIRouter(route_class=MsgpackRoute) + + +@dataclasses.dataclass(frozen=True) +class AuthData: + user: User + token: AuthToken + + +class LoginChallengeData(BaseModel): + username: str + + +class LoginResponse(BaseModel): + username: str + challenge: bytes + host: str + action: t.Literal["login", "changePassword"] + + +class Authentication(BaseModel): + response: bytes + signature: bytes + + +class Login(Authentication): + @cached_property + def response_data(self) -> LoginResponse: + return LoginResponse(**msgpack_decode(self.response)) + + +class ChangePasswordResponse(LoginResponse): + loginPubkey: bytes + encryptedContent: bytes + + +class ChangePassword(Authentication): + @cached_property + def response_data(self) -> ChangePasswordResponse: + return ChangePasswordResponse(**msgpack_decode(self.response)) + + +def __renew_token(auth_token: AuthToken): + current_expiry = auth_token.expiry + new_expiry = get_default_expiry() + # Throttle refreshing of token to avoid db writes + delta = (new_expiry - current_expiry).total_seconds() + if delta > MIN_REFRESH_INTERVAL: + auth_token.expiry = new_expiry + auth_token.save(update_fields=("expiry",)) + + +@sync_to_async +def __get_authenticated_user(api_token: str): + api_token = api_token.split()[1] + try: + token: AuthToken = AuthToken.objects.select_related("user").get(key=api_token) + except AuthToken.DoesNotExist: + raise AuthenticationFailed(detail="Invalid token.") + if not token.user.is_active: + raise AuthenticationFailed(detail="User inactive or deleted.") + + if token.expiry is not None: + if token.expiry < timezone.now(): + token.delete() + raise AuthenticationFailed(detail="Invalid token.") + + if AUTO_REFRESH: + __renew_token(token) + + return token.user, token + + +async def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData: + user, token = await __get_authenticated_user(api_token) + return AuthData(user, token) + + +async def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User: + user, token = await __get_authenticated_user(api_token) + return user + + +@sync_to_async +def __get_login_user(username: str) -> User: + kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()} + try: + user = User.objects.get(**kwargs) + if not hasattr(user, "userinfo"): + raise AuthenticationFailed(code="user_not_init", detail="User not properly init") + return user + except User.DoesNotExist: + raise AuthenticationFailed(code="user_not_found", detail="User not found") + + +async def get_login_user(challenge: LoginChallengeData) -> User: + user = await __get_login_user(challenge.username) + return user + + +def get_encryption_key(salt): + key = nacl.hash.blake2b(settings.SECRET_KEY.encode(), encoder=nacl.encoding.RawEncoder) + return nacl.hash.blake2b( + b"", + key=key, + salt=salt[: nacl.hash.BLAKE2B_SALTBYTES], + person=b"etebase-auth", + encoder=nacl.encoding.RawEncoder, + ) + + +@sync_to_async +def save_changed_password(data: ChangePassword, user: User): + response_data = data.response_data + user_info: UserInfo = user.userinfo + user_info.loginPubkey = response_data.loginPubkey + user_info.encryptedContent = response_data.encryptedContent + user_info.save() + + +@sync_to_async +def login_response_data(user: User): + return { + "token": AuthToken.objects.create(user=user).key, + "user": UserSerializer(user).data, + } + + +@sync_to_async +def send_user_logged_in_async(user: User, request: Request): + user_logged_in.send(sender=user.__class__, request=request, user=user) + + +@sync_to_async +def send_user_logged_out_async(user: User, request: Request): + user_logged_out.send(sender=user.__class__, request=request, user=user) + + +@sync_to_async +def validate_login_request( + validated_data: LoginResponse, + challenge_sent_to_user: Authentication, + user: User, + expected_action: str, + host_from_request: str, +) -> t.Optional[MsgpackResponse]: + + enc_key = get_encryption_key(bytes(user.userinfo.salt)) + box = nacl.secret.SecretBox(enc_key) + challenge_data = msgpack_decode(box.decrypt(validated_data.challenge)) + now = int(datetime.now().timestamp()) + if validated_data.action != expected_action: + content = { + "code": "wrong_action", + "detail": 'Expected "{}" but got something else'.format(challenge_sent_to_user.response), + } + return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST) + elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS: + content = {"code": "challenge_expired", "detail": "Login challenge has expired"} + return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST) + elif challenge_data["userId"] != user.id: + content = {"code": "wrong_user", "detail": "This challenge is for the wrong user"} + return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST) + elif not settings.DEBUG and validated_data.host.split(":", 1)[0] != host_from_request: + detail = 'Found wrong host name. Got: "{}" expected: "{}"'.format(validated_data.host, host_from_request) + content = {"code": "wrong_host", "detail": detail} + return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST) + + verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder) + + try: + verify_key.verify(challenge_sent_to_user.response, challenge_sent_to_user.signature) + except nacl.exceptions.BadSignatureError: + return MsgpackResponse( + {"code": "login_bad_signature", "detail": "Wrong password for user."}, + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + return None + + +@authentication_router.post("/login_challenge/") +async def login_challenge(user: User = Depends(get_login_user)): + enc_key = get_encryption_key(user.userinfo.salt) + box = nacl.secret.SecretBox(enc_key) + challenge_data = { + "timestamp": int(datetime.now().timestamp()), + "userId": user.id, + } + challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder)) + return MsgpackResponse({"salt": user.userinfo.salt, "version": user.userinfo.version, "challenge": challenge}) + + +@authentication_router.post("/login/") +async def login(data: Login, request: Request): + user = await get_login_user(LoginChallengeData(username=data.response_data.username)) + host = request.headers.get("Host") + bad_login_response = await validate_login_request(data.response_data, data, user, "login", host) + if bad_login_response is not None: + return bad_login_response + data = await login_response_data(user) + await send_user_logged_in_async(user, request) + return MsgpackResponse(data, status_code=status.HTTP_200_OK) + + +@authentication_router.post("/logout/") +async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)): + await sync_to_async(auth_data.token.delete)() + await send_user_logged_out_async(auth_data.user, request) + return Response(status_code=status.HTTP_204_NO_CONTENT) + + +@authentication_router.post("/change_password/") +async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)): + host = request.headers.get("Host") + bad_login_response = await validate_login_request(data.response_data, data, user, "changePassword", host) + if bad_login_response is not None: + return bad_login_response + await save_changed_password(data, user) + return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/etebase_fastapi/collections.py b/etebase_fastapi/collections.py new file mode 100644 index 0000000..e69de29 diff --git a/etebase_fastapi/execptions.py b/etebase_fastapi/execptions.py new file mode 100644 index 0000000..8808f5d --- /dev/null +++ b/etebase_fastapi/execptions.py @@ -0,0 +1,42 @@ +from fastapi import status + + +class CustomHttpException(Exception): + def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): + self.status_code = status_code + self.code = code + self.detail = detail + + @property + def as_dict(self) -> dict: + return {"code": self.code, "detail": self.detail} + + +class AuthenticationFailed(CustomHttpException): + def __init__( + self, + code="authentication_failed", + detail: str = "Incorrect authentication credentials.", + status_code: int = status.HTTP_401_UNAUTHORIZED, + ): + super().__init__(code=code, detail=detail, status_code=status_code) + + +class NotAuthenticated(CustomHttpException): + def __init__( + self, + code="not_authenticated", + detail: str = "Authentication credentials were not provided.", + status_code: int = status.HTTP_401_UNAUTHORIZED, + ): + super().__init__(code=code, detail=detail, status_code=status_code) + + +class PermissionDenied(CustomHttpException): + def __init__( + self, + code="permission_denied", + detail: str = "You do not have permission to perform this action.", + status_code: int = status.HTTP_403_FORBIDDEN, + ): + super().__init__(code=code, detail=detail, status_code=status_code) diff --git a/etebase_fastapi/msgpack.py b/etebase_fastapi/msgpack.py new file mode 100644 index 0000000..53e18cb --- /dev/null +++ b/etebase_fastapi/msgpack.py @@ -0,0 +1,63 @@ +import typing as t +import msgpack +from fastapi.routing import APIRoute, get_request_handler +from starlette.requests import Request +from starlette.responses import Response + + +class MsgpackRequest(Request): + media_type = "application/msgpack" + + async def json(self) -> bytes: + if not hasattr(self, "_json"): + body = await super().body() + self._json = msgpack.unpackb(body, raw=False) + return self._json + + +class MsgpackResponse(Response): + media_type = "application/msgpack" + + def render(self, content: t.Any) -> bytes: + return msgpack.packb(content, use_bin_type=True) + + +class MsgpackRoute(APIRoute): + # keep track of content-type -> request classes + REQUESTS_CLASSES = {MsgpackRequest.media_type: MsgpackRequest} + # keep track of content-type -> response classes + ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse} + + def _get_media_type_route_handler(self, media_type): + return get_request_handler( + dependant=self.dependant, + body_field=self.body_field, + status_code=self.status_code, + # use custom response class or fallback on default self.response_class + response_class=self.ROUTES_HANDLERS_CLASSES.get(media_type, self.response_class), + response_field=self.secure_cloned_response_field, + response_model_include=self.response_model_include, + response_model_exclude=self.response_model_exclude, + response_model_by_alias=self.response_model_by_alias, + response_model_exclude_unset=self.response_model_exclude_unset, + response_model_exclude_defaults=self.response_model_exclude_defaults, + response_model_exclude_none=self.response_model_exclude_none, + dependency_overrides_provider=self.dependency_overrides_provider, + ) + + def get_route_handler(self) -> t.Callable: + async def custom_route_handler(request: Request) -> Response: + + content_type = request.headers.get("Content-Type") + try: + request_cls = self.REQUESTS_CLASSES[content_type] + request = request_cls(request.scope, request.receive) + except KeyError: + # nothing registered to handle content_type, process given requests as-is + pass + + accept = request.headers.get("Accept") + route_handler = self._get_media_type_route_handler(accept) + return await route_handler(request) + + return custom_route_handler diff --git a/requirements.in/base.txt b/requirements.in/base.txt index 7d5bf7e..ca8dd94 100644 --- a/requirements.in/base.txt +++ b/requirements.in/base.txt @@ -5,3 +5,5 @@ drf-nested-routers msgpack psycopg2-binary pynacl +fastapi +uvicorn \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f6c8ed4..3d19eaf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,16 +4,22 @@ # # pip-compile --output-file=requirements.txt requirements.in/base.txt # -asgiref==3.2.10 # via django -cffi==1.14.0 # via pynacl -django-cors-headers==3.2.1 # via -r requirements.in/base.txt -django==3.1.1 # via -r requirements.in/base.txt, django-cors-headers, djangorestframework, drf-nested-routers -djangorestframework==3.11.0 # via -r requirements.in/base.txt, drf-nested-routers -drf-nested-routers==0.91 # via -r requirements.in/base.txt -msgpack==1.0.0 # via -r requirements.in/base.txt -psycopg2-binary==2.8.4 # via -r requirements.in/base.txt +asgiref==3.3.1 # via django +cffi==1.14.4 # via pynacl +click==7.1.2 # via uvicorn +django-cors-headers==3.6.0 # via -r requirements.in/base.txt +django==3.1.4 # via -r requirements.in/base.txt, django-cors-headers, djangorestframework, drf-nested-routers +djangorestframework==3.12.2 # via -r requirements.in/base.txt, drf-nested-routers +drf-nested-routers==0.92.5 # via -r requirements.in/base.txt +fastapi==0.63.0 # via -r requirements.in/base.txt +h11==0.11.0 # via uvicorn +msgpack==1.0.2 # via -r requirements.in/base.txt +psycopg2-binary==2.8.6 # via -r requirements.in/base.txt pycparser==2.20 # via cffi -pynacl==1.3.0 # via -r requirements.in/base.txt -pytz==2019.3 # via django -six==1.14.0 # via pynacl -sqlparse==0.3.0 # via django +pydantic==1.7.3 # via fastapi +pynacl==1.4.0 # via -r requirements.in/base.txt +pytz==2020.4 # via django +six==1.15.0 # via pynacl +sqlparse==0.4.1 # via django +starlette==0.13.6 # via fastapi +uvicorn==0.13.2 # via -r requirements.in/base.txt