wip: pagination-repository (#1316)

* bump mypy

* add pagination + refactor generic repo

* add pagination test

* remove all query object
This commit is contained in:
Hayden 2022-05-30 10:30:54 -08:00 committed by GitHub
parent 00f144a622
commit 4c594a48dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 237 additions and 215 deletions

View File

@ -10,11 +10,11 @@ class RepositoryFood(RepositoryGeneric[IngredientFood, IngredientFoodModel]):
def merge(self, from_food: UUID4, to_food: UUID4) -> IngredientFood | None: def merge(self, from_food: UUID4, to_food: UUID4) -> IngredientFood | None:
from_model: IngredientFoodModel = ( from_model: IngredientFoodModel = (
self.session.query(self.sql_model).filter_by(**self._filter_builder(**{"id": from_food})).one() self.session.query(self.model).filter_by(**self._filter_builder(**{"id": from_food})).one()
) )
to_model: IngredientFoodModel = ( to_model: IngredientFoodModel = (
self.session.query(self.sql_model).filter_by(**self._filter_builder(**{"id": to_food})).one() self.session.query(self.model).filter_by(**self._filter_builder(**{"id": to_food})).one()
) )
to_model.ingredients += from_model.ingredients to_model.ingredients += from_model.ingredients

View File

@ -1,90 +1,75 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar, Union from typing import Any, Generic, TypeVar, Union
from pydantic import UUID4, BaseModel from pydantic import UUID4, BaseModel
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import load_only
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
T = TypeVar("T", bound=BaseModel) from mealie.core.root_logger import get_logger
D = TypeVar("D") from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
Schema = TypeVar("Schema", bound=BaseModel)
Model = TypeVar("Model")
class RepositoryGeneric(Generic[T, D]): class RepositoryGeneric(Generic[Schema, Model]):
"""A Generic BaseAccess Model method to perform common operations on the database """A Generic BaseAccess Model method to perform common operations on the database
Args: Args:
Generic ([T]): Represents the Pydantic Model Generic ([Schema]): Represents the Pydantic Model
Generic ([D]): Represents the SqlAlchemyModel Model Generic ([Model]): Represents the SqlAlchemyModel Model
""" """
def __init__(self, session: Session, primary_key: str, sql_model: type[D], schema: type[T]) -> None: user_id: UUID4 = None
group_id: UUID4 = None
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
self.session = session self.session = session
self.primary_key = primary_key self.primary_key = primary_key
self.sql_model = sql_model self.model = sql_model
self.schema = schema self.schema = schema
self.observers: list = []
self.limit_by_group = False self.logger = get_logger()
self.user_id: UUID4 = None
self.limit_by_user = False def by_user(self, user_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
self.group_id: UUID4 = None
def subscribe(self, func: Callable) -> None:
self.observers.append(func)
def by_user(self, user_id: UUID4) -> "RepositoryGeneric[T, D]":
self.limit_by_user = True
self.user_id = user_id self.user_id = user_id
return self return self
def by_group(self, group_id: UUID4) -> "RepositoryGeneric[T, D]": def by_group(self, group_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
self.limit_by_group = True
self.group_id = group_id self.group_id = group_id
return self return self
def _log_exception(self, e: Exception) -> None:
self.logger.error(f"Error processing query for Repo model={self.model.__name__} schema={self.schema.__name__}")
self.logger.error(e)
def _query(self):
return self.session.query(self.model)
def _filter_builder(self, **kwargs) -> dict[str, Any]: def _filter_builder(self, **kwargs) -> dict[str, Any]:
dct = {} dct = {}
if self.limit_by_user: if self.user_id:
dct["user_id"] = self.user_id dct["user_id"] = self.user_id
if self.limit_by_group: if self.group_id:
dct["group_id"] = self.group_id dct["group_id"] = self.group_id
return {**dct, **kwargs} return {**dct, **kwargs}
# TODO: Run Observer in Async Background Task def get_all(self, limit: int = None, order_by: str = None, start=0, override=None) -> list[Schema]:
def update_observers(self) -> None: # sourcery skip: remove-unnecessary-cast
if self.observers: eff_schema = override or self.schema
for observer in self.observers:
observer()
def get_all(self, limit: int = None, order_by: str = None, start=0, override_schema=None) -> list[T]: fltr = self._filter_builder()
eff_schema = override_schema or self.schema
filter = self._filter_builder() q = self._query().filter_by(**fltr)
order_attr = None
if order_by: if order_by:
order_attr = getattr(self.sql_model, str(order_by)) if order_attr := getattr(self.model, str(order_by)):
order_attr = order_attr.desc() order_attr = order_attr.desc()
q = q.order_by(order_attr)
return [ return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
eff_schema.from_orm(x)
for x in self.session.query(self.sql_model)
.order_by(order_attr)
.filter_by(**filter)
.offset(start)
.limit(limit)
.all()
]
return [
eff_schema.from_orm(x)
for x in self.session.query(self.sql_model).filter_by(**filter).offset(start).limit(limit).all()
]
def multi_query( def multi_query(
self, self,
@ -93,55 +78,21 @@ class RepositoryGeneric(Generic[T, D]):
limit: int = None, limit: int = None,
override_schema=None, override_schema=None,
order_by: str = None, order_by: str = None,
) -> list[T]: ) -> list[Schema]:
# sourcery skip: remove-unnecessary-cast
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
filer = self._filter_builder(**query_by) fltr = self._filter_builder(**query_by)
q = self._query().filter_by(**fltr)
order_attr = None
if order_by: if order_by:
order_attr = getattr(self.sql_model, str(order_by)) if order_attr := getattr(self.model, str(order_by)):
order_attr = order_attr.desc() order_attr = order_attr.desc()
q = q.order_by(order_attr)
return [ return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
eff_schema.from_orm(x)
for x in self.session.query(self.sql_model)
.filter_by(**filer)
.order_by(order_attr)
.offset(start)
.limit(limit)
.all()
]
def get_all_limit_columns(self, fields: list[str], limit: int = None) -> list[D]: def _query_one(self, match_value: str | int | UUID4, match_key: str = None) -> Model:
"""Queries the database for the selected model. Restricts return responses to the
keys specified under "fields"
Args:
session (Session): Database Session Object
fields (list[str]): list of column names to query
limit (int): A limit of values to return
Returns:
list[SqlAlchemyBase]: Returns a list of ORM objects
"""
return self.session.query(self.sql_model).options(load_only(*fields)).limit(limit).all()
def get_all_primary_keys(self) -> list[str]:
"""Queries the database of the selected model and returns a list
of all primary_key values
Args:
session (Session): Database Session object
Returns:
list[str]:
"""
results = self.session.query(self.sql_model).options(load_only(str(self.primary_key)))
results_as_dict = [x.dict() for x in results]
return [x.get(self.primary_key) for x in results_as_dict]
def _query_one(self, match_value: str | int | UUID4, match_key: str = None) -> D:
""" """
Query the sql database for one item an return the sql alchemy model Query the sql database for one item an return the sql alchemy model
object. If no match key is provided the primary_key attribute will be used. object. If no match key is provided the primary_key attribute will be used.
@ -150,18 +101,18 @@ class RepositoryGeneric(Generic[T, D]):
match_key = self.primary_key match_key = self.primary_key
fltr = self._filter_builder(**{match_key: match_value}) fltr = self._filter_builder(**{match_key: match_value})
return self.session.query(self.sql_model).filter_by(**fltr).one() return self._query().filter_by(**fltr).one()
def get_one(self, value: str | int | UUID4, key: str = None, any_case=False, override_schema=None) -> T | None: def get_one(self, value: str | int | UUID4, key: str = None, any_case=False, override_schema=None) -> Schema | None:
key = key or self.primary_key key = key or self.primary_key
q = self.session.query(self.sql_model) q = self.session.query(self.model)
if any_case: if any_case:
search_attr = getattr(self.sql_model, key) search_attr = getattr(self.model, key)
q = q.filter(func.lower(search_attr) == str(value).lower()).filter_by(**self._filter_builder()) q = q.where(func.lower(search_attr) == str(value).lower()).filter_by(**self._filter_builder())
else: else:
q = self.session.query(self.sql_model).filter_by(**self._filter_builder(**{key: value})) q = q.filter_by(**self._filter_builder(**{key: value}))
result = q.one_or_none() result = q.one_or_none()
@ -173,32 +124,20 @@ class RepositoryGeneric(Generic[T, D]):
def get( def get(
self, match_value: str | int | UUID4, match_key: str = None, limit=1, any_case=False, override_schema=None self, match_value: str | int | UUID4, match_key: str = None, limit=1, any_case=False, override_schema=None
) -> T | list[T] | None: ) -> Schema | list[Schema] | None:
"""Retrieves an entry from the database by matching a key/value pair. If no self.logger.info("DEPRECATED: use get_one or get_all instead")
key is provided the class objects primary key will be used to match against.
Args:
match_value (str): A value used to match against the key/value in the database
match_key (str, optional): They key to match the value against. Defaults to None.
limit (int, optional): A limit to returned responses. Defaults to 1.
Returns:
dict or list[dict]:
"""
match_key = match_key or self.primary_key match_key = match_key or self.primary_key
if any_case: if any_case:
search_attr = getattr(self.sql_model, match_key) search_attr = getattr(self.model, match_key)
result = ( result = (
self.session.query(self.sql_model) self.session.query(self.model)
.filter(func.lower(search_attr) == match_value.lower()) # type: ignore .filter(func.lower(search_attr) == match_value.lower()) # type: ignore
.limit(limit) .limit(limit)
.all() .all()
) )
else: else:
result = self.session.query(self.sql_model).filter_by(**{match_key: match_value}).limit(limit).all() result = self.session.query(self.model).filter_by(**{match_key: match_value}).limit(limit).all()
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
@ -210,28 +149,29 @@ class RepositoryGeneric(Generic[T, D]):
return [eff_schema.from_orm(x) for x in result] return [eff_schema.from_orm(x) for x in result]
def create(self, document: T | BaseModel | dict) -> T: def create(self, data: Schema | BaseModel | dict) -> Schema:
"""Creates a new database entry for the given SQL Alchemy Model. data = data if isinstance(data, dict) else data.dict()
new_document = self.model(session=self.session, **data) # type: ignore
Args:
session (Session): A Database Session
document (dict): A python dictionary representing the data structure
Returns:
dict: A dictionary representation of the database entry
"""
document = document if isinstance(document, dict) else document.dict()
new_document = self.sql_model(session=self.session, **document) # type: ignore
self.session.add(new_document) self.session.add(new_document)
self.session.commit() self.session.commit()
self.session.refresh(new_document) self.session.refresh(new_document)
if self.observers:
self.update_observers()
return self.schema.from_orm(new_document) return self.schema.from_orm(new_document)
def update(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> T: def create_many(self, data: list[Schema | dict]) -> list[Schema]:
new_documents = []
for document in data:
document = document if isinstance(document, dict) else document.dict()
new_document = self.model(session=self.session, **document) # type: ignore
new_documents.append(new_document)
self.session.add_all(new_documents)
self.session.commit()
self.session.refresh(new_documents)
return [self.schema.from_orm(x) for x in new_documents]
def update(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema:
"""Update a database entry. """Update a database entry.
Args: Args:
session (Session): Database Session session (Session): Database Session
@ -246,30 +186,23 @@ class RepositoryGeneric(Generic[T, D]):
entry = self._query_one(match_value=match_value) entry = self._query_one(match_value=match_value)
entry.update(session=self.session, **new_data) # type: ignore entry.update(session=self.session, **new_data) # type: ignore
if self.observers:
self.update_observers()
self.session.commit() self.session.commit()
return self.schema.from_orm(entry) return self.schema.from_orm(entry)
def patch(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> T | None: def patch(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema:
new_data = new_data if isinstance(new_data, dict) else new_data.dict() new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(match_value=match_value) entry = self._query_one(match_value=match_value)
if not entry:
# TODO: Should raise exception
return None
entry_as_dict = self.schema.from_orm(entry).dict() entry_as_dict = self.schema.from_orm(entry).dict()
entry_as_dict.update(new_data) entry_as_dict.update(new_data)
return self.update(match_value, entry_as_dict) return self.update(match_value, entry_as_dict)
def delete(self, value, match_key: str | None = None) -> T: def delete(self, value, match_key: str | None = None) -> Schema:
match_key = match_key or self.primary_key match_key = match_key or self.primary_key
result = self.session.query(self.sql_model).filter_by(**{match_key: value}).one() result = self._query().filter_by(**{match_key: value}).one()
results_as_model = self.schema.from_orm(result) results_as_model = self.schema.from_orm(result)
try: try:
@ -279,23 +212,17 @@ class RepositoryGeneric(Generic[T, D]):
self.session.rollback() self.session.rollback()
raise e raise e
if self.observers:
self.update_observers()
return results_as_model return results_as_model
def delete_all(self) -> None: def delete_all(self) -> None:
self.session.query(self.sql_model).delete() self._query().delete()
self.session.commit() self.session.commit()
if self.observers:
self.update_observers()
def count_all(self, match_key=None, match_value=None) -> int: def count_all(self, match_key=None, match_value=None) -> int:
if None in [match_key, match_value]: if None in [match_key, match_value]:
return self.session.query(self.sql_model).count() return self._query().count()
else: else:
return self.session.query(self.sql_model).filter_by(**{match_key: match_value}).count() return self._query().filter_by(**{match_key: match_value}).count()
def _count_attribute( def _count_attribute(
self, self,
@ -303,27 +230,57 @@ class RepositoryGeneric(Generic[T, D]):
attr_match: str = None, attr_match: str = None,
count=True, count=True,
override_schema=None, override_schema=None,
) -> Union[int, list[T]]: ) -> Union[int, list[Schema]]: # sourcery skip: assign-if-exp
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
# attr_filter = getattr(self.sql_model, attribute_name)
q = self._query().filter(attribute_name == attr_match)
if count: if count:
return self.session.query(self.sql_model).filter(attribute_name == attr_match).count() # noqa: 711 return q.count()
else: else:
return [ return [eff_schema.from_orm(x) for x in q.all()]
eff_schema.from_orm(x)
for x in self.session.query(self.sql_model).filter(attribute_name == attr_match).all() # noqa: 711
]
def create_many(self, documents: list[T | dict]) -> list[T]: def pagination(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
new_documents = [] """
for document in documents: pagination is a method to interact with the filtered database table and return a paginated result
document = document if isinstance(document, dict) else document.dict() using the PaginationBase that provides several data points that are needed to manage pagination
new_document = self.sql_model(session=self.session, **document) # type: ignore on the client side. This method does utilize the _filter_build method to ensure that the results
new_documents.append(new_document) are filtered by the user and group id when applicable.
self.session.add_all(new_documents) NOTE: When you provide an override you'll need to manually type the result of this method
self.session.commit() as the override, as the type system, is not able to infer the result of this method.
self.session.refresh(new_documents) """
eff_schema = override or self.schema
return [self.schema.from_orm(x) for x in new_documents] q = self.session.query(self.model)
fltr = self._filter_builder()
q = q.filter_by(**fltr)
count = q.count()
if pagination.order_by:
if order_attr := getattr(self.model, pagination.order_by, None):
if pagination.order_direction == OrderDirection.asc:
order_attr = order_attr.asc()
elif pagination.order_direction == OrderDirection.desc:
order_attr = order_attr.desc()
q = q.order_by(order_attr)
q = q.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page)
try:
data = q.all()
except Exception as e:
self._log_exception(e)
self.session.rollback()
raise e
return PaginationBase(
page=pagination.page,
per_page=pagination.per_page,
total=count,
total_pages=int(count / pagination.per_page) + 1,
data=[eff_schema.from_orm(s) for s in data],
)

View File

@ -16,7 +16,7 @@ from .repository_generic import RepositoryGeneric
class RepositoryGroup(RepositoryGeneric[GroupInDB, Group]): class RepositoryGroup(RepositoryGeneric[GroupInDB, Group]):
def get_by_name(self, name: str, limit=1) -> Union[GroupInDB, Group, None]: def get_by_name(self, name: str, limit=1) -> Union[GroupInDB, Group, None]:
dbgroup = self.session.query(self.sql_model).filter_by(**{"name": name}).one_or_none() dbgroup = self.session.query(self.model).filter_by(**{"name": name}).one_or_none()
if dbgroup is None: if dbgroup is None:
return None return None
return self.schema.from_orm(dbgroup) return self.schema.from_orm(dbgroup)

View File

@ -44,11 +44,11 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
if order_by: if order_by:
order_attr = getattr(self.sql_model, str(order_by)) order_attr = getattr(self.model, str(order_by))
return [ return [
eff_schema.from_orm(x) eff_schema.from_orm(x)
for x in self.session.query(self.sql_model) for x in self.session.query(self.model)
.join(RecipeSettings) .join(RecipeSettings)
.filter(RecipeSettings.public == True) # noqa: 711 .filter(RecipeSettings.public == True) # noqa: 711
.order_by(order_attr.desc()) .order_by(order_attr.desc())
@ -59,7 +59,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
return [ return [
eff_schema.from_orm(x) eff_schema.from_orm(x)
for x in self.session.query(self.sql_model) for x in self.session.query(self.model)
.join(RecipeSettings) .join(RecipeSettings)
.filter(RecipeSettings.public == True) # noqa: 711 .filter(RecipeSettings.public == True) # noqa: 711
.offset(start) .offset(start)

View File

@ -10,11 +10,11 @@ class RepositoryUnit(RepositoryGeneric[IngredientUnit, IngredientUnitModel]):
def merge(self, from_unit: UUID4, to_unit: UUID4) -> IngredientUnit | None: def merge(self, from_unit: UUID4, to_unit: UUID4) -> IngredientUnit | None:
from_model: IngredientUnitModel = ( from_model: IngredientUnitModel = (
self.session.query(self.sql_model).filter_by(**self._filter_builder(**{"id": from_unit})).one() self.session.query(self.model).filter_by(**self._filter_builder(**{"id": from_unit})).one()
) )
to_model: IngredientUnitModel = ( to_model: IngredientUnitModel = (
self.session.query(self.sql_model).filter_by(**self._filter_builder(**{"id": to_unit})).one() self.session.query(self.model).filter_by(**self._filter_builder(**{"id": to_unit})).one()
) )
to_model.ingredients += from_model.ingredients to_model.ingredients += from_model.ingredients

View File

@ -18,7 +18,7 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
return self.schema.from_orm(entry) return self.schema.from_orm(entry)
def create(self, user: PrivateUser | dict): def create(self, user: PrivateUser | dict): # type: ignore
new_user = super().create(user) new_user = super().create(user)
# Select Random Image # Select Random Image
@ -42,4 +42,4 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
dbuser = self.session.query(User).filter(User.username == username).one_or_none() dbuser = self.session.query(User).filter(User.username == username).one_or_none()
if dbuser is None: if dbuser is None:
return None return None
return self.schema.from_orm(dbuser) return self.schema.from_orm(dbuser) # type: ignore

View File

@ -41,7 +41,7 @@ class AdminUserManagementRoutes(BaseAdminController):
@router.get("", response_model=list[GroupInDB]) @router.get("", response_model=list[GroupInDB])
def get_all(self, q: GetAll = Depends(GetAll)): def get_all(self, q: GetAll = Depends(GetAll)):
return self.repo.get_all(start=q.start, limit=q.limit, override_schema=GroupInDB) return self.repo.get_all(start=q.start, limit=q.limit, override=GroupInDB)
@router.post("", response_model=GroupInDB, status_code=status.HTTP_201_CREATED) @router.post("", response_model=GroupInDB, status_code=status.HTTP_201_CREATED)
def create_one(self, data: GroupBase): def create_one(self, data: GroupBase):

View File

@ -34,7 +34,7 @@ class AdminUserManagementRoutes(BaseAdminController):
@router.get("", response_model=list[UserOut]) @router.get("", response_model=list[UserOut])
def get_all(self, q: GetAll = Depends(GetAll)): def get_all(self, q: GetAll = Depends(GetAll)):
return self.repo.get_all(start=q.start, limit=q.limit, override_schema=UserOut) return self.repo.get_all(start=q.start, limit=q.limit, override=UserOut)
@router.post("", response_model=UserOut, status_code=201) @router.post("", response_model=UserOut, status_code=201)
def create_one(self, data: UserIn): def create_one(self, data: UserIn):

View File

@ -43,7 +43,7 @@ class RecipeCommentRoutes(BaseUserController):
@router.get("", response_model=list[RecipeCommentOut]) @router.get("", response_model=list[RecipeCommentOut])
def get_all(self, q: GetAll = Depends(GetAll)): def get_all(self, q: GetAll = Depends(GetAll)):
return self.repo.get_all(start=q.start, limit=q.limit, override_schema=RecipeCommentOut) return self.repo.get_all(start=q.start, limit=q.limit, override=RecipeCommentOut)
@router.post("", response_model=RecipeCommentOut, status_code=201) @router.post("", response_model=RecipeCommentOut, status_code=201)
def create_one(self, data: RecipeCommentCreate): def create_one(self, data: RecipeCommentCreate):

View File

@ -37,7 +37,7 @@ class MultiPurposeLabelsController(BaseUserController):
@router.get("", response_model=list[MultiPurposeLabelSummary]) @router.get("", response_model=list[MultiPurposeLabelSummary])
def get_all(self, q: GetAll = Depends(GetAll)): def get_all(self, q: GetAll = Depends(GetAll)):
return self.repo.get_all(start=q.start, limit=q.limit, override_schema=MultiPurposeLabelSummary) return self.repo.get_all(start=q.start, limit=q.limit, override=MultiPurposeLabelSummary)
@router.post("", response_model=MultiPurposeLabelOut) @router.post("", response_model=MultiPurposeLabelOut)
def create_one(self, data: MultiPurposeLabelCreate): def create_one(self, data: MultiPurposeLabelCreate):

View File

@ -24,7 +24,7 @@ class GroupMealplanConfigController(BaseUserController):
@router.get("", response_model=list[PlanRulesOut]) @router.get("", response_model=list[PlanRulesOut])
def get_all(self): def get_all(self):
return self.repo.get_all(override_schema=PlanRulesOut) return self.repo.get_all(override=PlanRulesOut)
@router.post("", response_model=PlanRulesOut, status_code=201) @router.post("", response_model=PlanRulesOut, status_code=201)
def create_one(self, data: PlanRulesCreate): def create_one(self, data: PlanRulesCreate):

View File

@ -114,7 +114,7 @@ class ShoppingListController(BaseUserController):
@router.get("", response_model=list[ShoppingListSummary]) @router.get("", response_model=list[ShoppingListSummary])
def get_all(self, q: GetAll = Depends(GetAll)): def get_all(self, q: GetAll = Depends(GetAll)):
return self.repo.get_all(start=q.start, limit=q.limit, override_schema=ShoppingListSummary) return self.repo.get_all(start=q.start, limit=q.limit, override=ShoppingListSummary)
@router.post("", response_model=ShoppingListOut, status_code=201) @router.post("", response_model=ShoppingListOut, status_code=201)
def create_one(self, data: ShoppingListCreate): def create_one(self, data: ShoppingListCreate):

View File

@ -25,7 +25,7 @@ class ReadWebhookController(BaseUserController):
@router.get("", response_model=list[ReadWebhook]) @router.get("", response_model=list[ReadWebhook])
def get_all(self, q: GetAll = Depends(GetAll)): def get_all(self, q: GetAll = Depends(GetAll)):
return self.repo.get_all(start=q.start, limit=q.limit, override_schema=ReadWebhook) return self.repo.get_all(start=q.start, limit=q.limit, override=ReadWebhook)
@router.post("", response_model=ReadWebhook, status_code=201) @router.post("", response_model=ReadWebhook, status_code=201)
def create_one(self, data: CreateWebhook): def create_one(self, data: CreateWebhook):

View File

@ -43,7 +43,7 @@ class RecipeCategoryController(BaseUserController):
@router.get("", response_model=list[RecipeCategory]) @router.get("", response_model=list[RecipeCategory])
def get_all(self): def get_all(self):
"""Returns a list of available categories in the database""" """Returns a list of available categories in the database"""
return self.repo.get_all(override_schema=RecipeCategory) return self.repo.get_all(override=RecipeCategory)
@router.post("", status_code=201) @router.post("", status_code=201)
def create_one(self, category: CategoryIn): def create_one(self, category: CategoryIn):

View File

@ -32,7 +32,7 @@ class TagController(BaseUserController):
@router.get("") @router.get("")
async def get_all(self): async def get_all(self):
"""Returns a list of available tags in the database""" """Returns a list of available tags in the database"""
return self.repo.get_all(override_schema=RecipeTag) return self.repo.get_all(override=RecipeTag)
@router.get("/empty") @router.get("/empty")
def get_empty_tags(self): def get_empty_tags(self):

View File

@ -26,7 +26,7 @@ class RecipeToolController(BaseUserController):
@router.get("", response_model=list[RecipeTool]) @router.get("", response_model=list[RecipeTool])
def get_all(self, q: GetAll = Depends(GetAll)): def get_all(self, q: GetAll = Depends(GetAll)):
return self.repo.get_all(start=q.start, limit=q.limit, override_schema=RecipeTool) return self.repo.get_all(start=q.start, limit=q.limit, override=RecipeTool)
@router.post("", response_model=RecipeTool, status_code=201) @router.post("", response_model=RecipeTool, status_code=201)
def create_one(self, data: RecipeToolCreate): def create_one(self, data: RecipeToolCreate):

View File

@ -26,7 +26,7 @@ class RecipeSharedController(BaseUserController):
if recipe_id: if recipe_id:
return self.repo.multi_query({"recipe_id": recipe_id}, override_schema=RecipeShareTokenSummary) return self.repo.multi_query({"recipe_id": recipe_id}, override_schema=RecipeShareTokenSummary)
else: else:
return self.repo.get_all(override_schema=RecipeShareTokenSummary) return self.repo.get_all(override=RecipeShareTokenSummary)
@router.post("", response_model=RecipeShareToken, status_code=201) @router.post("", response_model=RecipeShareToken, status_code=201)
def create_one(self, data: RecipeShareTokenCreate) -> RecipeShareToken: def create_one(self, data: RecipeShareTokenCreate) -> RecipeShareToken:

View File

@ -0,0 +1,27 @@
import enum
from typing import Generic, TypeVar
from pydantic import BaseModel
from pydantic.generics import GenericModel
DataT = TypeVar("DataT", bound=BaseModel)
class OrderDirection(str, enum.Enum):
asc = "asc"
desc = "desc"
class PaginationQuery(BaseModel):
page: int = 1
order_by: str = "created_at"
order_direction: OrderDirection = OrderDirection.desc
per_page: int = 50
class PaginationBase(GenericModel, Generic[DataT]):
page: int = 1
per_page: int = 10
total: int = 0
total_pages: int = 0
data: list[DataT]

52
poetry.lock generated
View File

@ -724,7 +724,7 @@ python-versions = ">=3.6"
[[package]] [[package]]
name = "mypy" name = "mypy"
version = "0.940" version = "0.960"
description = "Optional static typing for Python" description = "Optional static typing for Python"
category = "dev" category = "dev"
optional = false optional = false
@ -732,7 +732,7 @@ python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
mypy-extensions = ">=0.4.3" mypy-extensions = ">=0.4.3"
tomli = ">=1.1.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = ">=3.10" typing-extensions = ">=3.10"
[package.extras] [package.extras]
@ -1545,7 +1545,7 @@ pgsql = ["psycopg2-binary"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "45c28207b80dd8ecd82030410c132be32e8f2e46925c92641d4dd1626fec7786" content-hash = "5ceeaf3c1f6ddf5011d96f6a6f6d76da02da48e80de9dea9804eeb458dbe69b5"
[metadata.files] [metadata.files]
aiofiles = [ aiofiles = [
@ -2087,29 +2087,29 @@ mkdocs-material-extensions = [
{file = "mkdocs_material_extensions-1.0.3-py3-none-any.whl", hash = "sha256:a82b70e533ce060b2a5d9eb2bc2e1be201cf61f901f93704b4acf6e3d5983a44"}, {file = "mkdocs_material_extensions-1.0.3-py3-none-any.whl", hash = "sha256:a82b70e533ce060b2a5d9eb2bc2e1be201cf61f901f93704b4acf6e3d5983a44"},
] ]
mypy = [ mypy = [
{file = "mypy-0.940-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0fdc9191a49c77ab5fa0439915d405e80a1118b163ab03cd2a530f346b12566a"}, {file = "mypy-0.960-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3a3e525cd76c2c4f90f1449fd034ba21fcca68050ff7c8397bb7dd25dd8b8248"},
{file = "mypy-0.940-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1903c92ff8642d521b4627e51a67e49f5be5aedb1fb03465b3aae4c3338ec491"}, {file = "mypy-0.960-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7a76dc4f91e92db119b1be293892df8379b08fd31795bb44e0ff84256d34c251"},
{file = "mypy-0.940-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:471af97c35a32061883b0f8a3305ac17947fd42ce962ca9e2b0639eb9141492f"}, {file = "mypy-0.960-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffdad80a92c100d1b0fe3d3cf1a4724136029a29afe8566404c0146747114382"},
{file = "mypy-0.940-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:13677cb8b050f03b5bb2e8bf7b2668cd918b001d56c2435082bbfc9d5f730f42"}, {file = "mypy-0.960-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7d390248ec07fa344b9f365e6ed9d205bd0205e485c555bed37c4235c868e9d5"},
{file = "mypy-0.940-cp310-cp310-win_amd64.whl", hash = "sha256:2efd76893fb8327eca7e942e21b373e6f3c5c083ff860fb1e82ddd0462d662bd"}, {file = "mypy-0.960-cp310-cp310-win_amd64.whl", hash = "sha256:925aa84369a07846b7f3b8556ccade1f371aa554f2bd4fb31cb97a24b73b036e"},
{file = "mypy-0.940-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f8fe1bfab792e4300f80013edaf9949b34e4c056a7b2531b5ef3a0fb9d598ae2"}, {file = "mypy-0.960-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:239d6b2242d6c7f5822163ee082ef7a28ee02e7ac86c35593ef923796826a385"},
{file = "mypy-0.940-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2dba92f58610d116f68ec1221fb2de2a346d081d17b24a784624389b17a4b3f9"}, {file = "mypy-0.960-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f1ba54d440d4feee49d8768ea952137316d454b15301c44403db3f2cb51af024"},
{file = "mypy-0.940-cp36-cp36m-win_amd64.whl", hash = "sha256:712affcc456de637e774448c73e21c84dfa5a70bcda34e9b0be4fb898a9e8e07"}, {file = "mypy-0.960-cp36-cp36m-win_amd64.whl", hash = "sha256:cb7752b24528c118a7403ee955b6a578bfcf5879d5ee91790667c8ea511d2085"},
{file = "mypy-0.940-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8aaf18d0f8bc3ffba56d32a85971dfbd371a5be5036da41ac16aefec440eff17"}, {file = "mypy-0.960-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:826a2917c275e2ee05b7c7b736c1e6549a35b7ea5a198ca457f8c2ebea2cbecf"},
{file = "mypy-0.940-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:51be997c1922e2b7be514a5215d1e1799a40832c0a0dee325ba8794f2c48818f"}, {file = "mypy-0.960-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3eabcbd2525f295da322dff8175258f3fc4c3eb53f6d1929644ef4d99b92e72d"},
{file = "mypy-0.940-cp37-cp37m-win_amd64.whl", hash = "sha256:628f5513268ebbc563750af672ccba5eef7f92d2d90154233edd498dfb98ca4e"}, {file = "mypy-0.960-cp37-cp37m-win_amd64.whl", hash = "sha256:f47322796c412271f5aea48381a528a613f33e0a115452d03ae35d673e6064f8"},
{file = "mypy-0.940-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:68038d514ae59d5b2f326be502a359160158d886bd153fc2489dbf7a03c44c96"}, {file = "mypy-0.960-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2c7f8bb9619290836a4e167e2ef1f2cf14d70e0bc36c04441e41487456561409"},
{file = "mypy-0.940-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b2fa5f2d597478ccfe1f274f8da2f50ea1e63da5a7ae2342c5b3b2f3e57ec340"}, {file = "mypy-0.960-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fbfb873cf2b8d8c3c513367febde932e061a5f73f762896826ba06391d932b2a"},
{file = "mypy-0.940-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b1a116c451b41e35afc09618f454b5c2704ba7a4e36f9ff65014fef26bb6075b"}, {file = "mypy-0.960-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc537885891382e08129d9862553b3d00d4be3eb15b8cae9e2466452f52b0117"},
{file = "mypy-0.940-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1f66f2309cdbb07e95e60e83fb4a8272095bd4ea6ee58bf9a70d5fb304ec3e3f"}, {file = "mypy-0.960-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:481f98c6b24383188c928f33dd2f0776690807e12e9989dd0419edd5c74aa53b"},
{file = "mypy-0.940-cp38-cp38-win_amd64.whl", hash = "sha256:3ac14949677ae9cb1adc498c423b194ad4d25b13322f6fe889fb72b664c79121"}, {file = "mypy-0.960-cp38-cp38-win_amd64.whl", hash = "sha256:29dc94d9215c3eb80ac3c2ad29d0c22628accfb060348fd23d73abe3ace6c10d"},
{file = "mypy-0.940-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6eab2bcc2b9489b7df87d7c20743b66d13254ad4d6430e1dfe1a655d51f0933d"}, {file = "mypy-0.960-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:33d53a232bb79057f33332dbbb6393e68acbcb776d2f571ba4b1d50a2c8ba873"},
{file = "mypy-0.940-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0b52778a018559a256c819ee31b2e21e10b31ddca8705624317253d6d08dbc35"}, {file = "mypy-0.960-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8d645e9e7f7a5da3ec3bbcc314ebb9bb22c7ce39e70367830eb3c08d0140b9ce"},
{file = "mypy-0.940-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d9d7647505bf427bc7931e8baf6cacf9be97e78a397724511f20ddec2a850752"}, {file = "mypy-0.960-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:85cf2b14d32b61db24ade8ac9ae7691bdfc572a403e3cb8537da936e74713275"},
{file = "mypy-0.940-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a0e5657ccaedeb5fdfda59918cc98fc6d8a8e83041bc0cec347a2ab6915f9998"}, {file = "mypy-0.960-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a85a20b43fa69efc0b955eba1db435e2ffecb1ca695fe359768e0503b91ea89f"},
{file = "mypy-0.940-cp39-cp39-win_amd64.whl", hash = "sha256:83f66190e3c32603217105913fbfe0a3ef154ab6bbc7ef2c989f5b2957b55840"}, {file = "mypy-0.960-cp39-cp39-win_amd64.whl", hash = "sha256:0ebfb3f414204b98c06791af37a3a96772203da60636e2897408517fcfeee7a8"},
{file = "mypy-0.940-py3-none-any.whl", hash = "sha256:a168da06eccf51875fdff5f305a47f021f23f300e2b89768abdac24538b1f8ec"}, {file = "mypy-0.960-py3-none-any.whl", hash = "sha256:bfd4f6536bd384c27c392a8b8f790fd0ed5c0cf2f63fc2fed7bce56751d53026"},
{file = "mypy-0.940.tar.gz", hash = "sha256:71bec3d2782d0b1fecef7b1c436253544d81c1c0e9ca58190aed9befd8f081c5"}, {file = "mypy-0.960.tar.gz", hash = "sha256:d4fccf04c1acf750babd74252e0f2db6bd2ac3aa8fe960797d9f3ef41cf2bfd4"},
] ]
mypy-extensions = [ mypy-extensions = [
{file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},

View File

@ -52,7 +52,7 @@ isort = "^5.9.3"
flake8-print = "^4.0.0" flake8-print = "^4.0.0"
black = "^21.12b0" black = "^21.12b0"
coveragepy-lcov = "^0.1.1" coveragepy-lcov = "^0.1.1"
mypy = "^0.940" mypy = "^0.960"
types-python-slugify = "^5.0.3" types-python-slugify = "^5.0.3"
types-PyYAML = "^6.0.4" types-PyYAML = "^6.0.4"
types-requests = "^2.27.12" types-requests = "^2.27.12"

View File

@ -0,0 +1,38 @@
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.response.pagination import PaginationQuery
from mealie.services.seeder.seeder_service import SeederService
from tests.utils.fixture_schemas import TestUser
def test_repository_pagination(database: AllRepositories, unique_user: TestUser):
group = database.groups.get_one(unique_user.group_id)
seeder = SeederService(database, None, group)
seeder.seed_foods("en-US")
foods_repo = database.ingredient_foods.by_group(unique_user.group_id) # type: ignore
query = PaginationQuery(
page=1,
order_by="id",
per_page=10,
)
seen = []
for _ in range(10):
results = foods_repo.pagination(query)
assert len(results.data) == 10
for result in results.data:
assert result.id not in seen
seen += [result.id for result in results.data]
query.page += 1
results = foods_repo.pagination(query)
for result in results.data:
assert result.id not in seen