From aec4cb4f31879c22798eb16a480e0cb418646a36 Mon Sep 17 00:00:00 2001
From: Michael Genson <71845777+michael-genson@users.noreply.github.com>
Date: Thu, 14 Sep 2023 09:09:05 -0500
Subject: [PATCH] feat: Advanced Query Filter Record Ordering (#2530)
* added support for multiple order_by strs
* refactored qf to expose nested attr logic
* added nested attr support to order_by
* added tests
* changed unique user to be function-level
* updated docs
* added support for null handling
* updated docs
* undid fixture changes
* fix leaky tests
* added advanced shopping list item test
---------
Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
---
.../getting-started/api-usage.md | 48 +++
mealie/repos/repository_generic.py | 83 +++--
mealie/schema/response/pagination.py | 6 +
mealie/schema/response/query_filter.py | 77 ++--
tests/fixtures/fixture_users.py | 2 +-
.../repository_tests/test_pagination.py | 333 +++++++++++++++++-
6 files changed, 483 insertions(+), 66 deletions(-)
diff --git a/docs/docs/documentation/getting-started/api-usage.md b/docs/docs/documentation/getting-started/api-usage.md
index c12464afd4f0..e34f4afad161 100644
--- a/docs/docs/documentation/getting-started/api-usage.md
+++ b/docs/docs/documentation/getting-started/api-usage.md
@@ -72,6 +72,23 @@ This filter will find all recipes created on or after a particular date:
This filter will find all units that have `useAbbreviation` disabled:
`useAbbreviation = false`
+This filter will find all foods that are not named "carrot":
+`name <> "carrot"`
+
+##### Keyword Filters
+The API supports many SQL keywords, such as `IS NULL` and `IN`, as well as their negations (e.g. `IS NOT NULL` and `NOT IN`).
+
+Here is an example of a filter that returns all recipes where the "last made" value is not null:
+`lastMade IS NOT NULL`
+
+This filter will find all recipes that don't start with the word "Test":
+`name NOT LIKE "Test%"`
+
+> **_NOTE:_** for more information on this, [check out the SQL "LIKE" operator](https://www.w3schools.com/sql/sql_like.asp)
+
+This filter will find all recipes that have particular slugs:
+`slug IN ["pasta-fagioli", "delicious-ramen"]`
+
##### Nested Property filters
When querying tables with relationships, you can filter properties on related tables. For instance, if you want to query all recipes owned by a particular user:
`user.username = "SousChef20220320"`
@@ -79,6 +96,9 @@ When querying tables with relationships, you can filter properties on related ta
This timeline event filter will return all timeline events for recipes that were created after a particular date:
`recipe.createdAt >= "2023-02-25"`
+This recipe filter will return all recipes that contains a particular set of tags:
+`tags.name CONTAINS ALL ["Easy", "Cajun"]`
+
##### Compound Filters
You can combine multiple filter statements using logical operators (`AND`, `OR`).
@@ -96,3 +116,31 @@ You can have multiple filter groups combined by logical operators. You can defin
Here's a filter that will find all recipes updated between two particular times, but exclude the "Pasta Fagioli" recipe:
`(updatedAt > "2022-07-17T15:47:00Z" AND updatedAt < "2022-07-17T15:50:00Z") AND name <> "Pasta Fagioli"`
+
+#### Advanced Ordering
+Pagination supports `orderBy`, `orderByNullPosition`, and `orderDirection` params to change how you want your query results to be ordered. These can be fine-tuned for more advanced use-cases.
+
+##### Order By
+The pagination `orderBy` attribute allows you to sort your query results by a particular attribute. Sometimes, however, [you may want to sort by more than one attribute](https://www.w3schools.com/sql/sql_orderby.asp). This can be achieved by passing a comma-separated string to the `orderBy` parameter. For instance, if you want to sort recipes by their last made datetime, then by their created datetime, you can pass the following `orderBy` string:
+`lastMade, createdAt`
+
+Similar to the standard SQL `ORDER BY` logic, your attribute orders will be applied sequentially. In the above example, *first* recipes will be sorted by `lastMade`, *then* any recipes with an identical `lastMade` value are sorted by `createdAt`. In addition, standard SQL rules apply when handling results with null values (such as when joining related tables). You can apply the `NULLS FIRST` and `NULLS LAST` SQL expressions by setting the `orderByNullPosition` to "first" or "last". If left empty, the default SQL behavior is applied, [which is different depending on which database you're using](https://learnsql.com/blog/how-to-order-rows-with-nulls/).
+
+##### Order Direction
+The query will be ordered in ascending or descending order, depending on what you pass to the pagination `orderDirection` param. You can either specify "asc" or "desc".
+
+When sorting by multiple attributes, if you *also* want one or more of those sorts to be different directions, you can specify them with a colon. For instance, if, like our previous example, say you want to sort by `lastMade` and `createdAt`. However, this time, you want to sort by `lastMade` ascending, but `createdAt` descending. You could pass this `orderBy` string:
+`lastMade:asc, createdAt:desc`
+
+In the above example, whatever you pass to `orderDirection` will be ignored. If, however, you only specify the direction on one attribute, all other attributes will use the `orderDirection` value.
+
+Consider this `orderBy` string:
+`lastMade:asc, createdAt, slug`
+
+And this `orderDirection` value:
+`desc`
+
+This will result in a recipe query where all recipes are sorted by `lastMade` ascending, then `createdAt` descending, and finally `slug` descending.
+
+Similar to query filters, when querying tables with relationships, you can order by properties on related tables. For instance, if you want to query all foods with labels, sorted by label name, you could use this `orderBy`:
+`label.name`
diff --git a/mealie/repos/repository_generic.py b/mealie/repos/repository_generic.py
index 85419090d978..fefaacdcb9cd 100644
--- a/mealie/repos/repository_generic.py
+++ b/mealie/repos/repository_generic.py
@@ -7,14 +7,14 @@ from typing import Any, Generic, TypeVar
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
-from sqlalchemy import Select, case, delete, func, select
+from sqlalchemy import Select, case, delete, func, nulls_first, nulls_last, select
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes
from mealie.core.root_logger import get_logger
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema._mealie import MealieModel
-from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
+from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
from mealie.schema.response.query_search import SearchFilter
@@ -372,32 +372,65 @@ class RepositoryGeneric(Generic[Schema, Model]):
pagination.page = 1
if pagination.order_by:
- if order_attr := getattr(self.model, pagination.order_by, None):
- # queries handle uppercase and lowercase differently, which is undesirable
- if isinstance(order_attr.type, sqltypes.String):
- order_attr = func.lower(order_attr)
-
- if pagination.order_direction == OrderDirection.asc:
- order_attr = order_attr.asc()
- elif pagination.order_direction == OrderDirection.desc:
- order_attr = order_attr.desc()
-
- query = query.order_by(order_attr)
-
- elif pagination.order_by == "random":
- # randomize outside of database, since not all db's can set random seeds
- # this solution is db-independent & stable to paging
- temp_query = query.with_only_columns(self.model.id)
- allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
- order = list(range(len(allids)))
- random.seed(pagination.pagination_seed)
- random.shuffle(order)
- random_dict = dict(zip(allids, order, strict=True))
- case_stmt = case(random_dict, value=self.model.id)
- query = query.order_by(case_stmt)
+ query = self.add_order_by_to_query(query, pagination)
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
+ def add_order_by_to_query(self, query: Select, pagination: PaginationQuery) -> Select:
+ if not pagination.order_by:
+ return query
+
+ if pagination.order_by == "random":
+ # randomize outside of database, since not all db's can set random seeds
+ # this solution is db-independent & stable to paging
+ temp_query = query.with_only_columns(self.model.id)
+ allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
+ order = list(range(len(allids)))
+ random.seed(pagination.pagination_seed)
+ random.shuffle(order)
+ random_dict = dict(zip(allids, order, strict=True))
+ case_stmt = case(random_dict, value=self.model.id)
+ return query.order_by(case_stmt)
+
+ else:
+ for order_by_val in pagination.order_by.split(","):
+ try:
+ order_by_val = order_by_val.strip()
+ if ":" in order_by_val:
+ order_by, order_dir_val = order_by_val.split(":")
+ order_dir = OrderDirection(order_dir_val)
+ else:
+ order_by = order_by_val
+ order_dir = pagination.order_direction
+
+ _, order_attr, query = QueryFilter.get_model_and_model_attr_from_attr_string(
+ order_by, self.model, query=query
+ )
+
+ if order_dir is OrderDirection.asc:
+ order_attr = order_attr.asc()
+ elif order_dir is OrderDirection.desc:
+ order_attr = order_attr.desc()
+
+ # queries handle uppercase and lowercase differently, which is undesirable
+ if isinstance(order_attr.type, sqltypes.String):
+ order_attr = func.lower(order_attr)
+
+ if pagination.order_by_null_position is OrderByNullPosition.first:
+ order_attr = nulls_first(order_attr)
+ elif pagination.order_by_null_position is OrderByNullPosition.last:
+ order_attr = nulls_last(order_attr)
+
+ query = query.order_by(order_attr)
+
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=f'Invalid order_by statement "{pagination.order_by}": "{order_by_val}" is invalid',
+ ) from e
+
+ return query
+
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
search_filter = SearchFilter(self.session, search, schema._normalize_search)
return search_filter.filter_query_by_search(query, schema, self.model)
diff --git a/mealie/schema/response/pagination.py b/mealie/schema/response/pagination.py
index f6605ed67dd2..fdde83f1073b 100644
--- a/mealie/schema/response/pagination.py
+++ b/mealie/schema/response/pagination.py
@@ -16,6 +16,11 @@ class OrderDirection(str, enum.Enum):
desc = "desc"
+class OrderByNullPosition(str, enum.Enum):
+ first = "first"
+ last = "last"
+
+
class RecipeSearchQuery(MealieModel):
cookbook: UUID4 | str | None
require_all_categories: bool = False
@@ -30,6 +35,7 @@ class PaginationQuery(MealieModel):
page: int = 1
per_page: int = 50
order_by: str = "created_at"
+ order_by_null_position: OrderByNullPosition | None = None
order_direction: OrderDirection = OrderDirection.desc
query_filter: str | None = None
pagination_seed: str | None = None
diff --git a/mealie/schema/response/query_filter.py b/mealie/schema/response/query_filter.py
index ad351225ecaa..c916dc025d14 100644
--- a/mealie/schema/response/query_filter.py
+++ b/mealie/schema/response/query_filter.py
@@ -13,9 +13,10 @@ from sqlalchemy import ColumnElement, Select, and_, inspect, or_
from sqlalchemy.orm import InstrumentedAttribute, Mapper
from sqlalchemy.sql import sqltypes
+from mealie.db.models._model_base import SqlAlchemyBase
from mealie.db.models._model_utils.guid import GUID
-Model = TypeVar("Model")
+Model = TypeVar("Model", bound=SqlAlchemyBase)
class RelationalKeyword(Enum):
@@ -238,6 +239,53 @@ class QueryFilter:
if i == len(group) - 1:
return consolidated_group_builder.self_group()
+ @classmethod
+ def get_model_and_model_attr_from_attr_string(
+ cls, attr_string: str, model: type[Model], *, query: Select | None = None
+ ) -> tuple[SqlAlchemyBase, InstrumentedAttribute, Select | None]:
+ """
+ Take an attribute string and traverse a database model and its relationships to get the desired
+ model and model attribute. Optionally provide a query to apply the necessary table joins.
+
+ If the attribute string is invalid, raises a `ValueError`.
+
+ For instance, the attribute string "user.name" on `RecipeModel`
+ will return the `User` model's `name` attribute.
+
+ Works with shallow attributes (e.g. "slug" from `RecipeModel`)
+ and arbitrarily deep ones (e.g. "recipe.group.preferences" on `RecipeTimelineEvent`).
+ """
+ model_attr: InstrumentedAttribute | None = None
+ attribute_chain = attr_string.split(".")
+ if not attribute_chain:
+ raise ValueError("invalid query string: attribute name cannot be empty")
+
+ current_model: SqlAlchemyBase = model # type: ignore
+ for i, attribute_link in enumerate(attribute_chain):
+ try:
+ model_attr = getattr(current_model, attribute_link)
+
+ # at the end of the chain there are no more relationships to inspect
+ if i == len(attribute_chain) - 1:
+ break
+
+ if query is not None:
+ query = query.join(
+ model_attr, isouter=True
+ ) # we use outer joins to not unintentionally filter out values
+
+ mapper: Mapper = inspect(current_model)
+ relationship = mapper.relationships[attribute_link]
+ current_model = relationship.mapper.class_
+
+ except (AttributeError, KeyError) as e:
+ raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
+
+ if model_attr is None:
+ raise ValueError(f"invalid attribute string: '{attr_string}'")
+
+ return current_model, model_attr, query
+
def filter_query(self, query: Select, model: type[Model]) -> Select:
# join tables and build model chain
attr_model_map: dict[int, Any] = {}
@@ -246,29 +294,10 @@ class QueryFilter:
if not isinstance(component, QueryFilterComponent):
continue
- attribute_chain = component.attribute_name.split(".")
- if not attribute_chain:
- raise ValueError("invalid query string: attribute name cannot be empty")
-
- current_model = model
- for j, attribute_link in enumerate(attribute_chain):
- try:
- model_attr = getattr(current_model, attribute_link)
-
- # at the end of the chain there are no more relationships to inspect
- if j == len(attribute_chain) - 1:
- break
-
- query = query.join(model_attr)
- mapper: Mapper = inspect(current_model)
- relationship = mapper.relationships[attribute_link]
- current_model = relationship.mapper.class_
-
- except (AttributeError, KeyError) as e:
- raise ValueError(
- f"invalid query string: '{component.attribute_name}' does not exist on this schema"
- ) from e
- attr_model_map[i] = current_model
+ nested_model, model_attr, query = self.get_model_and_model_attr_from_attr_string(
+ component.attribute_name, model, query=query
+ )
+ attr_model_map[i] = nested_model
# build query filter
partial_group: list[ColumnElement] = []
diff --git a/tests/fixtures/fixture_users.py b/tests/fixtures/fixture_users.py
index 6d938b583480..1bc7fe752632 100644
--- a/tests/fixtures/fixture_users.py
+++ b/tests/fixtures/fixture_users.py
@@ -3,10 +3,10 @@ from typing import Generator
from pytest import fixture
from starlette.testclient import TestClient
+
from mealie.db.db_setup import session_context
from mealie.db.models.users.users import AuthMethod
from mealie.repos.all_repositories import get_repositories
-
from tests import utils
from tests.utils import api_routes
from tests.utils.factories import random_string
diff --git a/tests/unit_tests/repository_tests/test_pagination.py b/tests/unit_tests/repository_tests/test_pagination.py
index dab953f66fbe..91bc65ea7b6a 100644
--- a/tests/unit_tests/repository_tests/test_pagination.py
+++ b/tests/unit_tests/repository_tests/test_pagination.py
@@ -1,3 +1,4 @@
+import random
import time
from collections import defaultdict
from datetime import date, datetime, timedelta
@@ -7,21 +8,54 @@ from urllib.parse import parse_qsl, urlsplit
import pytest
from fastapi.testclient import TestClient
from humps import camelize
+from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.repos.repository_units import RepositoryUnit
+from mealie.schema.group.group_shopping_list import (
+ ShoppingListItemCreate,
+ ShoppingListMultiPurposeLabelCreate,
+ ShoppingListMultiPurposeLabelOut,
+ ShoppingListSave,
+)
+from mealie.schema.labels.multi_purpose_label import MultiPurposeLabelSave
from mealie.schema.meal_plan.new_meal import CreatePlanEntry
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe_category import CategorySave, TagSave
-from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
+from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientFood, SaveIngredientUnit
from mealie.schema.recipe.recipe_tool import RecipeToolSave
-from mealie.schema.response.pagination import PaginationQuery
+from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationQuery
from mealie.services.seeder.seeder_service import SeederService
from tests.utils import api_routes
from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser
+class Reversor:
+ """
+ Enables reversed sorting
+
+ https://stackoverflow.com/a/56842689
+ """
+
+ def __init__(self, obj):
+ self.obj = obj
+
+ def __eq__(self, other):
+ return other.obj == self.obj
+
+ def __lt__(self, other):
+ return other.obj < self.obj
+
+
+def get_label_position_from_label_id(label_id: UUID4, label_settings: list[ShoppingListMultiPurposeLabelOut]) -> int:
+ for label_setting in label_settings:
+ if label_setting.label_id == label_id:
+ return label_setting.position
+
+ raise Exception("Something went wrong when parsing label settings")
+
+
def test_repository_pagination(database: AllRepositories, unique_user: TestUser):
group = database.groups.get_one(unique_user.group_id)
assert group
@@ -153,14 +187,6 @@ def query_units(database: AllRepositories, unique_user: TestUser):
unit_ids = [unit.id for unit in [unit_1, unit_2, unit_3]]
units_repo = database.ingredient_units.by_group(unique_user.group_id) # type: ignore
- # make sure we can get all of our test units
- query = PaginationQuery(page=1, per_page=-1)
- all_units = units_repo.page_all(query).items
- assert len(all_units) == 3
-
- for unit in all_units:
- assert unit.id in unit_ids
-
yield units_repo, unit_1, unit_2, unit_3
for unit_id in unit_ids:
@@ -233,7 +259,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
- assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
@@ -242,7 +267,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
- assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
@@ -251,7 +275,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]')
unit_results = units_repo.page_all(query).items
- assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
@@ -521,6 +544,282 @@ def test_pagination_filter_datetimes(
assert len(unit_ids) == 0
+@pytest.mark.parametrize("order_direction", [OrderDirection.asc, OrderDirection.desc], ids=["ascending", "descending"])
+def test_pagination_order_by_multiple(
+ database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
+):
+ current_time = datetime.now()
+
+ alphabet = ["a", "b", "c", "d", "e"]
+ abbreviations = alphabet.copy()
+ descriptions = alphabet.copy()
+
+ random.shuffle(abbreviations)
+ random.shuffle(descriptions)
+ assert abbreviations != descriptions
+
+ units_to_create: list[SaveIngredientUnit] = []
+ for abbreviation in abbreviations:
+ for description in descriptions:
+ units_to_create.append(
+ SaveIngredientUnit(
+ group_id=unique_user.group_id,
+ name=random_string(),
+ abbreviation=abbreviation,
+ description=description,
+ )
+ )
+
+ sorted_units = database.ingredient_units.create_many(units_to_create)
+ sorted_units.sort(key=lambda x: (x.abbreviation, x.description), reverse=order_direction is OrderDirection.desc)
+
+ query = database.ingredient_units.page_all(
+ PaginationQuery(
+ page=1,
+ per_page=-1,
+ order_by="abbreviation, description",
+ order_direction=order_direction,
+ query_filter=f'created_at >= "{current_time.isoformat()}"',
+ )
+ )
+
+ assert query.items == sorted_units
+
+
+@pytest.mark.parametrize(
+ "order_by_str, order_direction",
+ [
+ ("abbreviation:asc, description:desc", OrderDirection.asc),
+ ("abbreviation:asc, description:desc", OrderDirection.desc),
+ ("abbreviation, description:desc", OrderDirection.asc),
+ ("abbreviation:asc, description", OrderDirection.desc),
+ ],
+ ids=[
+ "order_by_asc_explicit_order_bys",
+ "order_by_desc_explicit_order_bys",
+ "order_by_asc_inferred_order_by",
+ "order_by_desc_inferred_order_by",
+ ],
+)
+def test_pagination_order_by_multiple_directions(
+ database: AllRepositories, unique_user: TestUser, order_by_str: str, order_direction: OrderDirection
+):
+ current_time = datetime.now()
+
+ alphabet = ["a", "b", "c", "d", "e"]
+ abbreviations = alphabet.copy()
+ descriptions = alphabet.copy()
+
+ random.shuffle(abbreviations)
+ random.shuffle(descriptions)
+ assert abbreviations != descriptions
+
+ units_to_create: list[SaveIngredientUnit] = []
+ for abbreviation in abbreviations:
+ for description in descriptions:
+ units_to_create.append(
+ SaveIngredientUnit(
+ group_id=unique_user.group_id,
+ name=random_string(),
+ abbreviation=abbreviation,
+ description=description,
+ )
+ )
+
+ sorted_units = database.ingredient_units.create_many(units_to_create)
+
+ # sort by abbreviation ascending, description descending
+ sorted_units.sort(key=lambda x: (x.abbreviation, Reversor(x.description)))
+
+ query = database.ingredient_units.page_all(
+ PaginationQuery(
+ page=1,
+ per_page=-1,
+ order_by=order_by_str,
+ order_direction=order_direction,
+ query_filter=f'created_at >= "{current_time.isoformat()}"',
+ )
+ )
+
+ assert query.items == sorted_units
+
+
+@pytest.mark.parametrize(
+ "order_direction",
+ [OrderDirection.asc, OrderDirection.desc],
+ ids=["order_ascending", "order_descending"],
+)
+def test_pagination_order_by_nested_model(
+ database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
+):
+ current_time = datetime.now()
+
+ alphabet = ["a", "b", "c", "d", "e"]
+ labels = database.group_multi_purpose_labels.create_many(
+ [MultiPurposeLabelSave(group_id=unique_user.group_id, name=letter) for letter in alphabet]
+ )
+ random.shuffle(labels)
+
+ sorted_foods = database.ingredient_foods.create_many(
+ [SaveIngredientFood(group_id=unique_user.group_id, name=random_string(), label_id=label.id) for label in labels]
+ )
+
+ sorted_foods.sort(key=lambda x: x.label.name, reverse=order_direction is OrderDirection.desc) # type: ignore
+ query = database.ingredient_foods.page_all(
+ PaginationQuery(
+ page=1,
+ per_page=-1,
+ order_by="label.name",
+ order_direction=order_direction,
+ query_filter=f'created_at >= "{current_time.isoformat()}"',
+ )
+ )
+
+ assert query.items == sorted_foods
+
+
+def test_pagination_order_by_doesnt_filter(database: AllRepositories, unique_user: TestUser):
+ current_time = datetime.now()
+
+ label = database.group_multi_purpose_labels.create(
+ MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
+ )
+ food_with_label = database.ingredient_foods.create(
+ SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
+ )
+ food_without_label = database.ingredient_foods.create(
+ SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
+ )
+
+ query = database.ingredient_foods.by_group(unique_user.group_id).page_all(
+ PaginationQuery(per_page=-1, query_filter=f"created_at>{current_time.isoformat()}", order_by="label.name")
+ )
+ assert len(query.items) == 2
+ found_ids = {item.id for item in query.items}
+ assert food_with_label.id in found_ids
+ assert food_without_label.id in found_ids
+
+
+@pytest.mark.parametrize(
+ "null_position, order_direction",
+ [
+ (OrderByNullPosition.first, OrderDirection.asc),
+ (OrderByNullPosition.last, OrderDirection.asc),
+ (OrderByNullPosition.first, OrderDirection.asc),
+ (OrderByNullPosition.last, OrderDirection.asc),
+ ],
+ ids=[
+ "order_by_nulls_first_order_direction_asc",
+ "order_by_nulls_last_order_direction_asc",
+ "order_by_nulls_first_order_direction_desc",
+ "order_by_nulls_last_order_direction_desc",
+ ],
+)
+def test_pagination_order_by_nulls(
+ database: AllRepositories,
+ unique_user: TestUser,
+ null_position: OrderByNullPosition,
+ order_direction: OrderDirection,
+):
+ current_time = datetime.now()
+
+ label = database.group_multi_purpose_labels.create(
+ MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
+ )
+ food_with_label = database.ingredient_foods.create(
+ SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
+ )
+ food_without_label = database.ingredient_foods.create(
+ SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
+ )
+
+ query = database.ingredient_foods.page_all(
+ PaginationQuery(
+ per_page=-1,
+ query_filter=f"created_at >= {current_time.isoformat()}",
+ order_by="label.name",
+ order_by_null_position=null_position,
+ order_direction=order_direction,
+ )
+ )
+ assert len(query.items) == 2
+
+ if null_position is OrderByNullPosition.first:
+ assert query.items[0] == food_without_label
+ assert query.items[1] == food_with_label
+ else:
+ assert query.items[0] == food_with_label
+ assert query.items[1] == food_without_label
+
+
+def test_pagination_shopping_list_items_with_labels(database: AllRepositories, unique_user: TestUser):
+ # create a shopping list and populate it with some items with labels, and some without labels
+ shopping_list = database.group_shopping_lists.create(
+ ShoppingListSave(name=random_string(), group_id=unique_user.group_id)
+ )
+
+ labels = database.group_multi_purpose_labels.create_many(
+ [MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id) for _ in range(8)]
+ )
+ random.shuffle(labels)
+
+ label_settings = database.shopping_list_multi_purpose_labels.create_many(
+ [
+ ShoppingListMultiPurposeLabelCreate(shopping_list_id=shopping_list.id, label_id=label.id, position=i)
+ for i, label in enumerate(labels)
+ ]
+ )
+ random.shuffle(label_settings)
+
+ with_labels_positions = list(range(0, random_int(20, 25)))
+ random.shuffle(with_labels_positions)
+ items_with_labels = database.group_shopping_list_item.create_many(
+ [
+ ShoppingListItemCreate(
+ note=random_string(),
+ shopping_list_id=shopping_list.id,
+ label_id=random.choice(labels).id,
+ position=position,
+ )
+ for position in with_labels_positions
+ ]
+ )
+ # sort by item label position ascending, then item position ascending
+ items_with_labels.sort(
+ key=lambda x: (
+ get_label_position_from_label_id(x.label.id, label_settings), # type: ignore[union-attr]
+ x.position,
+ )
+ )
+
+ without_labels_positions = list(range(len(with_labels_positions), random_int(5, 10)))
+ random.shuffle(without_labels_positions)
+ items_without_labels = database.group_shopping_list_item.create_many(
+ [
+ ShoppingListItemCreate(
+ note=random_string(),
+ shopping_list_id=shopping_list.id,
+ label_id=random.choice(labels).id,
+ position=position,
+ )
+ for position in without_labels_positions
+ ]
+ )
+ items_without_labels.sort(key=lambda x: x.position)
+
+ # verify they're in order
+ query = database.group_shopping_list_item.page_all(
+ PaginationQuery(
+ per_page=-1,
+ order_by="label.shopping_lists_label_settings.position, position",
+ order_direction=OrderDirection.asc,
+ order_by_null_position=OrderByNullPosition.first,
+ ),
+ )
+
+ assert query.items == items_without_labels + items_with_labels
+
+
def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
yesterday = date.today() - timedelta(days=1)
today = date.today()
@@ -616,7 +915,11 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien
units_repo = query_units[0]
unit_1 = query_units[1]
- query = PaginationQuery(page=1, per_page=-1, query_filter="useAbbreviation=true")
+ query = PaginationQuery(
+ page=1,
+ per_page=-1,
+ query_filter=f"useAbbreviation=true AND id IN [{', '.join([str(unit.id) for unit in query_units[1:]])}]",
+ )
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
assert unit_results[0].id == unit_1.id
@@ -630,7 +933,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
- assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
@@ -640,7 +942,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
- assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids