feature: query filter support for common SQL keywords (#2366)

* added support for SQL keywords IS, IN, LIKE, NOT
deprecated datetime workaround for "<> null"
updated frontend reference for "<> null" to "IS NOT NULL"

* tests

* refactored query filtering to leverage orm

* added CONTAINS ALL keyword

* tests

* fixed bug where "and" or "or" was in an attr name

* more tests

* linter fixes

* TIL this works
This commit is contained in:
Michael Genson 2023-05-06 17:28:40 -05:00 committed by GitHub
parent 9b726126ed
commit 5d87b7e411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 760 additions and 117 deletions

View File

@ -217,7 +217,7 @@ export default defineComponent({
const queryFilter = computed(() => {
const orderBy = props.query?.orderBy || preferences.value.orderBy;
return preferences.value.filterNull && orderBy ? `${orderBy} <> null` : null;
return preferences.value.filterNull && orderBy ? `${orderBy} IS NOT NULL` : null;
});
async function fetchRecipes(pageCount = 1) {

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import datetime
import re
from collections import deque
from enum import Enum
from typing import Any, TypeVar, cast
from uuid import UUID
@ -9,16 +9,66 @@ from uuid import UUID
from dateutil import parser as date_parser
from dateutil.parser import ParserError
from humps import decamelize
from sqlalchemy import Select, bindparam, inspect, text
from sqlalchemy.orm import Mapper
from sqlalchemy import ColumnElement, Select, and_, inspect, or_
from sqlalchemy.orm import InstrumentedAttribute, Mapper
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.expression import BindParameter
from mealie.db.models._model_utils.guid import GUID
Model = TypeVar("Model")
class RelationalKeyword(Enum):
IS = "IS"
IS_NOT = "IS NOT"
IN = "IN"
NOT_IN = "NOT IN"
CONTAINS_ALL = "CONTAINS ALL"
LIKE = "LIKE"
NOT_LIKE = "NOT LIKE"
@classmethod
def parse_component(cls, component: str) -> list[str] | None:
"""
Try to parse a component using a relational keyword
If no matching keyword is found, returns None
"""
# extract the attribute name from the component
parsed_component = component.split(maxsplit=1)
if len(parsed_component) < 2:
return None
# assume the component has already filtered out the value and try to match a keyword
# if we try to filter out the value without checking first, keywords with spaces won't parse correctly
possible_keyword = parsed_component[1].strip().lower()
for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True):
if rel_kw.lower() != possible_keyword:
continue
parsed_component[1] = rel_kw
return parsed_component
# there was no match, so the component may still have the value in it
try:
_possible_keyword, _value = parsed_component[-1].rsplit(maxsplit=1)
parsed_component = [parsed_component[0], _possible_keyword, _value]
except ValueError:
# the component has no value to filter out
return None
possible_keyword = parsed_component[1].strip().lower()
for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True):
if rel_kw.lower() != possible_keyword:
continue
parsed_component[1] = rel_kw
return parsed_component
return None
class RelationalOperator(Enum):
EQ = "="
NOTEQ = "<>"
@ -27,6 +77,24 @@ class RelationalOperator(Enum):
GTE = ">="
LTE = "<="
@classmethod
def parse_component(cls, component: str) -> list[str] | None:
"""
Try to parse a component using a relational operator
If no matching operator is found, returns None
"""
for rel_op in sorted([operator.value for operator in cls], key=len, reverse=True):
if rel_op not in component:
continue
parsed_component = [base_component.strip() for base_component in component.split(rel_op) if base_component]
parsed_component.insert(1, rel_op)
return parsed_component
return None
class LogicalOperator(Enum):
AND = "AND"
@ -36,31 +104,107 @@ class LogicalOperator(Enum):
class QueryFilterComponent:
"""A single relational statement"""
def __init__(self, attribute_name: str, relational_operator: RelationalOperator, value: str) -> None:
@staticmethod
def strip_quotes_from_string(val: str) -> str:
if len(val) > 2 and val[0] == '"' and val[-1] == '"':
return val[1:-1]
else:
return val
def __init__(
self, attribute_name: str, relationship: RelationalKeyword | RelationalOperator, value: str | list[str]
) -> None:
self.attribute_name = decamelize(attribute_name)
self.relational_operator = relational_operator
self.value = value
self.relationship = relationship
# remove encasing quotes
if len(value) > 2 and value[0] == '"' and value[-1] == '"':
self.value = value[1:-1]
if isinstance(value, str):
value = self.strip_quotes_from_string(value)
elif isinstance(value, list):
value = [self.strip_quotes_from_string(v) for v in value]
# validate relationship/value pairs
if relationship in [
RelationalKeyword.IN,
RelationalKeyword.NOT_IN,
RelationalKeyword.CONTAINS_ALL,
] and not isinstance(value, list):
raise ValueError(
f"invalid query string: {relationship.value} must be given a list of values"
f"enclosed by {QueryFilter.l_list_sep} and {QueryFilter.r_list_sep}"
)
if relationship is RelationalKeyword.IS or relationship is RelationalKeyword.IS_NOT:
if not isinstance(value, str) or value.lower() not in ["null", "none"]:
raise ValueError(
f'invalid query string: "{relationship.value}" can only be used with "NULL", not "{value}"'
)
self.value = None
else:
self.value = value
def __repr__(self) -> str:
return f"[{self.attribute_name} {self.relational_operator.value} {self.value}]"
return f"[{self.attribute_name} {self.relationship.value} {self.value}]"
def validate(self, model_attr_type: Any) -> Any:
"""Validate value against an model attribute's type and return a validated value, or raise a ValueError"""
sanitized_values: list[Any]
if not isinstance(self.value, list):
sanitized_values = [self.value]
else:
sanitized_values = self.value
for i, v in enumerate(sanitized_values):
# always allow querying for null values
if v is None:
continue
if self.relationship is RelationalKeyword.LIKE or self.relationship is RelationalKeyword.NOT_LIKE:
if not isinstance(model_attr_type, sqltypes.String):
raise ValueError(
f'invalid query string: "{self.relationship.value}" can only be used with string columns'
)
if isinstance(model_attr_type, (GUID)):
try:
# we don't set value since a UUID is functionally identical to a string here
UUID(v)
except ValueError as e:
raise ValueError(f"invalid query string: invalid UUID '{v}'") from e
if isinstance(model_attr_type, sqltypes.Date | sqltypes.DateTime):
try:
sanitized_values[i] = date_parser.parse(v)
except ParserError as e:
raise ValueError(f"invalid query string: unknown date or datetime format '{v}'") from e
if isinstance(model_attr_type, sqltypes.Boolean):
try:
sanitized_values[i] = v.lower()[0] in ["t", "y"] or v == "1"
except IndexError as e:
raise ValueError("invalid query string") from e
return sanitized_values if isinstance(self.value, list) else sanitized_values[0]
class QueryFilter:
lsep: str = "("
rsep: str = ")"
l_group_sep: str = "("
r_group_sep: str = ")"
group_seps: set[str] = {l_group_sep, r_group_sep}
seps: set[str] = {lsep, rsep}
l_list_sep: str = "["
r_list_sep: str = "]"
list_item_sep: str = ","
def __init__(self, filter_string: str) -> None:
# parse filter string
components = QueryFilter._break_filter_string_into_components(filter_string)
base_components = QueryFilter._break_components_into_base_components(components)
if base_components.count(QueryFilter.lsep) != base_components.count(QueryFilter.rsep):
raise ValueError("invalid filter string: parenthesis are unbalanced")
if base_components.count(QueryFilter.l_group_sep) != base_components.count(QueryFilter.r_group_sep):
raise ValueError("invalid query string: parenthesis are unbalanced")
# parse base components into a filter group
self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components)
@ -75,97 +219,125 @@ class QueryFilter:
return f"<<{joined}>>"
@classmethod
def _consolidate_group(cls, group: list[ColumnElement], logical_operators: deque[LogicalOperator]) -> ColumnElement:
consolidated_group_builder: ColumnElement | None = None
for i, element in enumerate(reversed(group)):
if not i:
consolidated_group_builder = element
else:
operator = logical_operators.pop()
if operator is LogicalOperator.AND:
consolidated_group_builder = and_(consolidated_group_builder, element)
elif operator is LogicalOperator.OR:
consolidated_group_builder = or_(consolidated_group_builder, element)
else:
raise ValueError(f"invalid logical operator {operator}")
if i == len(group) - 1:
return consolidated_group_builder.self_group()
def filter_query(self, query: Select, model: type[Model]) -> Select:
segments: list[str] = []
params: list[BindParameter] = []
# join tables and build model chain
attr_model_map: dict[int, Any] = {}
model_attr: InstrumentedAttribute
for i, component in enumerate(self.filter_components):
if component in QueryFilter.seps:
segments.append(component) # type: ignore
if not isinstance(component, QueryFilterComponent):
continue
if isinstance(component, LogicalOperator):
segments.append(component.value)
continue
# for some reason typing doesn't like the lsep and rsep literals, so
# we explicitly mark this as a filter component instead cast doesn't
# actually do anything at runtime
component = cast(QueryFilterComponent, component)
attribute_chain = component.attribute_name.split(".")
if not attribute_chain:
raise ValueError("invalid query string: attribute name cannot be empty")
attr_model: Any = model
current_model = model
for j, attribute_link in enumerate(attribute_chain):
# last element
if j == len(attribute_chain) - 1:
if not hasattr(attr_model, attribute_link):
raise ValueError(
f"invalid query string: '{component.attribute_name}' does not exist on this schema"
)
attr_value = attribute_link
if j:
# use the nested table name, rather than the dot notation
component.attribute_name = f"{attr_model.__table__.name}.{attr_value}"
continue
# join on nested model
try:
query = query.join(getattr(attr_model, attribute_link))
model_attr = getattr(current_model, attribute_link)
mapper: Mapper = inspect(attr_model)
# at the end of the chain there are no more relationships to inspect
if j == len(attribute_chain) - 1:
break
query = query.join(model_attr)
mapper: Mapper = inspect(current_model)
relationship = mapper.relationships[attribute_link]
attr_model = relationship.mapper.class_
current_model = relationship.mapper.class_
except (AttributeError, KeyError) as e:
raise ValueError(
f"invalid query string: '{component.attribute_name}' does not exist on this schema"
) from e
attr_model_map[i] = current_model
# convert values to their proper types
attr = getattr(attr_model, attr_value)
value: Any = component.value
if isinstance(attr.type, (GUID)):
try:
# we don't set value since a UUID is functionally identical to a string here
UUID(value)
except ValueError as e:
raise ValueError(f"invalid query string: invalid UUID '{component.value}'") from e
if isinstance(attr.type, sqltypes.Date | sqltypes.DateTime):
# TODO: add support for IS NULL and IS NOT NULL
# in the meantime, this will work for the specific usecase of non-null dates/datetimes
if value in ["none", "null"] and component.relational_operator == RelationalOperator.NOTEQ:
component.relational_operator = RelationalOperator.GTE
value = datetime.datetime(datetime.MINYEAR, 1, 1)
# build query filter
partial_group: list[ColumnElement] = []
partial_group_stack: deque[list[ColumnElement]] = deque()
logical_operator_stack: deque[LogicalOperator] = deque()
for i, component in enumerate(self.filter_components):
if component == self.l_group_sep:
partial_group_stack.append(partial_group)
partial_group = []
elif component == self.r_group_sep:
if partial_group:
complete_group = self._consolidate_group(partial_group, logical_operator_stack)
partial_group = partial_group_stack.pop()
partial_group.append(complete_group)
else:
try:
value = date_parser.parse(component.value)
partial_group = partial_group_stack.pop()
except ParserError as e:
raise ValueError(
f"invalid query string: unknown date or datetime format '{component.value}'"
) from e
elif isinstance(component, LogicalOperator):
logical_operator_stack.append(component)
if isinstance(attr.type, sqltypes.Boolean):
try:
value = component.value.lower()[0] in ["t", "y"] or component.value == "1"
else:
component = cast(QueryFilterComponent, component)
model_attr = getattr(attr_model_map[i], component.attribute_name.split(".")[-1])
except IndexError as e:
raise ValueError("invalid query string") from e
# Keywords
if component.relationship is RelationalKeyword.IS:
element = model_attr.is_(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.IS_NOT:
element = model_attr.is_not(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.IN:
element = model_attr.in_(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.NOT_IN:
element = model_attr.not_in(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.CONTAINS_ALL:
primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0])
element = and_()
for v in component.validate(model_attr.type):
element = and_(element, primary_model_attr.any(model_attr == v))
elif component.relationship is RelationalKeyword.LIKE:
element = model_attr.like(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.NOT_LIKE:
element = model_attr.not_like(component.validate(model_attr.type))
paramkey = f"P{i+1}"
segments.append(" ".join([component.attribute_name, component.relational_operator.value, f":{paramkey}"]))
params.append(bindparam(paramkey, value, attr.type))
# Operators
elif component.relationship is RelationalOperator.EQ:
element = model_attr == component.validate(model_attr.type)
elif component.relationship is RelationalOperator.NOTEQ:
element = model_attr != component.validate(model_attr.type)
elif component.relationship is RelationalOperator.GT:
element = model_attr > component.validate(model_attr.type)
elif component.relationship is RelationalOperator.LT:
element = model_attr < component.validate(model_attr.type)
elif component.relationship is RelationalOperator.GTE:
element = model_attr >= component.validate(model_attr.type)
elif component.relationship is RelationalOperator.LTE:
element = model_attr <= component.validate(model_attr.type)
else:
raise ValueError(f"invalid relationship {component.relationship}")
qs = text(" ".join(segments)).bindparams(*params)
query = query.filter(qs)
return query
partial_group.append(element)
# combine the completed groups into one filter
while True:
consolidated_group = self._consolidate_group(partial_group, logical_operator_stack)
if not partial_group_stack:
return query.filter(consolidated_group)
else:
partial_group = partial_group_stack.pop()
partial_group.append(consolidated_group)
@staticmethod
def _break_filter_string_into_components(filter_string: str) -> list[str]:
@ -176,7 +348,7 @@ class QueryFilter:
subcomponents = []
for component in components:
# don't parse components comprised of only a separator
if component in QueryFilter.seps:
if component in QueryFilter.group_seps:
subcomponents.append(component)
continue
@ -187,7 +359,7 @@ class QueryFilter:
if c == '"':
in_quotes = not in_quotes
if c in QueryFilter.seps and not in_quotes:
if c in QueryFilter.group_seps and not in_quotes:
if new_component:
subcomponents.append(new_component)
@ -208,25 +380,50 @@ class QueryFilter:
return components
@staticmethod
def _break_components_into_base_components(components: list[str]) -> list[str]:
def _break_components_into_base_components(components: list[str]) -> list[str | list[str]]:
"""Further break down components by splitting at relational and logical operators"""
logical_operators = re.compile(
f'({"|".join(operator.value for operator in LogicalOperator)})', flags=re.IGNORECASE
)
pattern = "|".join([f"\\b{operator.value}\\b" for operator in LogicalOperator])
logical_operators = re.compile(f"({pattern})", flags=re.IGNORECASE)
base_components = []
in_list = False
base_components: list[str | list] = []
list_value_components = []
for component in components:
offset = 0
# parse out lists as their own singular sub component
subcomponents = component.split(QueryFilter.l_list_sep)
for i, subcomponent in enumerate(subcomponents):
if not i:
continue
for j, list_value_string in enumerate(subcomponent.split(QueryFilter.r_list_sep)):
if j % 2:
continue
list_value_components.append(
[val.strip() for val in list_value_string.split(QueryFilter.list_item_sep)]
)
quote_offset = 0
subcomponents = component.split('"')
for i, subcomponent in enumerate(subcomponents):
# we are in a list subcomponent, which is already handled
if in_list:
if QueryFilter.r_list_sep in subcomponent:
# filter out the remainder of the list subcomponent and continue parsing
base_components.append(list_value_components.pop(0))
subcomponent = subcomponent.split(QueryFilter.r_list_sep, maxsplit=1)[-1].strip()
in_list = False
else:
continue
# don't parse components comprised of only a separator
if subcomponent in QueryFilter.seps:
offset += 1
if subcomponent in QueryFilter.group_seps:
quote_offset += 1
base_components.append(subcomponent)
continue
# this subscomponent was surrounded in quotes, so we keep it as-is
if (i + offset) % 2:
# this subcomponent was surrounded in quotes, so we keep it as-is
if (i + quote_offset) % 2:
base_components.append(f'"{subcomponent.strip()}"')
continue
@ -234,53 +431,70 @@ class QueryFilter:
if not subcomponent:
continue
# continue parsing this subcomponent up to the list, then skip over subsequent subcomponents
if not in_list and QueryFilter.l_list_sep in subcomponent:
subcomponent, _new_sub_component = subcomponent.split(QueryFilter.l_list_sep, maxsplit=1)
subcomponent = subcomponent.strip()
subcomponents.insert(i + 1, _new_sub_component)
quote_offset += 1
in_list = True
# parse out logical operators
new_components = [
base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component
]
# parse out relational operators; each base_subcomponent has exactly zero or one relational operator
# we do them one at a time in descending length since some operators overlap (e.g. :> and >)
# parse out relational keywords and operators
# each base_subcomponent has exactly zero or one keyword or operator
for component in new_components:
if not component:
continue
added_to_base_components = False
for rel_op in sorted([operator.value for operator in RelationalOperator], key=len, reverse=True):
if rel_op in component:
new_base_components = [
base_component.strip() for base_component in component.split(rel_op) if base_component
]
new_base_components.insert(1, rel_op)
base_components.extend(new_base_components)
# we try relational operators first since they aren't required to be surrounded by spaces
parsed_component = RelationalOperator.parse_component(component)
if parsed_component is not None:
base_components.extend(parsed_component)
continue
added_to_base_components = True
break
parsed_component = RelationalKeyword.parse_component(component)
if parsed_component is not None:
base_components.extend(parsed_component)
continue
if not added_to_base_components:
base_components.append(component)
# this component does not have any keywords or operators, so we just add it as-is
base_components.append(component)
return base_components
@staticmethod
def _parse_base_components_into_filter_components(
base_components: list[str],
base_components: list[str | list[str]],
) -> list[str | QueryFilterComponent | LogicalOperator]:
"""Walk through base components and construct filter collections"""
relational_keywords = [kw.value for kw in RelationalKeyword]
relational_operators = [op.value for op in RelationalOperator]
logical_operators = [op.value for op in LogicalOperator]
# parse QueryFilterComponents and logical operators
components: list[str | QueryFilterComponent | LogicalOperator] = []
for i, base_component in enumerate(base_components):
if base_component in QueryFilter.seps:
if isinstance(base_component, list):
continue
if base_component in QueryFilter.group_seps:
components.append(base_component)
elif base_component in relational_operators:
elif base_component in relational_keywords or base_component in relational_operators:
relationship: RelationalKeyword | RelationalOperator
if base_component in relational_keywords:
relationship = RelationalKeyword(base_components[i])
else:
relationship = RelationalOperator(base_components[i])
components.append(
QueryFilterComponent(
attribute_name=base_components[i - 1],
relational_operator=RelationalOperator(base_components[i]),
attribute_name=base_components[i - 1], # type: ignore
relationship=relationship,
value=base_components[i + 1],
)
)

View File

@ -1,5 +1,6 @@
import time
from collections import defaultdict
from datetime import datetime
from random import randint
from urllib.parse import parse_qsl, urlsplit
@ -9,7 +10,10 @@ from humps import camelize
from mealie.repos.repository_factory import AllRepositories
from mealie.repos.repository_units import RepositoryUnit
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe_category import CategorySave, TagSave
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
from mealie.schema.recipe.recipe_tool import RecipeToolSave
from mealie.schema.response.pagination import PaginationQuery
from mealie.services.seeder.seeder_service import SeederService
from tests.utils import api_routes
@ -172,6 +176,256 @@ def test_pagination_filter_basic(query_units: tuple[RepositoryUnit, IngredientUn
assert unit_results[0].id == unit_2.id
def test_pagination_filter_null(database: AllRepositories, unique_user: TestUser):
recipe_not_made_1 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())
)
recipe_not_made_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())
)
# give one recipe a last made date
recipe_made = database.recipes.create(
Recipe(
user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), last_made=datetime.now()
)
)
recipe_repo = database.recipes.by_group(unique_user.group_id) # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NONE")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 2
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id in result_ids
assert recipe_not_made_2.id in result_ids
assert recipe_made.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NULL")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 2
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id in result_ids
assert recipe_not_made_2.id in result_ids
assert recipe_made.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NONE")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 1
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id not in result_ids
assert recipe_not_made_2.id not in result_ids
assert recipe_made.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NULL")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 1
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id not in result_ids
assert recipe_not_made_2.id not in result_ids
assert recipe_made.id in result_ids
def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
units_repo, unit_1, unit_2, unit_3 = query_units
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
assert unit_3.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
assert unit_3.id in result_ids
def test_pagination_filter_in_advanced(database: AllRepositories, unique_user: TestUser):
slug1, slug2 = (random_string(10) for _ in range(2))
tags = [
TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1),
TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2),
]
tag_1, tag_2 = [database.tags.create(tag) for tag in tags]
# Bootstrap the database with recipes
slug = random_string()
recipe_0 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[])
)
slug = random_string()
recipe_1 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1])
)
slug = random_string()
recipe_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_2])
)
slug = random_string()
recipe_1_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1, tag_2])
)
query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}]")
recipe_results = database.recipes.page_all(query).items
assert len(recipe_results) == 2
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_0.id not in recipe_ids
assert recipe_1.id in recipe_ids
assert recipe_2.id not in recipe_ids
assert recipe_1_2.id in recipe_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}, {tag_2.name}]")
recipe_results = database.recipes.page_all(query).items
assert len(recipe_results) == 3
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_0.id not in recipe_ids
assert recipe_1.id in recipe_ids
assert recipe_2.id in recipe_ids
assert recipe_1_2.id in recipe_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name CONTAINS ALL [{tag_1.name}, {tag_2.name}]")
recipe_results = database.recipes.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_0.id not in recipe_ids
assert recipe_1.id not in recipe_ids
assert recipe_2.id not in recipe_ids
assert recipe_1_2.id in recipe_ids
def test_pagination_filter_like(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
units_repo, unit_1, unit_2, unit_3 = query_units
query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "test u_it%"')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 3
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "%unit 1"')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id not in result_ids
assert unit_3.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=r'name NOT LIKE %t_1"')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id in result_ids
assert unit_3.id in result_ids
def test_pagination_filter_keyword_namespace_conflict(database: AllRepositories, unique_user: TestUser):
recipe_rating_1 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=1)
)
recipe_rating_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=2)
)
recipe_rating_3 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=3)
)
recipe_repo = database.recipes.by_group(unique_user.group_id) # type: ignore
# "rating" contains the word "in", but we should not parse this as the keyword "IN"
query = PaginationQuery(page=1, per_page=-1, query_filter="rating > 2")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 1
result_ids = {recipe.id for recipe in recipe_results}
assert recipe_rating_1.id not in result_ids
assert recipe_rating_2.id not in result_ids
assert recipe_rating_3.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="rating in [1, 3]")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 2
result_ids = {recipe.id for recipe in recipe_results}
assert recipe_rating_1.id in result_ids
assert recipe_rating_2.id not in result_ids
assert recipe_rating_3.id in result_ids
def test_pagination_filter_logical_namespace_conflict(database: AllRepositories, unique_user: TestUser):
categories = [
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
]
category_1, category_2 = [database.categories.create(category) for category in categories]
# Bootstrap the database with recipes
slug = random_string()
recipe_category_0 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug)
)
slug = random_string()
recipe_category_1 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
)
)
slug = random_string()
recipe_category_2 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_2],
)
)
# "recipeCategory" has the substring "or" in it, which shouldn't break queries
query = PaginationQuery(page=1, per_page=-1, query_filter=f'recipeCategory.id = "{category_1.id}"')
recipe_results = database.recipes.by_group(unique_user.group_id).page_all(query).items # type: ignore
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_category_0.id not in recipe_ids
assert recipe_category_1.id in recipe_ids
assert recipe_category_2.id not in recipe_ids
def test_pagination_filter_datetimes(
query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]
):
@ -197,15 +451,183 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien
def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
units_repo = query_units[0]
unit_3 = query_units[3]
units_repo, unit_1, unit_2, unit_3 = query_units
dt = str(unit_3.created_at.isoformat()) # type: ignore
qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="test unit 2" OR createdAt > "{dt}"))'
qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="{unit_2.name}" OR createdAt > "{dt}"))'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
assert unit_3.id not in [unit.id for unit in unit_results]
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id not in result_ids
qf = f'(name LIKE %_1 OR name IN ["{unit_2.name}"]) AND createdAt IS NOT NONE'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id not in result_ids
def test_pagination_filter_advanced_frontend_sort(database: AllRepositories, unique_user: TestUser):
categories = [
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
]
category_1, category_2 = [database.categories.create(category) for category in categories]
slug1, slug2 = (random_string(10) for _ in range(2))
tags = [
TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1),
TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2),
]
tag_1, tag_2 = [database.tags.create(tag) for tag in tags]
tools = [
RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)),
RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)),
]
tool_1, tool_2 = [database.tools.create(tool) for tool in tools]
# Bootstrap the database with recipes
slug = random_string()
recipe_ct0_tg0_tl0 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug)
)
slug = random_string()
recipe_ct1_tg0_tl0 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
)
)
slug = random_string()
recipe_ct12_tg0_tl0 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1, category_2],
)
)
slug = random_string()
recipe_ct1_tg1_tl0 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
tags=[tag_1],
)
)
slug = random_string()
recipe_ct1_tg0_tl1 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
tools=[tool_1],
)
)
slug = random_string()
recipe_ct0_tg2_tl2 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
tags=[tag_2],
tools=[tool_2],
)
)
slug = random_string()
recipe_ct12_tg12_tl2 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1, category_2],
tags=[tag_1, tag_2],
tools=[tool_2],
)
)
repo = database.recipes.by_group(unique_user.group_id) # type: ignore
qf = f'recipeCategory.id IN ["{category_1.id}"] AND tools.id IN ["{tool_1.id}"]'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id in recipe_ids
assert recipe_ct0_tg2_tl2.id not in recipe_ids
assert recipe_ct12_tg12_tl2.id not in recipe_ids
qf = f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"] AND tags.id IN ["{tag_1.id}"]'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id not in recipe_ids
assert recipe_ct0_tg2_tl2.id not in recipe_ids
assert recipe_ct12_tg12_tl2.id in recipe_ids
qf = f'tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_2.id}"]'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 2
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id not in recipe_ids
assert recipe_ct0_tg2_tl2.id in recipe_ids
assert recipe_ct12_tg12_tl2.id in recipe_ids
qf = (
f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"]'
f'AND tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_1.id}", "{tool_2.id}"]'
)
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id not in recipe_ids
assert recipe_ct0_tg2_tl2.id not in recipe_ids
assert recipe_ct12_tg12_tl2.id in recipe_ids
@pytest.mark.parametrize(
@ -214,6 +636,13 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
pytest.param('(name="test name" AND useAbbreviation=f))', id="unbalanced parenthesis"),
pytest.param('id="this is not a valid UUID"', id="invalid UUID"),
pytest.param('createdAt="this is not a valid datetime format"', id="invalid datetime format"),
pytest.param('name IS "test name"', id="IS can only be used with NULL or NONE"),
pytest.param('name IS NOT "test name"', id="IS NOT can only be used with NULL or NONE"),
pytest.param('name IN "test name"', id="IN must use a list of values"),
pytest.param('name NOT IN "test name"', id="NOT IN must use a list of values"),
pytest.param('name CONTAINS ALL "test name"', id="CONTAINS ALL must use a list of values"),
pytest.param('createdAt LIKE "2023-02-25"', id="LIKE is only valid for string columns"),
pytest.param('createdAt NOT LIKE "2023-02-25"', id="NOT LIKE is only valid for string columns"),
pytest.param('badAttribute="test value"', id="invalid attribute"),
pytest.param('group.badAttribute="test value"', id="bad nested attribute"),
pytest.param('group.preferences.badAttribute="test value"', id="bad double nested attribute"),