chore: backend page_all route cleanup (#1483)

* refactored to remove duplicate code

* refactored meal plan slice to use a query filter
This commit is contained in:
Michael Genson 2022-07-26 20:43:25 -05:00 committed by GitHub
parent f00280e32b
commit 3d4e5441dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 140 deletions

View File

@ -4,6 +4,7 @@ from typing import Any, Generic, TypeVar, Union
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import UUID4, BaseModel from pydantic import UUID4, BaseModel
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes from sqlalchemy.sql import sqltypes
@ -246,16 +247,43 @@ class RepositoryGeneric(Generic[Schema, Model]):
fltr = self._filter_builder() fltr = self._filter_builder()
q = q.filter_by(**fltr) q = q.filter_by(**fltr)
q, count, total_pages = self.add_pagination_to_query(q, pagination)
try:
data = q.all()
except Exception as e:
self._log_exception(e)
self.session.rollback()
raise e
return PaginationBase(
page=pagination.page,
per_page=pagination.per_page,
total=count,
total_pages=total_pages,
items=[eff_schema.from_orm(s) for s in data],
)
def add_pagination_to_query(self, query: Query, pagination: PaginationQuery) -> tuple[Query, int, int]:
"""
Adds pagination data to an existing query.
:returns:
- query - modified query with pagination data
- count - total number of records (without pagination)
- total_pages - the total number of pages in the query
"""
if pagination.query_filter: if pagination.query_filter:
try: try:
qf = QueryFilter(pagination.query_filter) query_filter = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model) query = query_filter.filter_query(query, model=self.model)
except ValueError as e: except ValueError as e:
self.logger.error(e) self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
count = q.count() count = query.count()
# interpret -1 as "get_all" # interpret -1 as "get_all"
if pagination.per_page == -1: if pagination.per_page == -1:
@ -286,21 +314,6 @@ class RepositoryGeneric(Generic[Schema, Model]):
elif pagination.order_direction == OrderDirection.desc: elif pagination.order_direction == OrderDirection.desc:
order_attr = order_attr.desc() order_attr = order_attr.desc()
q = q.order_by(order_attr) query = query.order_by(order_attr)
q = q.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page) return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
try:
data = q.all()
except Exception as e:
self._log_exception(e)
self.session.rollback()
raise e
return PaginationBase(
page=pagination.page,
per_page=pagination.per_page,
total=count,
total_pages=total_pages,
items=[eff_schema.from_orm(s) for s in data],
)

View File

@ -1,13 +1,8 @@
from datetime import date from datetime import date
from math import ceil
from uuid import UUID from uuid import UUID
from sqlalchemy import func
from sqlalchemy.sql import sqltypes
from mealie.db.models.group import GroupMealPlan from mealie.db.models.group import GroupMealPlan
from mealie.schema.meal_plan.new_meal import PlanEntryPagination, ReadPlanEntry from mealie.schema.meal_plan.new_meal import ReadPlanEntry
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from .repository_generic import RepositoryGeneric from .repository_generic import RepositoryGeneric
@ -16,68 +11,6 @@ class RepositoryMeals(RepositoryGeneric[ReadPlanEntry, GroupMealPlan]):
def by_group(self, group_id: UUID) -> "RepositoryMeals": def by_group(self, group_id: UUID) -> "RepositoryMeals":
return super().by_group(group_id) # type: ignore return super().by_group(group_id) # type: ignore
def get_slice(
self, pagination: PaginationQuery, start_date: date, end_date: date, group_id: UUID
) -> PlanEntryPagination:
start_str = start_date.strftime("%Y-%m-%d")
end_str = end_date.strftime("%Y-%m-%d")
# get the total number of documents
q = self.session.query(GroupMealPlan).filter(
GroupMealPlan.date.between(start_str, end_str),
GroupMealPlan.group_id == group_id,
)
count = q.count()
# interpret -1 as "get_all"
if pagination.per_page == -1:
pagination.per_page = count
try:
total_pages = ceil(count / pagination.per_page)
except ZeroDivisionError:
total_pages = 0
# interpret -1 as "last page"
if pagination.page == -1:
pagination.page = total_pages
# failsafe for user input error
if pagination.page < 1:
pagination.page = 1
if pagination.order_by:
if order_attr := getattr(self.model, pagination.order_by, None):
# queries handle uppercase and lowercase differently, which is undesirable
if isinstance(order_attr.type, sqltypes.String):
order_attr = func.lower(order_attr)
if pagination.order_direction == OrderDirection.asc:
order_attr = order_attr.asc()
elif pagination.order_direction == OrderDirection.desc:
order_attr = order_attr.desc()
q = q.order_by(order_attr)
q = q.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page)
try:
data = [self.schema.from_orm(x) for x in q.all()]
except Exception as e:
self._log_exception(e)
self.session.rollback()
raise e
return PlanEntryPagination(
page=pagination.page,
per_page=pagination.per_page,
total=count,
total_pages=total_pages,
items=data,
)
def get_today(self, group_id: UUID) -> list[ReadPlanEntry]: def get_today(self, group_id: UUID) -> list[ReadPlanEntry]:
today = date.today() today = date.today()
qry = self.session.query(GroupMealPlan).filter(GroupMealPlan.date == today, GroupMealPlan.group_id == group_id) qry = self.session.query(GroupMealPlan).filter(GroupMealPlan.date == today, GroupMealPlan.group_id == group_id)

View File

@ -1,15 +1,12 @@
from math import ceil
from random import randint from random import randint
from typing import Any, Optional from typing import Any, Optional
from uuid import UUID from uuid import UUID
from fastapi import HTTPException
from pydantic import UUID4 from pydantic import UUID4
from slugify import slugify from slugify import slugify
from sqlalchemy import and_, func from sqlalchemy import and_, func
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.sql import sqltypes
from mealie.db.models.recipe.category import Category from mealie.db.models.recipe.category import Category
from mealie.db.models.recipe.ingredient import RecipeIngredient from mealie.db.models.recipe.ingredient import RecipeIngredient
@ -20,8 +17,7 @@ from mealie.db.models.recipe.tool import Tool
from mealie.schema.recipe import Recipe from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import OrderDirection, PaginationQuery from mealie.schema.response.pagination import PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
from .repository_generic import RepositoryGeneric from .repository_generic import RepositoryGeneric
@ -149,49 +145,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
fltr = self._filter_builder() fltr = self._filter_builder()
q = q.filter_by(**fltr) q = q.filter_by(**fltr)
if pagination.query_filter: q, count, total_pages = self.add_pagination_to_query(q, pagination)
try:
qf = QueryFilter(pagination.query_filter)
q = qf.filter_query(q, model=self.model)
except ValueError as e:
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e))
count = q.count()
# interpret -1 as "get_all"
if pagination.per_page == -1:
pagination.per_page = count
try:
total_pages = ceil(count / pagination.per_page)
except ZeroDivisionError:
total_pages = 0
# interpret -1 as "last page"
if pagination.page == -1:
pagination.page = total_pages
# failsafe for user input error
if pagination.page < 1:
pagination.page = 1
if pagination.order_by:
if order_attr := getattr(self.model, pagination.order_by, None):
# queries handle uppercase and lowercase differently, which is undesirable
if isinstance(order_attr.type, sqltypes.String):
order_attr = func.lower(order_attr)
if pagination.order_direction == OrderDirection.asc:
order_attr = order_attr.asc()
elif pagination.order_direction == OrderDirection.desc:
order_attr = order_attr.desc()
q = q.order_by(order_attr)
q = q.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page)
try: try:
data = q.all() data = q.all()

View File

@ -1,4 +1,4 @@
from datetime import date, timedelta from datetime import date
from functools import cached_property from functools import cached_property
from typing import Optional from typing import Optional
@ -94,10 +94,24 @@ class GroupMealplanController(BaseUserController):
start_date: Optional[date] = None, start_date: Optional[date] = None,
end_date: Optional[date] = None, end_date: Optional[date] = None,
): ):
start_date = start_date or date.today() - timedelta(days=999) # merge start and end dates into pagination query only if either is provided
end_date = end_date or date.today() + timedelta(days=999) if start_date or end_date:
if not start_date:
date_filter = f"date <= {end_date}"
return self.repo.get_slice(pagination=q, start_date=start_date, end_date=end_date, group_id=self.group.id) elif not end_date:
date_filter = f"date >= {start_date}"
else:
date_filter = f"date >= {start_date} AND date <= {end_date}"
if q.query_filter:
q.query_filter = f"({q.query_filter}) AND ({date_filter})"
else:
q.query_filter = date_filter
return self.repo.page_all(pagination=q)
@router.post("", response_model=ReadPlanEntry, status_code=201) @router.post("", response_model=ReadPlanEntry, status_code=201)
def create_one(self, data: CreatePlanEntry): def create_one(self, data: CreatePlanEntry):