feat: Advanced Query Filter Record Ordering (#2530)

* added support for multiple order_by strs

* refactored qf to expose nested attr logic

* added nested attr support to order_by

* added tests

* changed unique user to be function-level

* updated docs

* added support for null handling

* updated docs

* undid fixture changes

* fix leaky tests

* added advanced shopping list item test

---------

Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
Michael Genson 2023-09-14 09:09:05 -05:00 committed by GitHub
parent 2c5e5a8421
commit aec4cb4f31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 483 additions and 66 deletions

View File

@ -72,6 +72,23 @@ This filter will find all recipes created on or after a particular date: <br>
This filter will find all units that have `useAbbreviation` disabled: <br>
`useAbbreviation = false`
This filter will find all foods that are not named "carrot": <br>
`name <> "carrot"`
##### Keyword Filters
The API supports many SQL keywords, such as `IS NULL` and `IN`, as well as their negations (e.g. `IS NOT NULL` and `NOT IN`).
Here is an example of a filter that returns all recipes where the "last made" value is not null: <br>
`lastMade IS NOT NULL`
This filter will find all recipes that don't start with the word "Test": <br>
`name NOT LIKE "Test%"`
> **_NOTE:_** for more information on this, [check out the SQL "LIKE" operator](https://www.w3schools.com/sql/sql_like.asp)
This filter will find all recipes that have particular slugs: <br>
`slug IN ["pasta-fagioli", "delicious-ramen"]`
##### Nested Property filters
When querying tables with relationships, you can filter properties on related tables. For instance, if you want to query all recipes owned by a particular user: <br>
`user.username = "SousChef20220320"`
@ -79,6 +96,9 @@ When querying tables with relationships, you can filter properties on related ta
This timeline event filter will return all timeline events for recipes that were created after a particular date: <br>
`recipe.createdAt >= "2023-02-25"`
This recipe filter will return all recipes that contains a particular set of tags: <br>
`tags.name CONTAINS ALL ["Easy", "Cajun"]`
##### Compound Filters
You can combine multiple filter statements using logical operators (`AND`, `OR`).
@ -96,3 +116,31 @@ You can have multiple filter groups combined by logical operators. You can defin
Here's a filter that will find all recipes updated between two particular times, but exclude the "Pasta Fagioli" recipe: <br>
`(updatedAt > "2022-07-17T15:47:00Z" AND updatedAt < "2022-07-17T15:50:00Z") AND name <> "Pasta Fagioli"`
#### Advanced Ordering
Pagination supports `orderBy`, `orderByNullPosition`, and `orderDirection` params to change how you want your query results to be ordered. These can be fine-tuned for more advanced use-cases.
##### Order By
The pagination `orderBy` attribute allows you to sort your query results by a particular attribute. Sometimes, however, [you may want to sort by more than one attribute](https://www.w3schools.com/sql/sql_orderby.asp). This can be achieved by passing a comma-separated string to the `orderBy` parameter. For instance, if you want to sort recipes by their last made datetime, then by their created datetime, you can pass the following `orderBy` string: <br>
`lastMade, createdAt`
Similar to the standard SQL `ORDER BY` logic, your attribute orders will be applied sequentially. In the above example, *first* recipes will be sorted by `lastMade`, *then* any recipes with an identical `lastMade` value are sorted by `createdAt`. In addition, standard SQL rules apply when handling results with null values (such as when joining related tables). You can apply the `NULLS FIRST` and `NULLS LAST` SQL expressions by setting the `orderByNullPosition` to "first" or "last". If left empty, the default SQL behavior is applied, [which is different depending on which database you're using](https://learnsql.com/blog/how-to-order-rows-with-nulls/).
##### Order Direction
The query will be ordered in ascending or descending order, depending on what you pass to the pagination `orderDirection` param. You can either specify "asc" or "desc".
When sorting by multiple attributes, if you *also* want one or more of those sorts to be different directions, you can specify them with a colon. For instance, if, like our previous example, say you want to sort by `lastMade` and `createdAt`. However, this time, you want to sort by `lastMade` ascending, but `createdAt` descending. You could pass this `orderBy` string: <br>
`lastMade:asc, createdAt:desc`
In the above example, whatever you pass to `orderDirection` will be ignored. If, however, you only specify the direction on one attribute, all other attributes will use the `orderDirection` value.
Consider this `orderBy` string: <br>
`lastMade:asc, createdAt, slug`
And this `orderDirection` value: <br>
`desc`
This will result in a recipe query where all recipes are sorted by `lastMade` ascending, then `createdAt` descending, and finally `slug` descending.
Similar to query filters, when querying tables with relationships, you can order by properties on related tables. For instance, if you want to query all foods with labels, sorted by label name, you could use this `orderBy`: <br>
`label.name`

View File

@ -7,14 +7,14 @@ from typing import Any, Generic, TypeVar
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import Select, case, delete, func, select
from sqlalchemy import Select, case, delete, func, nulls_first, nulls_last, select
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes
from mealie.core.root_logger import get_logger
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema._mealie import MealieModel
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
from mealie.schema.response.query_search import SearchFilter
@ -372,32 +372,65 @@ class RepositoryGeneric(Generic[Schema, Model]):
pagination.page = 1
if pagination.order_by:
if order_attr := getattr(self.model, pagination.order_by, None):
# queries handle uppercase and lowercase differently, which is undesirable
if isinstance(order_attr.type, sqltypes.String):
order_attr = func.lower(order_attr)
if pagination.order_direction == OrderDirection.asc:
order_attr = order_attr.asc()
elif pagination.order_direction == OrderDirection.desc:
order_attr = order_attr.desc()
query = query.order_by(order_attr)
elif pagination.order_by == "random":
# randomize outside of database, since not all db's can set random seeds
# this solution is db-independent & stable to paging
temp_query = query.with_only_columns(self.model.id)
allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
order = list(range(len(allids)))
random.seed(pagination.pagination_seed)
random.shuffle(order)
random_dict = dict(zip(allids, order, strict=True))
case_stmt = case(random_dict, value=self.model.id)
query = query.order_by(case_stmt)
query = self.add_order_by_to_query(query, pagination)
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
def add_order_by_to_query(self, query: Select, pagination: PaginationQuery) -> Select:
if not pagination.order_by:
return query
if pagination.order_by == "random":
# randomize outside of database, since not all db's can set random seeds
# this solution is db-independent & stable to paging
temp_query = query.with_only_columns(self.model.id)
allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
order = list(range(len(allids)))
random.seed(pagination.pagination_seed)
random.shuffle(order)
random_dict = dict(zip(allids, order, strict=True))
case_stmt = case(random_dict, value=self.model.id)
return query.order_by(case_stmt)
else:
for order_by_val in pagination.order_by.split(","):
try:
order_by_val = order_by_val.strip()
if ":" in order_by_val:
order_by, order_dir_val = order_by_val.split(":")
order_dir = OrderDirection(order_dir_val)
else:
order_by = order_by_val
order_dir = pagination.order_direction
_, order_attr, query = QueryFilter.get_model_and_model_attr_from_attr_string(
order_by, self.model, query=query
)
if order_dir is OrderDirection.asc:
order_attr = order_attr.asc()
elif order_dir is OrderDirection.desc:
order_attr = order_attr.desc()
# queries handle uppercase and lowercase differently, which is undesirable
if isinstance(order_attr.type, sqltypes.String):
order_attr = func.lower(order_attr)
if pagination.order_by_null_position is OrderByNullPosition.first:
order_attr = nulls_first(order_attr)
elif pagination.order_by_null_position is OrderByNullPosition.last:
order_attr = nulls_last(order_attr)
query = query.order_by(order_attr)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f'Invalid order_by statement "{pagination.order_by}": "{order_by_val}" is invalid',
) from e
return query
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
search_filter = SearchFilter(self.session, search, schema._normalize_search)
return search_filter.filter_query_by_search(query, schema, self.model)

View File

@ -16,6 +16,11 @@ class OrderDirection(str, enum.Enum):
desc = "desc"
class OrderByNullPosition(str, enum.Enum):
first = "first"
last = "last"
class RecipeSearchQuery(MealieModel):
cookbook: UUID4 | str | None
require_all_categories: bool = False
@ -30,6 +35,7 @@ class PaginationQuery(MealieModel):
page: int = 1
per_page: int = 50
order_by: str = "created_at"
order_by_null_position: OrderByNullPosition | None = None
order_direction: OrderDirection = OrderDirection.desc
query_filter: str | None = None
pagination_seed: str | None = None

View File

@ -13,9 +13,10 @@ from sqlalchemy import ColumnElement, Select, and_, inspect, or_
from sqlalchemy.orm import InstrumentedAttribute, Mapper
from sqlalchemy.sql import sqltypes
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.db.models._model_utils.guid import GUID
Model = TypeVar("Model")
Model = TypeVar("Model", bound=SqlAlchemyBase)
class RelationalKeyword(Enum):
@ -238,6 +239,53 @@ class QueryFilter:
if i == len(group) - 1:
return consolidated_group_builder.self_group()
@classmethod
def get_model_and_model_attr_from_attr_string(
cls, attr_string: str, model: type[Model], *, query: Select | None = None
) -> tuple[SqlAlchemyBase, InstrumentedAttribute, Select | None]:
"""
Take an attribute string and traverse a database model and its relationships to get the desired
model and model attribute. Optionally provide a query to apply the necessary table joins.
If the attribute string is invalid, raises a `ValueError`.
For instance, the attribute string "user.name" on `RecipeModel`
will return the `User` model's `name` attribute.
Works with shallow attributes (e.g. "slug" from `RecipeModel`)
and arbitrarily deep ones (e.g. "recipe.group.preferences" on `RecipeTimelineEvent`).
"""
model_attr: InstrumentedAttribute | None = None
attribute_chain = attr_string.split(".")
if not attribute_chain:
raise ValueError("invalid query string: attribute name cannot be empty")
current_model: SqlAlchemyBase = model # type: ignore
for i, attribute_link in enumerate(attribute_chain):
try:
model_attr = getattr(current_model, attribute_link)
# at the end of the chain there are no more relationships to inspect
if i == len(attribute_chain) - 1:
break
if query is not None:
query = query.join(
model_attr, isouter=True
) # we use outer joins to not unintentionally filter out values
mapper: Mapper = inspect(current_model)
relationship = mapper.relationships[attribute_link]
current_model = relationship.mapper.class_
except (AttributeError, KeyError) as e:
raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
if model_attr is None:
raise ValueError(f"invalid attribute string: '{attr_string}'")
return current_model, model_attr, query
def filter_query(self, query: Select, model: type[Model]) -> Select:
# join tables and build model chain
attr_model_map: dict[int, Any] = {}
@ -246,29 +294,10 @@ class QueryFilter:
if not isinstance(component, QueryFilterComponent):
continue
attribute_chain = component.attribute_name.split(".")
if not attribute_chain:
raise ValueError("invalid query string: attribute name cannot be empty")
current_model = model
for j, attribute_link in enumerate(attribute_chain):
try:
model_attr = getattr(current_model, attribute_link)
# at the end of the chain there are no more relationships to inspect
if j == len(attribute_chain) - 1:
break
query = query.join(model_attr)
mapper: Mapper = inspect(current_model)
relationship = mapper.relationships[attribute_link]
current_model = relationship.mapper.class_
except (AttributeError, KeyError) as e:
raise ValueError(
f"invalid query string: '{component.attribute_name}' does not exist on this schema"
) from e
attr_model_map[i] = current_model
nested_model, model_attr, query = self.get_model_and_model_attr_from_attr_string(
component.attribute_name, model, query=query
)
attr_model_map[i] = nested_model
# build query filter
partial_group: list[ColumnElement] = []

View File

@ -3,10 +3,10 @@ from typing import Generator
from pytest import fixture
from starlette.testclient import TestClient
from mealie.db.db_setup import session_context
from mealie.db.models.users.users import AuthMethod
from mealie.repos.all_repositories import get_repositories
from tests import utils
from tests.utils import api_routes
from tests.utils.factories import random_string

View File

@ -1,3 +1,4 @@
import random
import time
from collections import defaultdict
from datetime import date, datetime, timedelta
@ -7,21 +8,54 @@ from urllib.parse import parse_qsl, urlsplit
import pytest
from fastapi.testclient import TestClient
from humps import camelize
from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.repos.repository_units import RepositoryUnit
from mealie.schema.group.group_shopping_list import (
ShoppingListItemCreate,
ShoppingListMultiPurposeLabelCreate,
ShoppingListMultiPurposeLabelOut,
ShoppingListSave,
)
from mealie.schema.labels.multi_purpose_label import MultiPurposeLabelSave
from mealie.schema.meal_plan.new_meal import CreatePlanEntry
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe_category import CategorySave, TagSave
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientFood, SaveIngredientUnit
from mealie.schema.recipe.recipe_tool import RecipeToolSave
from mealie.schema.response.pagination import PaginationQuery
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationQuery
from mealie.services.seeder.seeder_service import SeederService
from tests.utils import api_routes
from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser
class Reversor:
"""
Enables reversed sorting
https://stackoverflow.com/a/56842689
"""
def __init__(self, obj):
self.obj = obj
def __eq__(self, other):
return other.obj == self.obj
def __lt__(self, other):
return other.obj < self.obj
def get_label_position_from_label_id(label_id: UUID4, label_settings: list[ShoppingListMultiPurposeLabelOut]) -> int:
for label_setting in label_settings:
if label_setting.label_id == label_id:
return label_setting.position
raise Exception("Something went wrong when parsing label settings")
def test_repository_pagination(database: AllRepositories, unique_user: TestUser):
group = database.groups.get_one(unique_user.group_id)
assert group
@ -153,14 +187,6 @@ def query_units(database: AllRepositories, unique_user: TestUser):
unit_ids = [unit.id for unit in [unit_1, unit_2, unit_3]]
units_repo = database.ingredient_units.by_group(unique_user.group_id) # type: ignore
# make sure we can get all of our test units
query = PaginationQuery(page=1, per_page=-1)
all_units = units_repo.page_all(query).items
assert len(all_units) == 3
for unit in all_units:
assert unit.id in unit_ids
yield units_repo, unit_1, unit_2, unit_3
for unit_id in unit_ids:
@ -233,7 +259,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
@ -242,7 +267,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
@ -251,7 +275,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
@ -521,6 +544,282 @@ def test_pagination_filter_datetimes(
assert len(unit_ids) == 0
@pytest.mark.parametrize("order_direction", [OrderDirection.asc, OrderDirection.desc], ids=["ascending", "descending"])
def test_pagination_order_by_multiple(
database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
):
current_time = datetime.now()
alphabet = ["a", "b", "c", "d", "e"]
abbreviations = alphabet.copy()
descriptions = alphabet.copy()
random.shuffle(abbreviations)
random.shuffle(descriptions)
assert abbreviations != descriptions
units_to_create: list[SaveIngredientUnit] = []
for abbreviation in abbreviations:
for description in descriptions:
units_to_create.append(
SaveIngredientUnit(
group_id=unique_user.group_id,
name=random_string(),
abbreviation=abbreviation,
description=description,
)
)
sorted_units = database.ingredient_units.create_many(units_to_create)
sorted_units.sort(key=lambda x: (x.abbreviation, x.description), reverse=order_direction is OrderDirection.desc)
query = database.ingredient_units.page_all(
PaginationQuery(
page=1,
per_page=-1,
order_by="abbreviation, description",
order_direction=order_direction,
query_filter=f'created_at >= "{current_time.isoformat()}"',
)
)
assert query.items == sorted_units
@pytest.mark.parametrize(
"order_by_str, order_direction",
[
("abbreviation:asc, description:desc", OrderDirection.asc),
("abbreviation:asc, description:desc", OrderDirection.desc),
("abbreviation, description:desc", OrderDirection.asc),
("abbreviation:asc, description", OrderDirection.desc),
],
ids=[
"order_by_asc_explicit_order_bys",
"order_by_desc_explicit_order_bys",
"order_by_asc_inferred_order_by",
"order_by_desc_inferred_order_by",
],
)
def test_pagination_order_by_multiple_directions(
database: AllRepositories, unique_user: TestUser, order_by_str: str, order_direction: OrderDirection
):
current_time = datetime.now()
alphabet = ["a", "b", "c", "d", "e"]
abbreviations = alphabet.copy()
descriptions = alphabet.copy()
random.shuffle(abbreviations)
random.shuffle(descriptions)
assert abbreviations != descriptions
units_to_create: list[SaveIngredientUnit] = []
for abbreviation in abbreviations:
for description in descriptions:
units_to_create.append(
SaveIngredientUnit(
group_id=unique_user.group_id,
name=random_string(),
abbreviation=abbreviation,
description=description,
)
)
sorted_units = database.ingredient_units.create_many(units_to_create)
# sort by abbreviation ascending, description descending
sorted_units.sort(key=lambda x: (x.abbreviation, Reversor(x.description)))
query = database.ingredient_units.page_all(
PaginationQuery(
page=1,
per_page=-1,
order_by=order_by_str,
order_direction=order_direction,
query_filter=f'created_at >= "{current_time.isoformat()}"',
)
)
assert query.items == sorted_units
@pytest.mark.parametrize(
"order_direction",
[OrderDirection.asc, OrderDirection.desc],
ids=["order_ascending", "order_descending"],
)
def test_pagination_order_by_nested_model(
database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
):
current_time = datetime.now()
alphabet = ["a", "b", "c", "d", "e"]
labels = database.group_multi_purpose_labels.create_many(
[MultiPurposeLabelSave(group_id=unique_user.group_id, name=letter) for letter in alphabet]
)
random.shuffle(labels)
sorted_foods = database.ingredient_foods.create_many(
[SaveIngredientFood(group_id=unique_user.group_id, name=random_string(), label_id=label.id) for label in labels]
)
sorted_foods.sort(key=lambda x: x.label.name, reverse=order_direction is OrderDirection.desc) # type: ignore
query = database.ingredient_foods.page_all(
PaginationQuery(
page=1,
per_page=-1,
order_by="label.name",
order_direction=order_direction,
query_filter=f'created_at >= "{current_time.isoformat()}"',
)
)
assert query.items == sorted_foods
def test_pagination_order_by_doesnt_filter(database: AllRepositories, unique_user: TestUser):
current_time = datetime.now()
label = database.group_multi_purpose_labels.create(
MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
)
food_with_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
)
food_without_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
)
query = database.ingredient_foods.by_group(unique_user.group_id).page_all(
PaginationQuery(per_page=-1, query_filter=f"created_at>{current_time.isoformat()}", order_by="label.name")
)
assert len(query.items) == 2
found_ids = {item.id for item in query.items}
assert food_with_label.id in found_ids
assert food_without_label.id in found_ids
@pytest.mark.parametrize(
"null_position, order_direction",
[
(OrderByNullPosition.first, OrderDirection.asc),
(OrderByNullPosition.last, OrderDirection.asc),
(OrderByNullPosition.first, OrderDirection.asc),
(OrderByNullPosition.last, OrderDirection.asc),
],
ids=[
"order_by_nulls_first_order_direction_asc",
"order_by_nulls_last_order_direction_asc",
"order_by_nulls_first_order_direction_desc",
"order_by_nulls_last_order_direction_desc",
],
)
def test_pagination_order_by_nulls(
database: AllRepositories,
unique_user: TestUser,
null_position: OrderByNullPosition,
order_direction: OrderDirection,
):
current_time = datetime.now()
label = database.group_multi_purpose_labels.create(
MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
)
food_with_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
)
food_without_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
)
query = database.ingredient_foods.page_all(
PaginationQuery(
per_page=-1,
query_filter=f"created_at >= {current_time.isoformat()}",
order_by="label.name",
order_by_null_position=null_position,
order_direction=order_direction,
)
)
assert len(query.items) == 2
if null_position is OrderByNullPosition.first:
assert query.items[0] == food_without_label
assert query.items[1] == food_with_label
else:
assert query.items[0] == food_with_label
assert query.items[1] == food_without_label
def test_pagination_shopping_list_items_with_labels(database: AllRepositories, unique_user: TestUser):
# create a shopping list and populate it with some items with labels, and some without labels
shopping_list = database.group_shopping_lists.create(
ShoppingListSave(name=random_string(), group_id=unique_user.group_id)
)
labels = database.group_multi_purpose_labels.create_many(
[MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id) for _ in range(8)]
)
random.shuffle(labels)
label_settings = database.shopping_list_multi_purpose_labels.create_many(
[
ShoppingListMultiPurposeLabelCreate(shopping_list_id=shopping_list.id, label_id=label.id, position=i)
for i, label in enumerate(labels)
]
)
random.shuffle(label_settings)
with_labels_positions = list(range(0, random_int(20, 25)))
random.shuffle(with_labels_positions)
items_with_labels = database.group_shopping_list_item.create_many(
[
ShoppingListItemCreate(
note=random_string(),
shopping_list_id=shopping_list.id,
label_id=random.choice(labels).id,
position=position,
)
for position in with_labels_positions
]
)
# sort by item label position ascending, then item position ascending
items_with_labels.sort(
key=lambda x: (
get_label_position_from_label_id(x.label.id, label_settings), # type: ignore[union-attr]
x.position,
)
)
without_labels_positions = list(range(len(with_labels_positions), random_int(5, 10)))
random.shuffle(without_labels_positions)
items_without_labels = database.group_shopping_list_item.create_many(
[
ShoppingListItemCreate(
note=random_string(),
shopping_list_id=shopping_list.id,
label_id=random.choice(labels).id,
position=position,
)
for position in without_labels_positions
]
)
items_without_labels.sort(key=lambda x: x.position)
# verify they're in order
query = database.group_shopping_list_item.page_all(
PaginationQuery(
per_page=-1,
order_by="label.shopping_lists_label_settings.position, position",
order_direction=OrderDirection.asc,
order_by_null_position=OrderByNullPosition.first,
),
)
assert query.items == items_without_labels + items_with_labels
def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
yesterday = date.today() - timedelta(days=1)
today = date.today()
@ -616,7 +915,11 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien
units_repo = query_units[0]
unit_1 = query_units[1]
query = PaginationQuery(page=1, per_page=-1, query_filter="useAbbreviation=true")
query = PaginationQuery(
page=1,
per_page=-1,
query_filter=f"useAbbreviation=true AND id IN [{', '.join([str(unit.id) for unit in query_units[1:]])}]",
)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
assert unit_results[0].id == unit_1.id
@ -630,7 +933,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
@ -640,7 +942,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids