Improve typing information.

master
Tom Hacohen 4 years ago
parent 332f7e2332
commit 709bc6c1fc

@ -12,6 +12,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing as t
from pathlib import Path from pathlib import Path
from django.db import models, transaction from django.db import models, transaction
@ -28,7 +29,7 @@ from . import app_settings
UidValidator = RegexValidator(regex=r"^[a-zA-Z0-9\-_]{20,}$", message="Not a valid UID") UidValidator = RegexValidator(regex=r"^[a-zA-Z0-9\-_]{20,}$", message="Not a valid UID")
def stoken_annotation_builder(stoken_id_fields): def stoken_annotation_builder(stoken_id_fields: t.List[str]):
aggr_fields = [Coalesce(Max(field), V(0)) for field in stoken_id_fields] aggr_fields = [Coalesce(Max(field), V(0)) for field in stoken_id_fields]
return Greatest(*aggr_fields) if len(aggr_fields) > 1 else aggr_fields[0] return Greatest(*aggr_fields) if len(aggr_fields) > 1 else aggr_fields[0]
@ -37,6 +38,8 @@ class CollectionType(models.Model):
owner = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) owner = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
uid = models.BinaryField(editable=True, blank=False, null=False, db_index=True, unique=True) uid = models.BinaryField(editable=True, blank=False, null=False, db_index=True, unique=True)
objects: models.manager.BaseManager["CollectionType"]
class Collection(models.Model): class Collection(models.Model):
main_item = models.OneToOneField("CollectionItem", related_name="parent", null=True, on_delete=models.SET_NULL) main_item = models.OneToOneField("CollectionItem", related_name="parent", null=True, on_delete=models.SET_NULL)
@ -46,19 +49,21 @@ class Collection(models.Model):
stoken_annotation = stoken_annotation_builder(["items__revisions__stoken", "members__stoken"]) stoken_annotation = stoken_annotation_builder(["items__revisions__stoken", "members__stoken"])
objects: models.manager.BaseManager["Collection"]
def __str__(self): def __str__(self):
return self.uid return self.uid
@property @property
def content(self): def content(self) -> "CollectionItemRevision":
return self.main_item.content return self.main_item.content
@property @property
def etag(self): def etag(self) -> str:
return self.content.uid return self.content.uid
@cached_property @cached_property
def stoken(self): def stoken(self) -> str:
stoken_id = ( stoken_id = (
self.__class__.objects.filter(main_item=self.main_item) self.__class__.objects.filter(main_item=self.main_item)
.annotate(max_stoken=self.stoken_annotation) .annotate(max_stoken=self.stoken_annotation)
@ -80,6 +85,8 @@ class CollectionItem(models.Model):
stoken_annotation = stoken_annotation_builder(["revisions__stoken"]) stoken_annotation = stoken_annotation_builder(["revisions__stoken"])
objects: models.manager.BaseManager["CollectionItem"]
class Meta: class Meta:
unique_together = ("uid", "collection") unique_together = ("uid", "collection")
@ -87,23 +94,23 @@ class CollectionItem(models.Model):
return "{} {}".format(self.uid, self.collection.uid) return "{} {}".format(self.uid, self.collection.uid)
@cached_property @cached_property
def content(self): def content(self) -> "CollectionItemRevision":
return self.revisions.get(current=True) return self.revisions.get(current=True)
@property @property
def etag(self): def etag(self) -> str:
return self.content.uid return self.content.uid
def chunk_directory_path(instance, filename): def chunk_directory_path(instance: "CollectionItemChunk", filename: str) -> Path:
custom_func = app_settings.CHUNK_PATH_FUNC custom_func = app_settings.CHUNK_PATH_FUNC
if custom_func is not None: if custom_func is not None:
return custom_func(instance, filename) return custom_func(instance, filename)
col = instance.collection col: Collection = instance.collection
user_id = col.owner.id user_id: int = col.owner.id
uid_prefix = instance.uid[:2] uid_prefix: str = instance.uid[:2]
uid_rest = instance.uid[2:] uid_rest: str = instance.uid[2:]
return Path("user_{}".format(user_id), col.uid, uid_prefix, uid_rest) return Path("user_{}".format(user_id), col.uid, uid_prefix, uid_rest)
@ -112,6 +119,8 @@ class CollectionItemChunk(models.Model):
collection = models.ForeignKey(Collection, related_name="chunks", on_delete=models.CASCADE) collection = models.ForeignKey(Collection, related_name="chunks", on_delete=models.CASCADE)
chunkFile = models.FileField(upload_to=chunk_directory_path, max_length=150, unique=True) chunkFile = models.FileField(upload_to=chunk_directory_path, max_length=150, unique=True)
objects: models.manager.BaseManager["CollectionItemChunk"]
def __str__(self): def __str__(self):
return self.uid return self.uid
@ -135,6 +144,8 @@ class Stoken(models.Model):
validators=[UidValidator], validators=[UidValidator],
) )
objects: models.manager.BaseManager["Stoken"]
class CollectionItemRevision(models.Model): class CollectionItemRevision(models.Model):
stoken = models.OneToOneField(Stoken, on_delete=models.PROTECT) stoken = models.OneToOneField(Stoken, on_delete=models.PROTECT)
@ -146,6 +157,8 @@ class CollectionItemRevision(models.Model):
current = models.BooleanField(db_index=True, default=True, null=True) current = models.BooleanField(db_index=True, default=True, null=True)
deleted = models.BooleanField(default=False) deleted = models.BooleanField(default=False)
objects: models.manager.BaseManager["CollectionItemRevision"]
class Meta: class Meta:
unique_together = ("item", "current") unique_together = ("item", "current")
@ -157,6 +170,8 @@ class RevisionChunkRelation(models.Model):
chunk = models.ForeignKey(CollectionItemChunk, related_name="revisions_relation", on_delete=models.CASCADE) chunk = models.ForeignKey(CollectionItemChunk, related_name="revisions_relation", on_delete=models.CASCADE)
revision = models.ForeignKey(CollectionItemRevision, related_name="chunks_relation", on_delete=models.CASCADE) revision = models.ForeignKey(CollectionItemRevision, related_name="chunks_relation", on_delete=models.CASCADE)
objects: models.manager.BaseManager["RevisionChunkRelation"]
class Meta: class Meta:
ordering = ("id",) ordering = ("id",)
@ -180,6 +195,8 @@ class CollectionMember(models.Model):
stoken_annotation = stoken_annotation_builder(["stoken"]) stoken_annotation = stoken_annotation_builder(["stoken"])
objects: models.manager.BaseManager["CollectionMember"]
class Meta: class Meta:
unique_together = ("user", "collection") unique_together = ("user", "collection")
@ -204,6 +221,8 @@ class CollectionMemberRemoved(models.Model):
collection = models.ForeignKey(Collection, related_name="removed_members", on_delete=models.CASCADE) collection = models.ForeignKey(Collection, related_name="removed_members", on_delete=models.CASCADE)
user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
objects: models.manager.BaseManager["CollectionMemberRemoved"]
class Meta: class Meta:
unique_together = ("user", "collection") unique_together = ("user", "collection")
@ -225,6 +244,8 @@ class CollectionInvitation(models.Model):
default=AccessLevels.READ_ONLY, default=AccessLevels.READ_ONLY,
) )
objects: models.manager.BaseManager["CollectionInvitation"]
class Meta: class Meta:
unique_together = ("user", "fromMember") unique_together = ("user", "fromMember")
@ -232,7 +253,7 @@ class CollectionInvitation(models.Model):
return "{} {}".format(self.fromMember.collection.uid, self.user) return "{} {}".format(self.fromMember.collection.uid, self.user)
@cached_property @cached_property
def collection(self): def collection(self) -> Collection:
return self.fromMember.collection return self.fromMember.collection
@ -244,5 +265,7 @@ class UserInfo(models.Model):
encryptedContent = models.BinaryField(editable=True, blank=False, null=False) encryptedContent = models.BinaryField(editable=True, blank=False, null=False)
salt = models.BinaryField(editable=True, blank=False, null=False) salt = models.BinaryField(editable=True, blank=False, null=False)
objects: models.manager.BaseManager["UserInfo"]
def __str__(self): def __str__(self):
return "UserInfo<{}>".format(self.owner) return "UserInfo<{}>".format(self.owner)

@ -1,6 +1,7 @@
import typing as t import typing as t
from dataclasses import dataclass from dataclasses import dataclass
from django.db.models import QuerySet
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from myauth.models import UserType, get_typed_user_model from myauth.models import UserType, get_typed_user_model
@ -18,14 +19,14 @@ class CallbackContext:
user: t.Optional[UserType] = None user: t.Optional[UserType] = None
def get_user_queryset(queryset, context: CallbackContext): def get_user_queryset(queryset: QuerySet[UserType], context: CallbackContext) -> QuerySet[UserType]:
custom_func = app_settings.GET_USER_QUERYSET_FUNC custom_func = app_settings.GET_USER_QUERYSET_FUNC
if custom_func is not None: if custom_func is not None:
return custom_func(queryset, context) return custom_func(queryset, context)
return queryset return queryset
def create_user(context: CallbackContext, *args, **kwargs): def create_user(context: CallbackContext, *args, **kwargs) -> UserType:
custom_func = app_settings.CREATE_USER_FUNC custom_func = app_settings.CREATE_USER_FUNC
if custom_func is not None: if custom_func is not None:
return custom_func(context, *args, **kwargs) return custom_func(context, *args, **kwargs)

@ -30,6 +30,8 @@ from ..sendfile import sendfile
User = get_typed_user_model User = get_typed_user_model
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
CollectionQuerySet = QuerySet[models.Collection]
CollectionItemQuerySet = QuerySet[models.CollectionItem]
class ListMulti(BaseModel): class ListMulti(BaseModel):
@ -187,7 +189,7 @@ class ItemBatchIn(BaseModel):
@sync_to_async @sync_to_async
def collection_list_common( def collection_list_common(
queryset: QuerySet, queryset: CollectionQuerySet,
user: UserType, user: UserType,
stoken: t.Optional[str], stoken: t.Optional[str],
limit: int, limit: int,
@ -249,7 +251,7 @@ async def list_multi(
data: ListMulti, data: ListMulti,
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
queryset: QuerySet = Depends(get_collection_queryset), queryset: CollectionQuerySet = Depends(get_collection_queryset),
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
@ -267,7 +269,7 @@ async def collection_list(
limit: int = 50, limit: int = 50,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_collection_queryset), queryset: CollectionQuerySet = Depends(get_collection_queryset),
): ):
return await collection_list_common(queryset, user, stoken, limit, prefetch) return await collection_list_common(queryset, user, stoken, limit, prefetch)
@ -395,7 +397,7 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val
@item_router.get("/item/{item_uid}/", response_model=CollectionItemOut, dependencies=PERMISSIONS_READ) @item_router.get("/item/{item_uid}/", response_model=CollectionItemOut, dependencies=PERMISSIONS_READ)
def item_get( def item_get(
item_uid: str, item_uid: str,
queryset: QuerySet = Depends(get_item_queryset), queryset: CollectionItemQuerySet = Depends(get_item_queryset),
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
@ -405,7 +407,7 @@ def item_get(
@sync_to_async @sync_to_async
def item_list_common( def item_list_common(
queryset: QuerySet, queryset: CollectionItemQuerySet,
user: UserType, user: UserType,
stoken: t.Optional[str], stoken: t.Optional[str],
limit: int, limit: int,
@ -422,7 +424,7 @@ def item_list_common(
@item_router.get("/item/", response_model=CollectionItemListResponse, dependencies=PERMISSIONS_READ) @item_router.get("/item/", response_model=CollectionItemListResponse, dependencies=PERMISSIONS_READ)
async def item_list( async def item_list(
queryset: QuerySet = Depends(get_item_queryset), queryset: CollectionItemQuerySet = Depends(get_item_queryset),
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
@ -471,7 +473,7 @@ def item_revisions(
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
items: QuerySet = Depends(get_item_queryset), items: CollectionItemQuerySet = Depends(get_item_queryset),
): ):
item = get_object_or_404(items, uid=item_uid) item = get_object_or_404(items, uid=item_uid)
@ -505,7 +507,7 @@ def fetch_updates(
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_item_queryset), queryset: CollectionItemQuerySet = Depends(get_item_queryset),
): ):
# FIXME: make configurable? # FIXME: make configurable?
item_limit = 200 item_limit = 200

@ -23,7 +23,8 @@ from ..utils import (
User = get_typed_user_model() User = get_typed_user_model()
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
default_queryset: QuerySet = models.CollectionInvitation.objects.all() InvitationQuerySet = QuerySet[models.CollectionInvitation]
default_queryset: InvitationQuerySet = models.CollectionInvitation.objects.all()
class UserInfoOut(BaseModel): class UserInfoOut(BaseModel):
@ -94,7 +95,7 @@ def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)):
def list_common( def list_common(
queryset: QuerySet, queryset: InvitationQuerySet,
iterator: t.Optional[str], iterator: t.Optional[str],
limit: int, limit: int,
) -> InvitationListResponse: ) -> InvitationListResponse:
@ -125,7 +126,7 @@ def list_common(
def incoming_list( def incoming_list(
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
queryset: QuerySet = Depends(get_incoming_queryset), queryset: InvitationQuerySet = Depends(get_incoming_queryset),
): ):
return list_common(queryset, iterator, limit) return list_common(queryset, iterator, limit)
@ -135,7 +136,7 @@ def incoming_list(
) )
def incoming_get( def incoming_get(
invitation_uid: str, invitation_uid: str,
queryset: QuerySet = Depends(get_incoming_queryset), queryset: InvitationQuerySet = Depends(get_incoming_queryset),
): ):
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
return CollectionInvitationOut.from_orm(obj) return CollectionInvitationOut.from_orm(obj)
@ -146,7 +147,7 @@ def incoming_get(
) )
def incoming_delete( def incoming_delete(
invitation_uid: str, invitation_uid: str,
queryset: QuerySet = Depends(get_incoming_queryset), queryset: InvitationQuerySet = Depends(get_incoming_queryset),
): ):
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
obj.delete() obj.delete()
@ -158,7 +159,7 @@ def incoming_delete(
def incoming_accept( def incoming_accept(
invitation_uid: str, invitation_uid: str,
data: CollectionInvitationAcceptIn, data: CollectionInvitationAcceptIn,
queryset: QuerySet = Depends(get_incoming_queryset), queryset: InvitationQuerySet = Depends(get_incoming_queryset),
): ):
invitation = get_object_or_404(queryset, uid=invitation_uid) invitation = get_object_or_404(queryset, uid=invitation_uid)
@ -201,7 +202,7 @@ def outgoing_create(
with transaction.atomic(): with transaction.atomic():
try: try:
ret = models.CollectionInvitation.objects.create( models.CollectionInvitation.objects.create(
**data.dict(exclude={"collection", "username"}), user=to_user, fromMember=member **data.dict(exclude={"collection", "username"}), user=to_user, fromMember=member
) )
except IntegrityError: except IntegrityError:
@ -212,7 +213,7 @@ def outgoing_create(
def outgoing_list( def outgoing_list(
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
queryset: QuerySet = Depends(get_outgoing_queryset), queryset: InvitationQuerySet = Depends(get_outgoing_queryset),
): ):
return list_common(queryset, iterator, limit) return list_common(queryset, iterator, limit)
@ -222,7 +223,7 @@ def outgoing_list(
) )
def outgoing_delete( def outgoing_delete(
invitation_uid: str, invitation_uid: str,
queryset: QuerySet = Depends(get_outgoing_queryset), queryset: InvitationQuerySet = Depends(get_outgoing_queryset),
): ):
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
obj.delete() obj.delete()

@ -15,14 +15,15 @@ from .collection import get_collection, verify_collection_admin
User = get_typed_user_model() User = get_typed_user_model()
member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
default_queryset: QuerySet = models.CollectionMember.objects.all() MemberQuerySet = QuerySet[models.CollectionMember]
default_queryset: MemberQuerySet = models.CollectionMember.objects.all()
def get_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet: def get_queryset(collection: models.Collection = Depends(get_collection)) -> MemberQuerySet:
return default_queryset.filter(collection=collection) return default_queryset.filter(collection=collection)
def get_member(username: str, queryset: QuerySet = Depends(get_queryset)) -> QuerySet: def get_member(username: str, queryset: MemberQuerySet = Depends(get_queryset)) -> models.CollectionMember:
return get_object_or_404(queryset, user__username__iexact=username) return get_object_or_404(queryset, user__username__iexact=username)
@ -54,7 +55,7 @@ class MemberListResponse(BaseModel):
def member_list( def member_list(
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
queryset: QuerySet = Depends(get_queryset), queryset: MemberQuerySet = Depends(get_queryset),
): ):
queryset = queryset.order_by("id") queryset = queryset.order_by("id")
result, new_stoken_obj, done = filter_by_stoken_and_limit( result, new_stoken_obj, done = filter_by_stoken_and_limit(

@ -7,7 +7,7 @@ import base64
from fastapi import status, Query, Depends from fastapi import status, Query, Depends
from pydantic import BaseModel as PyBaseModel from pydantic import BaseModel as PyBaseModel
from django.db.models import QuerySet from django.db.models import Model, QuerySet
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django_etebase import app_settings from django_etebase import app_settings
@ -22,6 +22,9 @@ Prefetch = Literal["auto", "medium"]
PrefetchQuery = Query(default="auto") PrefetchQuery = Query(default="auto")
T = t.TypeVar("T", bound=Model, covariant=True)
class BaseModel(PyBaseModel): class BaseModel(PyBaseModel):
class Config: class Config:
json_encoders = { json_encoders = {
@ -35,7 +38,7 @@ class Context:
prefetch: t.Optional[Prefetch] prefetch: t.Optional[Prefetch]
def get_object_or_404(queryset: QuerySet, **kwargs): def get_object_or_404(queryset: QuerySet[T], **kwargs) -> T:
try: try:
return queryset.get(**kwargs) return queryset.get(**kwargs)
except ObjectDoesNotExist as e: except ObjectDoesNotExist as e:

@ -15,7 +15,7 @@ class UnicodeUsernameValidator(validators.RegexValidator):
class UserManager(DjangoUserManager): class UserManager(DjangoUserManager):
def get_by_natural_key(self, username): def get_by_natural_key(self, username: str):
return self.get(**{self.model.USERNAME_FIELD + "__iexact": username}) return self.get(**{self.model.USERNAME_FIELD + "__iexact": username})
@ -37,7 +37,7 @@ class User(AbstractUser):
) )
@classmethod @classmethod
def normalize_username(cls, username): def normalize_username(cls, username: str):
return super().normalize_username(username).lower() return super().normalize_username(username).lower()

Loading…
Cancel
Save