mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-07-09 03:04:54 -04:00
feat: Advanced Query Filter Record Ordering (#2530)
* added support for multiple order_by strs * refactored qf to expose nested attr logic * added nested attr support to order_by * added tests * changed unique user to be function-level * updated docs * added support for null handling * updated docs * undid fixture changes * fix leaky tests * added advanced shopping list item test --------- Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
parent
2c5e5a8421
commit
aec4cb4f31
@ -72,6 +72,23 @@ This filter will find all recipes created on or after a particular date: <br>
|
||||
This filter will find all units that have `useAbbreviation` disabled: <br>
|
||||
`useAbbreviation = false`
|
||||
|
||||
This filter will find all foods that are not named "carrot": <br>
|
||||
`name <> "carrot"`
|
||||
|
||||
##### Keyword Filters
|
||||
The API supports many SQL keywords, such as `IS NULL` and `IN`, as well as their negations (e.g. `IS NOT NULL` and `NOT IN`).
|
||||
|
||||
Here is an example of a filter that returns all recipes where the "last made" value is not null: <br>
|
||||
`lastMade IS NOT NULL`
|
||||
|
||||
This filter will find all recipes that don't start with the word "Test": <br>
|
||||
`name NOT LIKE "Test%"`
|
||||
|
||||
> **_NOTE:_** for more information on this, [check out the SQL "LIKE" operator](https://www.w3schools.com/sql/sql_like.asp)
|
||||
|
||||
This filter will find all recipes that have particular slugs: <br>
|
||||
`slug IN ["pasta-fagioli", "delicious-ramen"]`
|
||||
|
||||
##### Nested Property filters
|
||||
When querying tables with relationships, you can filter properties on related tables. For instance, if you want to query all recipes owned by a particular user: <br>
|
||||
`user.username = "SousChef20220320"`
|
||||
@ -79,6 +96,9 @@ When querying tables with relationships, you can filter properties on related ta
|
||||
This timeline event filter will return all timeline events for recipes that were created after a particular date: <br>
|
||||
`recipe.createdAt >= "2023-02-25"`
|
||||
|
||||
This recipe filter will return all recipes that contains a particular set of tags: <br>
|
||||
`tags.name CONTAINS ALL ["Easy", "Cajun"]`
|
||||
|
||||
##### Compound Filters
|
||||
You can combine multiple filter statements using logical operators (`AND`, `OR`).
|
||||
|
||||
@ -96,3 +116,31 @@ You can have multiple filter groups combined by logical operators. You can defin
|
||||
|
||||
Here's a filter that will find all recipes updated between two particular times, but exclude the "Pasta Fagioli" recipe: <br>
|
||||
`(updatedAt > "2022-07-17T15:47:00Z" AND updatedAt < "2022-07-17T15:50:00Z") AND name <> "Pasta Fagioli"`
|
||||
|
||||
#### Advanced Ordering
|
||||
Pagination supports `orderBy`, `orderByNullPosition`, and `orderDirection` params to change how you want your query results to be ordered. These can be fine-tuned for more advanced use-cases.
|
||||
|
||||
##### Order By
|
||||
The pagination `orderBy` attribute allows you to sort your query results by a particular attribute. Sometimes, however, [you may want to sort by more than one attribute](https://www.w3schools.com/sql/sql_orderby.asp). This can be achieved by passing a comma-separated string to the `orderBy` parameter. For instance, if you want to sort recipes by their last made datetime, then by their created datetime, you can pass the following `orderBy` string: <br>
|
||||
`lastMade, createdAt`
|
||||
|
||||
Similar to the standard SQL `ORDER BY` logic, your attribute orders will be applied sequentially. In the above example, *first* recipes will be sorted by `lastMade`, *then* any recipes with an identical `lastMade` value are sorted by `createdAt`. In addition, standard SQL rules apply when handling results with null values (such as when joining related tables). You can apply the `NULLS FIRST` and `NULLS LAST` SQL expressions by setting the `orderByNullPosition` to "first" or "last". If left empty, the default SQL behavior is applied, [which is different depending on which database you're using](https://learnsql.com/blog/how-to-order-rows-with-nulls/).
|
||||
|
||||
##### Order Direction
|
||||
The query will be ordered in ascending or descending order, depending on what you pass to the pagination `orderDirection` param. You can either specify "asc" or "desc".
|
||||
|
||||
When sorting by multiple attributes, if you *also* want one or more of those sorts to be different directions, you can specify them with a colon. For instance, if, like our previous example, say you want to sort by `lastMade` and `createdAt`. However, this time, you want to sort by `lastMade` ascending, but `createdAt` descending. You could pass this `orderBy` string: <br>
|
||||
`lastMade:asc, createdAt:desc`
|
||||
|
||||
In the above example, whatever you pass to `orderDirection` will be ignored. If, however, you only specify the direction on one attribute, all other attributes will use the `orderDirection` value.
|
||||
|
||||
Consider this `orderBy` string: <br>
|
||||
`lastMade:asc, createdAt, slug`
|
||||
|
||||
And this `orderDirection` value: <br>
|
||||
`desc`
|
||||
|
||||
This will result in a recipe query where all recipes are sorted by `lastMade` ascending, then `createdAt` descending, and finally `slug` descending.
|
||||
|
||||
Similar to query filters, when querying tables with relationships, you can order by properties on related tables. For instance, if you want to query all foods with labels, sorted by label name, you could use this `orderBy`: <br>
|
||||
`label.name`
|
||||
|
@ -7,14 +7,14 @@ from typing import Any, Generic, TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import UUID4, BaseModel
|
||||
from sqlalchemy import Select, case, delete, func, select
|
||||
from sqlalchemy import Select, case, delete, func, nulls_first, nulls_last, select
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
from mealie.core.root_logger import get_logger
|
||||
from mealie.db.models._model_base import SqlAlchemyBase
|
||||
from mealie.schema._mealie import MealieModel
|
||||
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
|
||||
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationBase, PaginationQuery
|
||||
from mealie.schema.response.query_filter import QueryFilter
|
||||
from mealie.schema.response.query_search import SearchFilter
|
||||
|
||||
@ -372,32 +372,65 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
pagination.page = 1
|
||||
|
||||
if pagination.order_by:
|
||||
if order_attr := getattr(self.model, pagination.order_by, None):
|
||||
# queries handle uppercase and lowercase differently, which is undesirable
|
||||
if isinstance(order_attr.type, sqltypes.String):
|
||||
order_attr = func.lower(order_attr)
|
||||
|
||||
if pagination.order_direction == OrderDirection.asc:
|
||||
order_attr = order_attr.asc()
|
||||
elif pagination.order_direction == OrderDirection.desc:
|
||||
order_attr = order_attr.desc()
|
||||
|
||||
query = query.order_by(order_attr)
|
||||
|
||||
elif pagination.order_by == "random":
|
||||
# randomize outside of database, since not all db's can set random seeds
|
||||
# this solution is db-independent & stable to paging
|
||||
temp_query = query.with_only_columns(self.model.id)
|
||||
allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
|
||||
order = list(range(len(allids)))
|
||||
random.seed(pagination.pagination_seed)
|
||||
random.shuffle(order)
|
||||
random_dict = dict(zip(allids, order, strict=True))
|
||||
case_stmt = case(random_dict, value=self.model.id)
|
||||
query = query.order_by(case_stmt)
|
||||
query = self.add_order_by_to_query(query, pagination)
|
||||
|
||||
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
|
||||
|
||||
def add_order_by_to_query(self, query: Select, pagination: PaginationQuery) -> Select:
|
||||
if not pagination.order_by:
|
||||
return query
|
||||
|
||||
if pagination.order_by == "random":
|
||||
# randomize outside of database, since not all db's can set random seeds
|
||||
# this solution is db-independent & stable to paging
|
||||
temp_query = query.with_only_columns(self.model.id)
|
||||
allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
|
||||
order = list(range(len(allids)))
|
||||
random.seed(pagination.pagination_seed)
|
||||
random.shuffle(order)
|
||||
random_dict = dict(zip(allids, order, strict=True))
|
||||
case_stmt = case(random_dict, value=self.model.id)
|
||||
return query.order_by(case_stmt)
|
||||
|
||||
else:
|
||||
for order_by_val in pagination.order_by.split(","):
|
||||
try:
|
||||
order_by_val = order_by_val.strip()
|
||||
if ":" in order_by_val:
|
||||
order_by, order_dir_val = order_by_val.split(":")
|
||||
order_dir = OrderDirection(order_dir_val)
|
||||
else:
|
||||
order_by = order_by_val
|
||||
order_dir = pagination.order_direction
|
||||
|
||||
_, order_attr, query = QueryFilter.get_model_and_model_attr_from_attr_string(
|
||||
order_by, self.model, query=query
|
||||
)
|
||||
|
||||
if order_dir is OrderDirection.asc:
|
||||
order_attr = order_attr.asc()
|
||||
elif order_dir is OrderDirection.desc:
|
||||
order_attr = order_attr.desc()
|
||||
|
||||
# queries handle uppercase and lowercase differently, which is undesirable
|
||||
if isinstance(order_attr.type, sqltypes.String):
|
||||
order_attr = func.lower(order_attr)
|
||||
|
||||
if pagination.order_by_null_position is OrderByNullPosition.first:
|
||||
order_attr = nulls_first(order_attr)
|
||||
elif pagination.order_by_null_position is OrderByNullPosition.last:
|
||||
order_attr = nulls_last(order_attr)
|
||||
|
||||
query = query.order_by(order_attr)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Invalid order_by statement "{pagination.order_by}": "{order_by_val}" is invalid',
|
||||
) from e
|
||||
|
||||
return query
|
||||
|
||||
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
|
||||
search_filter = SearchFilter(self.session, search, schema._normalize_search)
|
||||
return search_filter.filter_query_by_search(query, schema, self.model)
|
||||
|
@ -16,6 +16,11 @@ class OrderDirection(str, enum.Enum):
|
||||
desc = "desc"
|
||||
|
||||
|
||||
class OrderByNullPosition(str, enum.Enum):
|
||||
first = "first"
|
||||
last = "last"
|
||||
|
||||
|
||||
class RecipeSearchQuery(MealieModel):
|
||||
cookbook: UUID4 | str | None
|
||||
require_all_categories: bool = False
|
||||
@ -30,6 +35,7 @@ class PaginationQuery(MealieModel):
|
||||
page: int = 1
|
||||
per_page: int = 50
|
||||
order_by: str = "created_at"
|
||||
order_by_null_position: OrderByNullPosition | None = None
|
||||
order_direction: OrderDirection = OrderDirection.desc
|
||||
query_filter: str | None = None
|
||||
pagination_seed: str | None = None
|
||||
|
@ -13,9 +13,10 @@ from sqlalchemy import ColumnElement, Select, and_, inspect, or_
|
||||
from sqlalchemy.orm import InstrumentedAttribute, Mapper
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
from mealie.db.models._model_base import SqlAlchemyBase
|
||||
from mealie.db.models._model_utils.guid import GUID
|
||||
|
||||
Model = TypeVar("Model")
|
||||
Model = TypeVar("Model", bound=SqlAlchemyBase)
|
||||
|
||||
|
||||
class RelationalKeyword(Enum):
|
||||
@ -238,6 +239,53 @@ class QueryFilter:
|
||||
if i == len(group) - 1:
|
||||
return consolidated_group_builder.self_group()
|
||||
|
||||
@classmethod
|
||||
def get_model_and_model_attr_from_attr_string(
|
||||
cls, attr_string: str, model: type[Model], *, query: Select | None = None
|
||||
) -> tuple[SqlAlchemyBase, InstrumentedAttribute, Select | None]:
|
||||
"""
|
||||
Take an attribute string and traverse a database model and its relationships to get the desired
|
||||
model and model attribute. Optionally provide a query to apply the necessary table joins.
|
||||
|
||||
If the attribute string is invalid, raises a `ValueError`.
|
||||
|
||||
For instance, the attribute string "user.name" on `RecipeModel`
|
||||
will return the `User` model's `name` attribute.
|
||||
|
||||
Works with shallow attributes (e.g. "slug" from `RecipeModel`)
|
||||
and arbitrarily deep ones (e.g. "recipe.group.preferences" on `RecipeTimelineEvent`).
|
||||
"""
|
||||
model_attr: InstrumentedAttribute | None = None
|
||||
attribute_chain = attr_string.split(".")
|
||||
if not attribute_chain:
|
||||
raise ValueError("invalid query string: attribute name cannot be empty")
|
||||
|
||||
current_model: SqlAlchemyBase = model # type: ignore
|
||||
for i, attribute_link in enumerate(attribute_chain):
|
||||
try:
|
||||
model_attr = getattr(current_model, attribute_link)
|
||||
|
||||
# at the end of the chain there are no more relationships to inspect
|
||||
if i == len(attribute_chain) - 1:
|
||||
break
|
||||
|
||||
if query is not None:
|
||||
query = query.join(
|
||||
model_attr, isouter=True
|
||||
) # we use outer joins to not unintentionally filter out values
|
||||
|
||||
mapper: Mapper = inspect(current_model)
|
||||
relationship = mapper.relationships[attribute_link]
|
||||
current_model = relationship.mapper.class_
|
||||
|
||||
except (AttributeError, KeyError) as e:
|
||||
raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
|
||||
|
||||
if model_attr is None:
|
||||
raise ValueError(f"invalid attribute string: '{attr_string}'")
|
||||
|
||||
return current_model, model_attr, query
|
||||
|
||||
def filter_query(self, query: Select, model: type[Model]) -> Select:
|
||||
# join tables and build model chain
|
||||
attr_model_map: dict[int, Any] = {}
|
||||
@ -246,29 +294,10 @@ class QueryFilter:
|
||||
if not isinstance(component, QueryFilterComponent):
|
||||
continue
|
||||
|
||||
attribute_chain = component.attribute_name.split(".")
|
||||
if not attribute_chain:
|
||||
raise ValueError("invalid query string: attribute name cannot be empty")
|
||||
|
||||
current_model = model
|
||||
for j, attribute_link in enumerate(attribute_chain):
|
||||
try:
|
||||
model_attr = getattr(current_model, attribute_link)
|
||||
|
||||
# at the end of the chain there are no more relationships to inspect
|
||||
if j == len(attribute_chain) - 1:
|
||||
break
|
||||
|
||||
query = query.join(model_attr)
|
||||
mapper: Mapper = inspect(current_model)
|
||||
relationship = mapper.relationships[attribute_link]
|
||||
current_model = relationship.mapper.class_
|
||||
|
||||
except (AttributeError, KeyError) as e:
|
||||
raise ValueError(
|
||||
f"invalid query string: '{component.attribute_name}' does not exist on this schema"
|
||||
) from e
|
||||
attr_model_map[i] = current_model
|
||||
nested_model, model_attr, query = self.get_model_and_model_attr_from_attr_string(
|
||||
component.attribute_name, model, query=query
|
||||
)
|
||||
attr_model_map[i] = nested_model
|
||||
|
||||
# build query filter
|
||||
partial_group: list[ColumnElement] = []
|
||||
|
2
tests/fixtures/fixture_users.py
vendored
2
tests/fixtures/fixture_users.py
vendored
@ -3,10 +3,10 @@ from typing import Generator
|
||||
|
||||
from pytest import fixture
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.db.models.users.users import AuthMethod
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
|
||||
from tests import utils
|
||||
from tests.utils import api_routes
|
||||
from tests.utils.factories import random_string
|
||||
|
@ -1,3 +1,4 @@
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import date, datetime, timedelta
|
||||
@ -7,21 +8,54 @@ from urllib.parse import parse_qsl, urlsplit
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from humps import camelize
|
||||
from pydantic import UUID4
|
||||
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.repos.repository_units import RepositoryUnit
|
||||
from mealie.schema.group.group_shopping_list import (
|
||||
ShoppingListItemCreate,
|
||||
ShoppingListMultiPurposeLabelCreate,
|
||||
ShoppingListMultiPurposeLabelOut,
|
||||
ShoppingListSave,
|
||||
)
|
||||
from mealie.schema.labels.multi_purpose_label import MultiPurposeLabelSave
|
||||
from mealie.schema.meal_plan.new_meal import CreatePlanEntry
|
||||
from mealie.schema.recipe import Recipe
|
||||
from mealie.schema.recipe.recipe_category import CategorySave, TagSave
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientFood, SaveIngredientUnit
|
||||
from mealie.schema.recipe.recipe_tool import RecipeToolSave
|
||||
from mealie.schema.response.pagination import PaginationQuery
|
||||
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationQuery
|
||||
from mealie.services.seeder.seeder_service import SeederService
|
||||
from tests.utils import api_routes
|
||||
from tests.utils.factories import random_int, random_string
|
||||
from tests.utils.fixture_schemas import TestUser
|
||||
|
||||
|
||||
class Reversor:
|
||||
"""
|
||||
Enables reversed sorting
|
||||
|
||||
https://stackoverflow.com/a/56842689
|
||||
"""
|
||||
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
|
||||
def __eq__(self, other):
|
||||
return other.obj == self.obj
|
||||
|
||||
def __lt__(self, other):
|
||||
return other.obj < self.obj
|
||||
|
||||
|
||||
def get_label_position_from_label_id(label_id: UUID4, label_settings: list[ShoppingListMultiPurposeLabelOut]) -> int:
|
||||
for label_setting in label_settings:
|
||||
if label_setting.label_id == label_id:
|
||||
return label_setting.position
|
||||
|
||||
raise Exception("Something went wrong when parsing label settings")
|
||||
|
||||
|
||||
def test_repository_pagination(database: AllRepositories, unique_user: TestUser):
|
||||
group = database.groups.get_one(unique_user.group_id)
|
||||
assert group
|
||||
@ -153,14 +187,6 @@ def query_units(database: AllRepositories, unique_user: TestUser):
|
||||
unit_ids = [unit.id for unit in [unit_1, unit_2, unit_3]]
|
||||
units_repo = database.ingredient_units.by_group(unique_user.group_id) # type: ignore
|
||||
|
||||
# make sure we can get all of our test units
|
||||
query = PaginationQuery(page=1, per_page=-1)
|
||||
all_units = units_repo.page_all(query).items
|
||||
assert len(all_units) == 3
|
||||
|
||||
for unit in all_units:
|
||||
assert unit.id in unit_ids
|
||||
|
||||
yield units_repo, unit_1, unit_2, unit_3
|
||||
|
||||
for unit_id in unit_ids:
|
||||
@ -233,7 +259,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]")
|
||||
unit_results = units_repo.page_all(query).items
|
||||
|
||||
assert len(unit_results) == 2
|
||||
result_ids = {unit.id for unit in unit_results}
|
||||
assert unit_1.id in result_ids
|
||||
assert unit_2.id in result_ids
|
||||
@ -242,7 +267,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]")
|
||||
unit_results = units_repo.page_all(query).items
|
||||
|
||||
assert len(unit_results) == 1
|
||||
result_ids = {unit.id for unit in unit_results}
|
||||
assert unit_1.id not in result_ids
|
||||
assert unit_2.id not in result_ids
|
||||
@ -251,7 +275,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]')
|
||||
unit_results = units_repo.page_all(query).items
|
||||
|
||||
assert len(unit_results) == 1
|
||||
result_ids = {unit.id for unit in unit_results}
|
||||
assert unit_1.id not in result_ids
|
||||
assert unit_2.id not in result_ids
|
||||
@ -521,6 +544,282 @@ def test_pagination_filter_datetimes(
|
||||
assert len(unit_ids) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order_direction", [OrderDirection.asc, OrderDirection.desc], ids=["ascending", "descending"])
|
||||
def test_pagination_order_by_multiple(
|
||||
database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
|
||||
):
|
||||
current_time = datetime.now()
|
||||
|
||||
alphabet = ["a", "b", "c", "d", "e"]
|
||||
abbreviations = alphabet.copy()
|
||||
descriptions = alphabet.copy()
|
||||
|
||||
random.shuffle(abbreviations)
|
||||
random.shuffle(descriptions)
|
||||
assert abbreviations != descriptions
|
||||
|
||||
units_to_create: list[SaveIngredientUnit] = []
|
||||
for abbreviation in abbreviations:
|
||||
for description in descriptions:
|
||||
units_to_create.append(
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_user.group_id,
|
||||
name=random_string(),
|
||||
abbreviation=abbreviation,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_units = database.ingredient_units.create_many(units_to_create)
|
||||
sorted_units.sort(key=lambda x: (x.abbreviation, x.description), reverse=order_direction is OrderDirection.desc)
|
||||
|
||||
query = database.ingredient_units.page_all(
|
||||
PaginationQuery(
|
||||
page=1,
|
||||
per_page=-1,
|
||||
order_by="abbreviation, description",
|
||||
order_direction=order_direction,
|
||||
query_filter=f'created_at >= "{current_time.isoformat()}"',
|
||||
)
|
||||
)
|
||||
|
||||
assert query.items == sorted_units
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"order_by_str, order_direction",
|
||||
[
|
||||
("abbreviation:asc, description:desc", OrderDirection.asc),
|
||||
("abbreviation:asc, description:desc", OrderDirection.desc),
|
||||
("abbreviation, description:desc", OrderDirection.asc),
|
||||
("abbreviation:asc, description", OrderDirection.desc),
|
||||
],
|
||||
ids=[
|
||||
"order_by_asc_explicit_order_bys",
|
||||
"order_by_desc_explicit_order_bys",
|
||||
"order_by_asc_inferred_order_by",
|
||||
"order_by_desc_inferred_order_by",
|
||||
],
|
||||
)
|
||||
def test_pagination_order_by_multiple_directions(
|
||||
database: AllRepositories, unique_user: TestUser, order_by_str: str, order_direction: OrderDirection
|
||||
):
|
||||
current_time = datetime.now()
|
||||
|
||||
alphabet = ["a", "b", "c", "d", "e"]
|
||||
abbreviations = alphabet.copy()
|
||||
descriptions = alphabet.copy()
|
||||
|
||||
random.shuffle(abbreviations)
|
||||
random.shuffle(descriptions)
|
||||
assert abbreviations != descriptions
|
||||
|
||||
units_to_create: list[SaveIngredientUnit] = []
|
||||
for abbreviation in abbreviations:
|
||||
for description in descriptions:
|
||||
units_to_create.append(
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_user.group_id,
|
||||
name=random_string(),
|
||||
abbreviation=abbreviation,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_units = database.ingredient_units.create_many(units_to_create)
|
||||
|
||||
# sort by abbreviation ascending, description descending
|
||||
sorted_units.sort(key=lambda x: (x.abbreviation, Reversor(x.description)))
|
||||
|
||||
query = database.ingredient_units.page_all(
|
||||
PaginationQuery(
|
||||
page=1,
|
||||
per_page=-1,
|
||||
order_by=order_by_str,
|
||||
order_direction=order_direction,
|
||||
query_filter=f'created_at >= "{current_time.isoformat()}"',
|
||||
)
|
||||
)
|
||||
|
||||
assert query.items == sorted_units
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"order_direction",
|
||||
[OrderDirection.asc, OrderDirection.desc],
|
||||
ids=["order_ascending", "order_descending"],
|
||||
)
|
||||
def test_pagination_order_by_nested_model(
|
||||
database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
|
||||
):
|
||||
current_time = datetime.now()
|
||||
|
||||
alphabet = ["a", "b", "c", "d", "e"]
|
||||
labels = database.group_multi_purpose_labels.create_many(
|
||||
[MultiPurposeLabelSave(group_id=unique_user.group_id, name=letter) for letter in alphabet]
|
||||
)
|
||||
random.shuffle(labels)
|
||||
|
||||
sorted_foods = database.ingredient_foods.create_many(
|
||||
[SaveIngredientFood(group_id=unique_user.group_id, name=random_string(), label_id=label.id) for label in labels]
|
||||
)
|
||||
|
||||
sorted_foods.sort(key=lambda x: x.label.name, reverse=order_direction is OrderDirection.desc) # type: ignore
|
||||
query = database.ingredient_foods.page_all(
|
||||
PaginationQuery(
|
||||
page=1,
|
||||
per_page=-1,
|
||||
order_by="label.name",
|
||||
order_direction=order_direction,
|
||||
query_filter=f'created_at >= "{current_time.isoformat()}"',
|
||||
)
|
||||
)
|
||||
|
||||
assert query.items == sorted_foods
|
||||
|
||||
|
||||
def test_pagination_order_by_doesnt_filter(database: AllRepositories, unique_user: TestUser):
|
||||
current_time = datetime.now()
|
||||
|
||||
label = database.group_multi_purpose_labels.create(
|
||||
MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
|
||||
)
|
||||
food_with_label = database.ingredient_foods.create(
|
||||
SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
|
||||
)
|
||||
food_without_label = database.ingredient_foods.create(
|
||||
SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
|
||||
)
|
||||
|
||||
query = database.ingredient_foods.by_group(unique_user.group_id).page_all(
|
||||
PaginationQuery(per_page=-1, query_filter=f"created_at>{current_time.isoformat()}", order_by="label.name")
|
||||
)
|
||||
assert len(query.items) == 2
|
||||
found_ids = {item.id for item in query.items}
|
||||
assert food_with_label.id in found_ids
|
||||
assert food_without_label.id in found_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"null_position, order_direction",
|
||||
[
|
||||
(OrderByNullPosition.first, OrderDirection.asc),
|
||||
(OrderByNullPosition.last, OrderDirection.asc),
|
||||
(OrderByNullPosition.first, OrderDirection.asc),
|
||||
(OrderByNullPosition.last, OrderDirection.asc),
|
||||
],
|
||||
ids=[
|
||||
"order_by_nulls_first_order_direction_asc",
|
||||
"order_by_nulls_last_order_direction_asc",
|
||||
"order_by_nulls_first_order_direction_desc",
|
||||
"order_by_nulls_last_order_direction_desc",
|
||||
],
|
||||
)
|
||||
def test_pagination_order_by_nulls(
|
||||
database: AllRepositories,
|
||||
unique_user: TestUser,
|
||||
null_position: OrderByNullPosition,
|
||||
order_direction: OrderDirection,
|
||||
):
|
||||
current_time = datetime.now()
|
||||
|
||||
label = database.group_multi_purpose_labels.create(
|
||||
MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
|
||||
)
|
||||
food_with_label = database.ingredient_foods.create(
|
||||
SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
|
||||
)
|
||||
food_without_label = database.ingredient_foods.create(
|
||||
SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
|
||||
)
|
||||
|
||||
query = database.ingredient_foods.page_all(
|
||||
PaginationQuery(
|
||||
per_page=-1,
|
||||
query_filter=f"created_at >= {current_time.isoformat()}",
|
||||
order_by="label.name",
|
||||
order_by_null_position=null_position,
|
||||
order_direction=order_direction,
|
||||
)
|
||||
)
|
||||
assert len(query.items) == 2
|
||||
|
||||
if null_position is OrderByNullPosition.first:
|
||||
assert query.items[0] == food_without_label
|
||||
assert query.items[1] == food_with_label
|
||||
else:
|
||||
assert query.items[0] == food_with_label
|
||||
assert query.items[1] == food_without_label
|
||||
|
||||
|
||||
def test_pagination_shopping_list_items_with_labels(database: AllRepositories, unique_user: TestUser):
|
||||
# create a shopping list and populate it with some items with labels, and some without labels
|
||||
shopping_list = database.group_shopping_lists.create(
|
||||
ShoppingListSave(name=random_string(), group_id=unique_user.group_id)
|
||||
)
|
||||
|
||||
labels = database.group_multi_purpose_labels.create_many(
|
||||
[MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id) for _ in range(8)]
|
||||
)
|
||||
random.shuffle(labels)
|
||||
|
||||
label_settings = database.shopping_list_multi_purpose_labels.create_many(
|
||||
[
|
||||
ShoppingListMultiPurposeLabelCreate(shopping_list_id=shopping_list.id, label_id=label.id, position=i)
|
||||
for i, label in enumerate(labels)
|
||||
]
|
||||
)
|
||||
random.shuffle(label_settings)
|
||||
|
||||
with_labels_positions = list(range(0, random_int(20, 25)))
|
||||
random.shuffle(with_labels_positions)
|
||||
items_with_labels = database.group_shopping_list_item.create_many(
|
||||
[
|
||||
ShoppingListItemCreate(
|
||||
note=random_string(),
|
||||
shopping_list_id=shopping_list.id,
|
||||
label_id=random.choice(labels).id,
|
||||
position=position,
|
||||
)
|
||||
for position in with_labels_positions
|
||||
]
|
||||
)
|
||||
# sort by item label position ascending, then item position ascending
|
||||
items_with_labels.sort(
|
||||
key=lambda x: (
|
||||
get_label_position_from_label_id(x.label.id, label_settings), # type: ignore[union-attr]
|
||||
x.position,
|
||||
)
|
||||
)
|
||||
|
||||
without_labels_positions = list(range(len(with_labels_positions), random_int(5, 10)))
|
||||
random.shuffle(without_labels_positions)
|
||||
items_without_labels = database.group_shopping_list_item.create_many(
|
||||
[
|
||||
ShoppingListItemCreate(
|
||||
note=random_string(),
|
||||
shopping_list_id=shopping_list.id,
|
||||
label_id=random.choice(labels).id,
|
||||
position=position,
|
||||
)
|
||||
for position in without_labels_positions
|
||||
]
|
||||
)
|
||||
items_without_labels.sort(key=lambda x: x.position)
|
||||
|
||||
# verify they're in order
|
||||
query = database.group_shopping_list_item.page_all(
|
||||
PaginationQuery(
|
||||
per_page=-1,
|
||||
order_by="label.shopping_lists_label_settings.position, position",
|
||||
order_direction=OrderDirection.asc,
|
||||
order_by_null_position=OrderByNullPosition.first,
|
||||
),
|
||||
)
|
||||
|
||||
assert query.items == items_without_labels + items_with_labels
|
||||
|
||||
|
||||
def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
|
||||
yesterday = date.today() - timedelta(days=1)
|
||||
today = date.today()
|
||||
@ -616,7 +915,11 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien
|
||||
units_repo = query_units[0]
|
||||
unit_1 = query_units[1]
|
||||
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter="useAbbreviation=true")
|
||||
query = PaginationQuery(
|
||||
page=1,
|
||||
per_page=-1,
|
||||
query_filter=f"useAbbreviation=true AND id IN [{', '.join([str(unit.id) for unit in query_units[1:]])}]",
|
||||
)
|
||||
unit_results = units_repo.page_all(query).items
|
||||
assert len(unit_results) == 1
|
||||
assert unit_results[0].id == unit_1.id
|
||||
@ -630,7 +933,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
|
||||
unit_results = units_repo.page_all(query).items
|
||||
|
||||
assert len(unit_results) == 2
|
||||
result_ids = {unit.id for unit in unit_results}
|
||||
assert unit_1.id in result_ids
|
||||
assert unit_2.id in result_ids
|
||||
@ -640,7 +942,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
|
||||
unit_results = units_repo.page_all(query).items
|
||||
|
||||
assert len(unit_results) == 2
|
||||
result_ids = {unit.id for unit in unit_results}
|
||||
assert unit_1.id in result_ids
|
||||
assert unit_2.id in result_ids
|
||||
|
Loading…
x
Reference in New Issue
Block a user