From 709bc6c1fc2be28b6d2ad9e4bc7bd12de2828dfd Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Tue, 29 Dec 2020 17:18:09 +0200 Subject: [PATCH] Improve typing information. --- django_etebase/models.py | 47 ++++++++++++++++++++------- django_etebase/utils.py | 5 +-- etebase_fastapi/routers/collection.py | 18 +++++----- etebase_fastapi/routers/invitation.py | 19 ++++++----- etebase_fastapi/routers/member.py | 9 ++--- etebase_fastapi/utils.py | 7 ++-- myauth/models.py | 4 +-- 7 files changed, 70 insertions(+), 39 deletions(-) diff --git a/django_etebase/models.py b/django_etebase/models.py index 3060fa4..7725a19 100644 --- a/django_etebase/models.py +++ b/django_etebase/models.py @@ -12,6 +12,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import typing as t from pathlib import Path from django.db import models, transaction @@ -28,7 +29,7 @@ from . import app_settings UidValidator = RegexValidator(regex=r"^[a-zA-Z0-9\-_]{20,}$", message="Not a valid UID") -def stoken_annotation_builder(stoken_id_fields): +def stoken_annotation_builder(stoken_id_fields: t.List[str]): aggr_fields = [Coalesce(Max(field), V(0)) for field in stoken_id_fields] return Greatest(*aggr_fields) if len(aggr_fields) > 1 else aggr_fields[0] @@ -37,6 +38,8 @@ class CollectionType(models.Model): owner = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) uid = models.BinaryField(editable=True, blank=False, null=False, db_index=True, unique=True) + objects: models.manager.BaseManager["CollectionType"] + class Collection(models.Model): main_item = models.OneToOneField("CollectionItem", related_name="parent", null=True, on_delete=models.SET_NULL) @@ -46,19 +49,21 @@ class Collection(models.Model): stoken_annotation = stoken_annotation_builder(["items__revisions__stoken", "members__stoken"]) + objects: models.manager.BaseManager["Collection"] + def __str__(self): return self.uid @property - def content(self): + def content(self) -> "CollectionItemRevision": return self.main_item.content @property - def etag(self): + def etag(self) -> str: return self.content.uid @cached_property - def stoken(self): + def stoken(self) -> str: stoken_id = ( self.__class__.objects.filter(main_item=self.main_item) .annotate(max_stoken=self.stoken_annotation) @@ -80,6 +85,8 @@ class CollectionItem(models.Model): stoken_annotation = stoken_annotation_builder(["revisions__stoken"]) + objects: models.manager.BaseManager["CollectionItem"] + class Meta: unique_together = ("uid", "collection") @@ -87,23 +94,23 @@ class CollectionItem(models.Model): return "{} {}".format(self.uid, self.collection.uid) @cached_property - def content(self): + def content(self) -> "CollectionItemRevision": return self.revisions.get(current=True) @property - def etag(self): + def etag(self) -> str: return self.content.uid -def chunk_directory_path(instance, filename): +def chunk_directory_path(instance: "CollectionItemChunk", filename: str) -> Path: custom_func = app_settings.CHUNK_PATH_FUNC if custom_func is not None: return custom_func(instance, filename) - col = instance.collection - user_id = col.owner.id - uid_prefix = instance.uid[:2] - uid_rest = instance.uid[2:] + col: Collection = instance.collection + user_id: int = col.owner.id + uid_prefix: str = instance.uid[:2] + uid_rest: str = instance.uid[2:] return Path("user_{}".format(user_id), col.uid, uid_prefix, uid_rest) @@ -112,6 +119,8 @@ class CollectionItemChunk(models.Model): collection = models.ForeignKey(Collection, related_name="chunks", on_delete=models.CASCADE) chunkFile = models.FileField(upload_to=chunk_directory_path, max_length=150, unique=True) + objects: models.manager.BaseManager["CollectionItemChunk"] + def __str__(self): return self.uid @@ -135,6 +144,8 @@ class Stoken(models.Model): validators=[UidValidator], ) + objects: models.manager.BaseManager["Stoken"] + class CollectionItemRevision(models.Model): stoken = models.OneToOneField(Stoken, on_delete=models.PROTECT) @@ -146,6 +157,8 @@ class CollectionItemRevision(models.Model): current = models.BooleanField(db_index=True, default=True, null=True) deleted = models.BooleanField(default=False) + objects: models.manager.BaseManager["CollectionItemRevision"] + class Meta: unique_together = ("item", "current") @@ -157,6 +170,8 @@ class RevisionChunkRelation(models.Model): chunk = models.ForeignKey(CollectionItemChunk, related_name="revisions_relation", on_delete=models.CASCADE) revision = models.ForeignKey(CollectionItemRevision, related_name="chunks_relation", on_delete=models.CASCADE) + objects: models.manager.BaseManager["RevisionChunkRelation"] + class Meta: ordering = ("id",) @@ -180,6 +195,8 @@ class CollectionMember(models.Model): stoken_annotation = stoken_annotation_builder(["stoken"]) + objects: models.manager.BaseManager["CollectionMember"] + class Meta: unique_together = ("user", "collection") @@ -204,6 +221,8 @@ class CollectionMemberRemoved(models.Model): collection = models.ForeignKey(Collection, related_name="removed_members", on_delete=models.CASCADE) user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) + objects: models.manager.BaseManager["CollectionMemberRemoved"] + class Meta: unique_together = ("user", "collection") @@ -225,6 +244,8 @@ class CollectionInvitation(models.Model): default=AccessLevels.READ_ONLY, ) + objects: models.manager.BaseManager["CollectionInvitation"] + class Meta: unique_together = ("user", "fromMember") @@ -232,7 +253,7 @@ class CollectionInvitation(models.Model): return "{} {}".format(self.fromMember.collection.uid, self.user) @cached_property - def collection(self): + def collection(self) -> Collection: return self.fromMember.collection @@ -244,5 +265,7 @@ class UserInfo(models.Model): encryptedContent = models.BinaryField(editable=True, blank=False, null=False) salt = models.BinaryField(editable=True, blank=False, null=False) + objects: models.manager.BaseManager["UserInfo"] + def __str__(self): return "UserInfo<{}>".format(self.owner) diff --git a/django_etebase/utils.py b/django_etebase/utils.py index d812ae3..3a05fd4 100644 --- a/django_etebase/utils.py +++ b/django_etebase/utils.py @@ -1,6 +1,7 @@ import typing as t from dataclasses import dataclass +from django.db.models import QuerySet from django.core.exceptions import PermissionDenied from myauth.models import UserType, get_typed_user_model @@ -18,14 +19,14 @@ class CallbackContext: user: t.Optional[UserType] = None -def get_user_queryset(queryset, context: CallbackContext): +def get_user_queryset(queryset: QuerySet[UserType], context: CallbackContext) -> QuerySet[UserType]: custom_func = app_settings.GET_USER_QUERYSET_FUNC if custom_func is not None: return custom_func(queryset, context) return queryset -def create_user(context: CallbackContext, *args, **kwargs): +def create_user(context: CallbackContext, *args, **kwargs) -> UserType: custom_func = app_settings.CREATE_USER_FUNC if custom_func is not None: return custom_func(context, *args, **kwargs) diff --git a/etebase_fastapi/routers/collection.py b/etebase_fastapi/routers/collection.py index 56afd7b..4825626 100644 --- a/etebase_fastapi/routers/collection.py +++ b/etebase_fastapi/routers/collection.py @@ -30,6 +30,8 @@ from ..sendfile import sendfile User = get_typed_user_model collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) +CollectionQuerySet = QuerySet[models.Collection] +CollectionItemQuerySet = QuerySet[models.CollectionItem] class ListMulti(BaseModel): @@ -187,7 +189,7 @@ class ItemBatchIn(BaseModel): @sync_to_async def collection_list_common( - queryset: QuerySet, + queryset: CollectionQuerySet, user: UserType, stoken: t.Optional[str], limit: int, @@ -249,7 +251,7 @@ async def list_multi( data: ListMulti, stoken: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_collection_queryset), + queryset: CollectionQuerySet = Depends(get_collection_queryset), user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): @@ -267,7 +269,7 @@ async def collection_list( limit: int = 50, prefetch: Prefetch = PrefetchQuery, user: UserType = Depends(get_authenticated_user), - queryset: QuerySet = Depends(get_collection_queryset), + queryset: CollectionQuerySet = Depends(get_collection_queryset), ): return await collection_list_common(queryset, user, stoken, limit, prefetch) @@ -395,7 +397,7 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val @item_router.get("/item/{item_uid}/", response_model=CollectionItemOut, dependencies=PERMISSIONS_READ) def item_get( item_uid: str, - queryset: QuerySet = Depends(get_item_queryset), + queryset: CollectionItemQuerySet = Depends(get_item_queryset), user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): @@ -405,7 +407,7 @@ def item_get( @sync_to_async def item_list_common( - queryset: QuerySet, + queryset: CollectionItemQuerySet, user: UserType, stoken: t.Optional[str], limit: int, @@ -422,7 +424,7 @@ def item_list_common( @item_router.get("/item/", response_model=CollectionItemListResponse, dependencies=PERMISSIONS_READ) async def item_list( - queryset: QuerySet = Depends(get_item_queryset), + queryset: CollectionItemQuerySet = Depends(get_item_queryset), stoken: t.Optional[str] = None, limit: int = 50, prefetch: Prefetch = PrefetchQuery, @@ -471,7 +473,7 @@ def item_revisions( iterator: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, user: UserType = Depends(get_authenticated_user), - items: QuerySet = Depends(get_item_queryset), + items: CollectionItemQuerySet = Depends(get_item_queryset), ): item = get_object_or_404(items, uid=item_uid) @@ -505,7 +507,7 @@ def fetch_updates( stoken: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, user: UserType = Depends(get_authenticated_user), - queryset: QuerySet = Depends(get_item_queryset), + queryset: CollectionItemQuerySet = Depends(get_item_queryset), ): # FIXME: make configurable? item_limit = 200 diff --git a/etebase_fastapi/routers/invitation.py b/etebase_fastapi/routers/invitation.py index 6a06c60..aceb05d 100644 --- a/etebase_fastapi/routers/invitation.py +++ b/etebase_fastapi/routers/invitation.py @@ -23,7 +23,8 @@ from ..utils import ( User = get_typed_user_model() invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) -default_queryset: QuerySet = models.CollectionInvitation.objects.all() +InvitationQuerySet = QuerySet[models.CollectionInvitation] +default_queryset: InvitationQuerySet = models.CollectionInvitation.objects.all() class UserInfoOut(BaseModel): @@ -94,7 +95,7 @@ def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)): def list_common( - queryset: QuerySet, + queryset: InvitationQuerySet, iterator: t.Optional[str], limit: int, ) -> InvitationListResponse: @@ -125,7 +126,7 @@ def list_common( def incoming_list( iterator: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): return list_common(queryset, iterator, limit) @@ -135,7 +136,7 @@ def incoming_list( ) def incoming_get( invitation_uid: str, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): obj = get_object_or_404(queryset, uid=invitation_uid) return CollectionInvitationOut.from_orm(obj) @@ -146,7 +147,7 @@ def incoming_get( ) def incoming_delete( invitation_uid: str, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): obj = get_object_or_404(queryset, uid=invitation_uid) obj.delete() @@ -158,7 +159,7 @@ def incoming_delete( def incoming_accept( invitation_uid: str, data: CollectionInvitationAcceptIn, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): invitation = get_object_or_404(queryset, uid=invitation_uid) @@ -201,7 +202,7 @@ def outgoing_create( with transaction.atomic(): try: - ret = models.CollectionInvitation.objects.create( + models.CollectionInvitation.objects.create( **data.dict(exclude={"collection", "username"}), user=to_user, fromMember=member ) except IntegrityError: @@ -212,7 +213,7 @@ def outgoing_create( def outgoing_list( iterator: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_outgoing_queryset), + queryset: InvitationQuerySet = Depends(get_outgoing_queryset), ): return list_common(queryset, iterator, limit) @@ -222,7 +223,7 @@ def outgoing_list( ) def outgoing_delete( invitation_uid: str, - queryset: QuerySet = Depends(get_outgoing_queryset), + queryset: InvitationQuerySet = Depends(get_outgoing_queryset), ): obj = get_object_or_404(queryset, uid=invitation_uid) obj.delete() diff --git a/etebase_fastapi/routers/member.py b/etebase_fastapi/routers/member.py index 210374c..41393bf 100644 --- a/etebase_fastapi/routers/member.py +++ b/etebase_fastapi/routers/member.py @@ -15,14 +15,15 @@ from .collection import get_collection, verify_collection_admin User = get_typed_user_model() member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) -default_queryset: QuerySet = models.CollectionMember.objects.all() +MemberQuerySet = QuerySet[models.CollectionMember] +default_queryset: MemberQuerySet = models.CollectionMember.objects.all() -def get_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet: +def get_queryset(collection: models.Collection = Depends(get_collection)) -> MemberQuerySet: return default_queryset.filter(collection=collection) -def get_member(username: str, queryset: QuerySet = Depends(get_queryset)) -> QuerySet: +def get_member(username: str, queryset: MemberQuerySet = Depends(get_queryset)) -> models.CollectionMember: return get_object_or_404(queryset, user__username__iexact=username) @@ -54,7 +55,7 @@ class MemberListResponse(BaseModel): def member_list( iterator: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_queryset), + queryset: MemberQuerySet = Depends(get_queryset), ): queryset = queryset.order_by("id") result, new_stoken_obj, done = filter_by_stoken_and_limit( diff --git a/etebase_fastapi/utils.py b/etebase_fastapi/utils.py index 03f1a7d..c9db61c 100644 --- a/etebase_fastapi/utils.py +++ b/etebase_fastapi/utils.py @@ -7,7 +7,7 @@ import base64 from fastapi import status, Query, Depends from pydantic import BaseModel as PyBaseModel -from django.db.models import QuerySet +from django.db.models import Model, QuerySet from django.core.exceptions import ObjectDoesNotExist from django_etebase import app_settings @@ -22,6 +22,9 @@ Prefetch = Literal["auto", "medium"] PrefetchQuery = Query(default="auto") +T = t.TypeVar("T", bound=Model, covariant=True) + + class BaseModel(PyBaseModel): class Config: json_encoders = { @@ -35,7 +38,7 @@ class Context: prefetch: t.Optional[Prefetch] -def get_object_or_404(queryset: QuerySet, **kwargs): +def get_object_or_404(queryset: QuerySet[T], **kwargs) -> T: try: return queryset.get(**kwargs) except ObjectDoesNotExist as e: diff --git a/myauth/models.py b/myauth/models.py index c9298a4..89b94b4 100644 --- a/myauth/models.py +++ b/myauth/models.py @@ -15,7 +15,7 @@ class UnicodeUsernameValidator(validators.RegexValidator): class UserManager(DjangoUserManager): - def get_by_natural_key(self, username): + def get_by_natural_key(self, username: str): return self.get(**{self.model.USERNAME_FIELD + "__iexact": username}) @@ -37,7 +37,7 @@ class User(AbstractUser): ) @classmethod - def normalize_username(cls, username): + def normalize_username(cls, username: str): return super().normalize_username(username).lower()