feature: query filter support for common SQL keywords (#2366)

* added support for SQL keywords IS, IN, LIKE, NOT
deprecated datetime workaround for "<> null"
updated frontend reference for "<> null" to "IS NOT NULL"

* tests

* refactored query filtering to leverage orm

* added CONTAINS ALL keyword

* tests

* fixed bug where "and" or "or" was in an attr name

* more tests

* linter fixes

* TIL this works
This commit is contained in:
Michael Genson 2023-05-06 17:28:40 -05:00 committed by GitHub
parent 9b726126ed
commit 5d87b7e411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 760 additions and 117 deletions

View File

@ -217,7 +217,7 @@ export default defineComponent({
const queryFilter = computed(() => { const queryFilter = computed(() => {
const orderBy = props.query?.orderBy || preferences.value.orderBy; const orderBy = props.query?.orderBy || preferences.value.orderBy;
return preferences.value.filterNull && orderBy ? `${orderBy} <> null` : null; return preferences.value.filterNull && orderBy ? `${orderBy} IS NOT NULL` : null;
}); });
async function fetchRecipes(pageCount = 1) { async function fetchRecipes(pageCount = 1) {

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import datetime
import re import re
from collections import deque
from enum import Enum from enum import Enum
from typing import Any, TypeVar, cast from typing import Any, TypeVar, cast
from uuid import UUID from uuid import UUID
@ -9,16 +9,66 @@ from uuid import UUID
from dateutil import parser as date_parser from dateutil import parser as date_parser
from dateutil.parser import ParserError from dateutil.parser import ParserError
from humps import decamelize from humps import decamelize
from sqlalchemy import Select, bindparam, inspect, text from sqlalchemy import ColumnElement, Select, and_, inspect, or_
from sqlalchemy.orm import Mapper from sqlalchemy.orm import InstrumentedAttribute, Mapper
from sqlalchemy.sql import sqltypes from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.expression import BindParameter
from mealie.db.models._model_utils.guid import GUID from mealie.db.models._model_utils.guid import GUID
Model = TypeVar("Model") Model = TypeVar("Model")
class RelationalKeyword(Enum):
IS = "IS"
IS_NOT = "IS NOT"
IN = "IN"
NOT_IN = "NOT IN"
CONTAINS_ALL = "CONTAINS ALL"
LIKE = "LIKE"
NOT_LIKE = "NOT LIKE"
@classmethod
def parse_component(cls, component: str) -> list[str] | None:
"""
Try to parse a component using a relational keyword
If no matching keyword is found, returns None
"""
# extract the attribute name from the component
parsed_component = component.split(maxsplit=1)
if len(parsed_component) < 2:
return None
# assume the component has already filtered out the value and try to match a keyword
# if we try to filter out the value without checking first, keywords with spaces won't parse correctly
possible_keyword = parsed_component[1].strip().lower()
for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True):
if rel_kw.lower() != possible_keyword:
continue
parsed_component[1] = rel_kw
return parsed_component
# there was no match, so the component may still have the value in it
try:
_possible_keyword, _value = parsed_component[-1].rsplit(maxsplit=1)
parsed_component = [parsed_component[0], _possible_keyword, _value]
except ValueError:
# the component has no value to filter out
return None
possible_keyword = parsed_component[1].strip().lower()
for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True):
if rel_kw.lower() != possible_keyword:
continue
parsed_component[1] = rel_kw
return parsed_component
return None
class RelationalOperator(Enum): class RelationalOperator(Enum):
EQ = "=" EQ = "="
NOTEQ = "<>" NOTEQ = "<>"
@ -27,6 +77,24 @@ class RelationalOperator(Enum):
GTE = ">=" GTE = ">="
LTE = "<=" LTE = "<="
@classmethod
def parse_component(cls, component: str) -> list[str] | None:
"""
Try to parse a component using a relational operator
If no matching operator is found, returns None
"""
for rel_op in sorted([operator.value for operator in cls], key=len, reverse=True):
if rel_op not in component:
continue
parsed_component = [base_component.strip() for base_component in component.split(rel_op) if base_component]
parsed_component.insert(1, rel_op)
return parsed_component
return None
class LogicalOperator(Enum): class LogicalOperator(Enum):
AND = "AND" AND = "AND"
@ -36,31 +104,107 @@ class LogicalOperator(Enum):
class QueryFilterComponent: class QueryFilterComponent:
"""A single relational statement""" """A single relational statement"""
def __init__(self, attribute_name: str, relational_operator: RelationalOperator, value: str) -> None: @staticmethod
def strip_quotes_from_string(val: str) -> str:
if len(val) > 2 and val[0] == '"' and val[-1] == '"':
return val[1:-1]
else:
return val
def __init__(
self, attribute_name: str, relationship: RelationalKeyword | RelationalOperator, value: str | list[str]
) -> None:
self.attribute_name = decamelize(attribute_name) self.attribute_name = decamelize(attribute_name)
self.relational_operator = relational_operator self.relationship = relationship
self.value = value
# remove encasing quotes # remove encasing quotes
if len(value) > 2 and value[0] == '"' and value[-1] == '"': if isinstance(value, str):
self.value = value[1:-1] value = self.strip_quotes_from_string(value)
elif isinstance(value, list):
value = [self.strip_quotes_from_string(v) for v in value]
# validate relationship/value pairs
if relationship in [
RelationalKeyword.IN,
RelationalKeyword.NOT_IN,
RelationalKeyword.CONTAINS_ALL,
] and not isinstance(value, list):
raise ValueError(
f"invalid query string: {relationship.value} must be given a list of values"
f"enclosed by {QueryFilter.l_list_sep} and {QueryFilter.r_list_sep}"
)
if relationship is RelationalKeyword.IS or relationship is RelationalKeyword.IS_NOT:
if not isinstance(value, str) or value.lower() not in ["null", "none"]:
raise ValueError(
f'invalid query string: "{relationship.value}" can only be used with "NULL", not "{value}"'
)
self.value = None
else:
self.value = value
def __repr__(self) -> str: def __repr__(self) -> str:
return f"[{self.attribute_name} {self.relational_operator.value} {self.value}]" return f"[{self.attribute_name} {self.relationship.value} {self.value}]"
def validate(self, model_attr_type: Any) -> Any:
"""Validate value against an model attribute's type and return a validated value, or raise a ValueError"""
sanitized_values: list[Any]
if not isinstance(self.value, list):
sanitized_values = [self.value]
else:
sanitized_values = self.value
for i, v in enumerate(sanitized_values):
# always allow querying for null values
if v is None:
continue
if self.relationship is RelationalKeyword.LIKE or self.relationship is RelationalKeyword.NOT_LIKE:
if not isinstance(model_attr_type, sqltypes.String):
raise ValueError(
f'invalid query string: "{self.relationship.value}" can only be used with string columns'
)
if isinstance(model_attr_type, (GUID)):
try:
# we don't set value since a UUID is functionally identical to a string here
UUID(v)
except ValueError as e:
raise ValueError(f"invalid query string: invalid UUID '{v}'") from e
if isinstance(model_attr_type, sqltypes.Date | sqltypes.DateTime):
try:
sanitized_values[i] = date_parser.parse(v)
except ParserError as e:
raise ValueError(f"invalid query string: unknown date or datetime format '{v}'") from e
if isinstance(model_attr_type, sqltypes.Boolean):
try:
sanitized_values[i] = v.lower()[0] in ["t", "y"] or v == "1"
except IndexError as e:
raise ValueError("invalid query string") from e
return sanitized_values if isinstance(self.value, list) else sanitized_values[0]
class QueryFilter: class QueryFilter:
lsep: str = "(" l_group_sep: str = "("
rsep: str = ")" r_group_sep: str = ")"
group_seps: set[str] = {l_group_sep, r_group_sep}
seps: set[str] = {lsep, rsep} l_list_sep: str = "["
r_list_sep: str = "]"
list_item_sep: str = ","
def __init__(self, filter_string: str) -> None: def __init__(self, filter_string: str) -> None:
# parse filter string # parse filter string
components = QueryFilter._break_filter_string_into_components(filter_string) components = QueryFilter._break_filter_string_into_components(filter_string)
base_components = QueryFilter._break_components_into_base_components(components) base_components = QueryFilter._break_components_into_base_components(components)
if base_components.count(QueryFilter.lsep) != base_components.count(QueryFilter.rsep): if base_components.count(QueryFilter.l_group_sep) != base_components.count(QueryFilter.r_group_sep):
raise ValueError("invalid filter string: parenthesis are unbalanced") raise ValueError("invalid query string: parenthesis are unbalanced")
# parse base components into a filter group # parse base components into a filter group
self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components) self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components)
@ -75,97 +219,125 @@ class QueryFilter:
return f"<<{joined}>>" return f"<<{joined}>>"
@classmethod
def _consolidate_group(cls, group: list[ColumnElement], logical_operators: deque[LogicalOperator]) -> ColumnElement:
consolidated_group_builder: ColumnElement | None = None
for i, element in enumerate(reversed(group)):
if not i:
consolidated_group_builder = element
else:
operator = logical_operators.pop()
if operator is LogicalOperator.AND:
consolidated_group_builder = and_(consolidated_group_builder, element)
elif operator is LogicalOperator.OR:
consolidated_group_builder = or_(consolidated_group_builder, element)
else:
raise ValueError(f"invalid logical operator {operator}")
if i == len(group) - 1:
return consolidated_group_builder.self_group()
def filter_query(self, query: Select, model: type[Model]) -> Select: def filter_query(self, query: Select, model: type[Model]) -> Select:
segments: list[str] = [] # join tables and build model chain
params: list[BindParameter] = [] attr_model_map: dict[int, Any] = {}
model_attr: InstrumentedAttribute
for i, component in enumerate(self.filter_components): for i, component in enumerate(self.filter_components):
if component in QueryFilter.seps: if not isinstance(component, QueryFilterComponent):
segments.append(component) # type: ignore
continue continue
if isinstance(component, LogicalOperator):
segments.append(component.value)
continue
# for some reason typing doesn't like the lsep and rsep literals, so
# we explicitly mark this as a filter component instead cast doesn't
# actually do anything at runtime
component = cast(QueryFilterComponent, component)
attribute_chain = component.attribute_name.split(".") attribute_chain = component.attribute_name.split(".")
if not attribute_chain: if not attribute_chain:
raise ValueError("invalid query string: attribute name cannot be empty") raise ValueError("invalid query string: attribute name cannot be empty")
attr_model: Any = model current_model = model
for j, attribute_link in enumerate(attribute_chain): for j, attribute_link in enumerate(attribute_chain):
# last element
if j == len(attribute_chain) - 1:
if not hasattr(attr_model, attribute_link):
raise ValueError(
f"invalid query string: '{component.attribute_name}' does not exist on this schema"
)
attr_value = attribute_link
if j:
# use the nested table name, rather than the dot notation
component.attribute_name = f"{attr_model.__table__.name}.{attr_value}"
continue
# join on nested model
try: try:
query = query.join(getattr(attr_model, attribute_link)) model_attr = getattr(current_model, attribute_link)
mapper: Mapper = inspect(attr_model) # at the end of the chain there are no more relationships to inspect
if j == len(attribute_chain) - 1:
break
query = query.join(model_attr)
mapper: Mapper = inspect(current_model)
relationship = mapper.relationships[attribute_link] relationship = mapper.relationships[attribute_link]
attr_model = relationship.mapper.class_ current_model = relationship.mapper.class_
except (AttributeError, KeyError) as e: except (AttributeError, KeyError) as e:
raise ValueError( raise ValueError(
f"invalid query string: '{component.attribute_name}' does not exist on this schema" f"invalid query string: '{component.attribute_name}' does not exist on this schema"
) from e ) from e
attr_model_map[i] = current_model
# convert values to their proper types # build query filter
attr = getattr(attr_model, attr_value) partial_group: list[ColumnElement] = []
value: Any = component.value partial_group_stack: deque[list[ColumnElement]] = deque()
logical_operator_stack: deque[LogicalOperator] = deque()
if isinstance(attr.type, (GUID)): for i, component in enumerate(self.filter_components):
try: if component == self.l_group_sep:
# we don't set value since a UUID is functionally identical to a string here partial_group_stack.append(partial_group)
UUID(value) partial_group = []
except ValueError as e:
raise ValueError(f"invalid query string: invalid UUID '{component.value}'") from e
if isinstance(attr.type, sqltypes.Date | sqltypes.DateTime):
# TODO: add support for IS NULL and IS NOT NULL
# in the meantime, this will work for the specific usecase of non-null dates/datetimes
if value in ["none", "null"] and component.relational_operator == RelationalOperator.NOTEQ:
component.relational_operator = RelationalOperator.GTE
value = datetime.datetime(datetime.MINYEAR, 1, 1)
elif component == self.r_group_sep:
if partial_group:
complete_group = self._consolidate_group(partial_group, logical_operator_stack)
partial_group = partial_group_stack.pop()
partial_group.append(complete_group)
else: else:
try: partial_group = partial_group_stack.pop()
value = date_parser.parse(component.value)
except ParserError as e: elif isinstance(component, LogicalOperator):
raise ValueError( logical_operator_stack.append(component)
f"invalid query string: unknown date or datetime format '{component.value}'"
) from e
if isinstance(attr.type, sqltypes.Boolean): else:
try: component = cast(QueryFilterComponent, component)
value = component.value.lower()[0] in ["t", "y"] or component.value == "1" model_attr = getattr(attr_model_map[i], component.attribute_name.split(".")[-1])
except IndexError as e: # Keywords
raise ValueError("invalid query string") from e if component.relationship is RelationalKeyword.IS:
element = model_attr.is_(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.IS_NOT:
element = model_attr.is_not(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.IN:
element = model_attr.in_(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.NOT_IN:
element = model_attr.not_in(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.CONTAINS_ALL:
primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0])
element = and_()
for v in component.validate(model_attr.type):
element = and_(element, primary_model_attr.any(model_attr == v))
elif component.relationship is RelationalKeyword.LIKE:
element = model_attr.like(component.validate(model_attr.type))
elif component.relationship is RelationalKeyword.NOT_LIKE:
element = model_attr.not_like(component.validate(model_attr.type))
paramkey = f"P{i+1}" # Operators
segments.append(" ".join([component.attribute_name, component.relational_operator.value, f":{paramkey}"])) elif component.relationship is RelationalOperator.EQ:
params.append(bindparam(paramkey, value, attr.type)) element = model_attr == component.validate(model_attr.type)
elif component.relationship is RelationalOperator.NOTEQ:
element = model_attr != component.validate(model_attr.type)
elif component.relationship is RelationalOperator.GT:
element = model_attr > component.validate(model_attr.type)
elif component.relationship is RelationalOperator.LT:
element = model_attr < component.validate(model_attr.type)
elif component.relationship is RelationalOperator.GTE:
element = model_attr >= component.validate(model_attr.type)
elif component.relationship is RelationalOperator.LTE:
element = model_attr <= component.validate(model_attr.type)
else:
raise ValueError(f"invalid relationship {component.relationship}")
qs = text(" ".join(segments)).bindparams(*params) partial_group.append(element)
query = query.filter(qs)
return query # combine the completed groups into one filter
while True:
consolidated_group = self._consolidate_group(partial_group, logical_operator_stack)
if not partial_group_stack:
return query.filter(consolidated_group)
else:
partial_group = partial_group_stack.pop()
partial_group.append(consolidated_group)
@staticmethod @staticmethod
def _break_filter_string_into_components(filter_string: str) -> list[str]: def _break_filter_string_into_components(filter_string: str) -> list[str]:
@ -176,7 +348,7 @@ class QueryFilter:
subcomponents = [] subcomponents = []
for component in components: for component in components:
# don't parse components comprised of only a separator # don't parse components comprised of only a separator
if component in QueryFilter.seps: if component in QueryFilter.group_seps:
subcomponents.append(component) subcomponents.append(component)
continue continue
@ -187,7 +359,7 @@ class QueryFilter:
if c == '"': if c == '"':
in_quotes = not in_quotes in_quotes = not in_quotes
if c in QueryFilter.seps and not in_quotes: if c in QueryFilter.group_seps and not in_quotes:
if new_component: if new_component:
subcomponents.append(new_component) subcomponents.append(new_component)
@ -208,25 +380,50 @@ class QueryFilter:
return components return components
@staticmethod @staticmethod
def _break_components_into_base_components(components: list[str]) -> list[str]: def _break_components_into_base_components(components: list[str]) -> list[str | list[str]]:
"""Further break down components by splitting at relational and logical operators""" """Further break down components by splitting at relational and logical operators"""
logical_operators = re.compile( pattern = "|".join([f"\\b{operator.value}\\b" for operator in LogicalOperator])
f'({"|".join(operator.value for operator in LogicalOperator)})', flags=re.IGNORECASE logical_operators = re.compile(f"({pattern})", flags=re.IGNORECASE)
)
base_components = [] in_list = False
base_components: list[str | list] = []
list_value_components = []
for component in components: for component in components:
offset = 0 # parse out lists as their own singular sub component
subcomponents = component.split(QueryFilter.l_list_sep)
for i, subcomponent in enumerate(subcomponents):
if not i:
continue
for j, list_value_string in enumerate(subcomponent.split(QueryFilter.r_list_sep)):
if j % 2:
continue
list_value_components.append(
[val.strip() for val in list_value_string.split(QueryFilter.list_item_sep)]
)
quote_offset = 0
subcomponents = component.split('"') subcomponents = component.split('"')
for i, subcomponent in enumerate(subcomponents): for i, subcomponent in enumerate(subcomponents):
# we are in a list subcomponent, which is already handled
if in_list:
if QueryFilter.r_list_sep in subcomponent:
# filter out the remainder of the list subcomponent and continue parsing
base_components.append(list_value_components.pop(0))
subcomponent = subcomponent.split(QueryFilter.r_list_sep, maxsplit=1)[-1].strip()
in_list = False
else:
continue
# don't parse components comprised of only a separator # don't parse components comprised of only a separator
if subcomponent in QueryFilter.seps: if subcomponent in QueryFilter.group_seps:
offset += 1 quote_offset += 1
base_components.append(subcomponent) base_components.append(subcomponent)
continue continue
# this subscomponent was surrounded in quotes, so we keep it as-is # this subcomponent was surrounded in quotes, so we keep it as-is
if (i + offset) % 2: if (i + quote_offset) % 2:
base_components.append(f'"{subcomponent.strip()}"') base_components.append(f'"{subcomponent.strip()}"')
continue continue
@ -234,53 +431,70 @@ class QueryFilter:
if not subcomponent: if not subcomponent:
continue continue
# continue parsing this subcomponent up to the list, then skip over subsequent subcomponents
if not in_list and QueryFilter.l_list_sep in subcomponent:
subcomponent, _new_sub_component = subcomponent.split(QueryFilter.l_list_sep, maxsplit=1)
subcomponent = subcomponent.strip()
subcomponents.insert(i + 1, _new_sub_component)
quote_offset += 1
in_list = True
# parse out logical operators # parse out logical operators
new_components = [ new_components = [
base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component
] ]
# parse out relational operators; each base_subcomponent has exactly zero or one relational operator # parse out relational keywords and operators
# we do them one at a time in descending length since some operators overlap (e.g. :> and >) # each base_subcomponent has exactly zero or one keyword or operator
for component in new_components: for component in new_components:
if not component: if not component:
continue continue
added_to_base_components = False # we try relational operators first since they aren't required to be surrounded by spaces
for rel_op in sorted([operator.value for operator in RelationalOperator], key=len, reverse=True): parsed_component = RelationalOperator.parse_component(component)
if rel_op in component: if parsed_component is not None:
new_base_components = [ base_components.extend(parsed_component)
base_component.strip() for base_component in component.split(rel_op) if base_component continue
]
new_base_components.insert(1, rel_op)
base_components.extend(new_base_components)
added_to_base_components = True parsed_component = RelationalKeyword.parse_component(component)
break if parsed_component is not None:
base_components.extend(parsed_component)
continue
if not added_to_base_components: # this component does not have any keywords or operators, so we just add it as-is
base_components.append(component) base_components.append(component)
return base_components return base_components
@staticmethod @staticmethod
def _parse_base_components_into_filter_components( def _parse_base_components_into_filter_components(
base_components: list[str], base_components: list[str | list[str]],
) -> list[str | QueryFilterComponent | LogicalOperator]: ) -> list[str | QueryFilterComponent | LogicalOperator]:
"""Walk through base components and construct filter collections""" """Walk through base components and construct filter collections"""
relational_keywords = [kw.value for kw in RelationalKeyword]
relational_operators = [op.value for op in RelationalOperator] relational_operators = [op.value for op in RelationalOperator]
logical_operators = [op.value for op in LogicalOperator] logical_operators = [op.value for op in LogicalOperator]
# parse QueryFilterComponents and logical operators # parse QueryFilterComponents and logical operators
components: list[str | QueryFilterComponent | LogicalOperator] = [] components: list[str | QueryFilterComponent | LogicalOperator] = []
for i, base_component in enumerate(base_components): for i, base_component in enumerate(base_components):
if base_component in QueryFilter.seps: if isinstance(base_component, list):
continue
if base_component in QueryFilter.group_seps:
components.append(base_component) components.append(base_component)
elif base_component in relational_operators: elif base_component in relational_keywords or base_component in relational_operators:
relationship: RelationalKeyword | RelationalOperator
if base_component in relational_keywords:
relationship = RelationalKeyword(base_components[i])
else:
relationship = RelationalOperator(base_components[i])
components.append( components.append(
QueryFilterComponent( QueryFilterComponent(
attribute_name=base_components[i - 1], attribute_name=base_components[i - 1], # type: ignore
relational_operator=RelationalOperator(base_components[i]), relationship=relationship,
value=base_components[i + 1], value=base_components[i + 1],
) )
) )

View File

@ -1,5 +1,6 @@
import time import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime
from random import randint from random import randint
from urllib.parse import parse_qsl, urlsplit from urllib.parse import parse_qsl, urlsplit
@ -9,7 +10,10 @@ from humps import camelize
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.repos.repository_units import RepositoryUnit from mealie.repos.repository_units import RepositoryUnit
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe_category import CategorySave, TagSave
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
from mealie.schema.recipe.recipe_tool import RecipeToolSave
from mealie.schema.response.pagination import PaginationQuery from mealie.schema.response.pagination import PaginationQuery
from mealie.services.seeder.seeder_service import SeederService from mealie.services.seeder.seeder_service import SeederService
from tests.utils import api_routes from tests.utils import api_routes
@ -172,6 +176,256 @@ def test_pagination_filter_basic(query_units: tuple[RepositoryUnit, IngredientUn
assert unit_results[0].id == unit_2.id assert unit_results[0].id == unit_2.id
def test_pagination_filter_null(database: AllRepositories, unique_user: TestUser):
recipe_not_made_1 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())
)
recipe_not_made_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())
)
# give one recipe a last made date
recipe_made = database.recipes.create(
Recipe(
user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), last_made=datetime.now()
)
)
recipe_repo = database.recipes.by_group(unique_user.group_id) # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NONE")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 2
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id in result_ids
assert recipe_not_made_2.id in result_ids
assert recipe_made.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NULL")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 2
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id in result_ids
assert recipe_not_made_2.id in result_ids
assert recipe_made.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NONE")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 1
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id not in result_ids
assert recipe_not_made_2.id not in result_ids
assert recipe_made.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NULL")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 1
result_ids = {result.id for result in recipe_results}
assert recipe_not_made_1.id not in result_ids
assert recipe_not_made_2.id not in result_ids
assert recipe_made.id in result_ids
def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
units_repo, unit_1, unit_2, unit_3 = query_units
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
assert unit_3.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
assert unit_3.id in result_ids
def test_pagination_filter_in_advanced(database: AllRepositories, unique_user: TestUser):
slug1, slug2 = (random_string(10) for _ in range(2))
tags = [
TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1),
TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2),
]
tag_1, tag_2 = [database.tags.create(tag) for tag in tags]
# Bootstrap the database with recipes
slug = random_string()
recipe_0 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[])
)
slug = random_string()
recipe_1 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1])
)
slug = random_string()
recipe_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_2])
)
slug = random_string()
recipe_1_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1, tag_2])
)
query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}]")
recipe_results = database.recipes.page_all(query).items
assert len(recipe_results) == 2
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_0.id not in recipe_ids
assert recipe_1.id in recipe_ids
assert recipe_2.id not in recipe_ids
assert recipe_1_2.id in recipe_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}, {tag_2.name}]")
recipe_results = database.recipes.page_all(query).items
assert len(recipe_results) == 3
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_0.id not in recipe_ids
assert recipe_1.id in recipe_ids
assert recipe_2.id in recipe_ids
assert recipe_1_2.id in recipe_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name CONTAINS ALL [{tag_1.name}, {tag_2.name}]")
recipe_results = database.recipes.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_0.id not in recipe_ids
assert recipe_1.id not in recipe_ids
assert recipe_2.id not in recipe_ids
assert recipe_1_2.id in recipe_ids
def test_pagination_filter_like(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
units_repo, unit_1, unit_2, unit_3 = query_units
query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "test u_it%"')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 3
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "%unit 1"')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id not in result_ids
assert unit_3.id not in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter=r'name NOT LIKE %t_1"')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id in result_ids
assert unit_3.id in result_ids
def test_pagination_filter_keyword_namespace_conflict(database: AllRepositories, unique_user: TestUser):
recipe_rating_1 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=1)
)
recipe_rating_2 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=2)
)
recipe_rating_3 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=3)
)
recipe_repo = database.recipes.by_group(unique_user.group_id) # type: ignore
# "rating" contains the word "in", but we should not parse this as the keyword "IN"
query = PaginationQuery(page=1, per_page=-1, query_filter="rating > 2")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 1
result_ids = {recipe.id for recipe in recipe_results}
assert recipe_rating_1.id not in result_ids
assert recipe_rating_2.id not in result_ids
assert recipe_rating_3.id in result_ids
query = PaginationQuery(page=1, per_page=-1, query_filter="rating in [1, 3]")
recipe_results = recipe_repo.page_all(query).items
assert len(recipe_results) == 2
result_ids = {recipe.id for recipe in recipe_results}
assert recipe_rating_1.id in result_ids
assert recipe_rating_2.id not in result_ids
assert recipe_rating_3.id in result_ids
def test_pagination_filter_logical_namespace_conflict(database: AllRepositories, unique_user: TestUser):
categories = [
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
]
category_1, category_2 = [database.categories.create(category) for category in categories]
# Bootstrap the database with recipes
slug = random_string()
recipe_category_0 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug)
)
slug = random_string()
recipe_category_1 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
)
)
slug = random_string()
recipe_category_2 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_2],
)
)
# "recipeCategory" has the substring "or" in it, which shouldn't break queries
query = PaginationQuery(page=1, per_page=-1, query_filter=f'recipeCategory.id = "{category_1.id}"')
recipe_results = database.recipes.by_group(unique_user.group_id).page_all(query).items # type: ignore
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_category_0.id not in recipe_ids
assert recipe_category_1.id in recipe_ids
assert recipe_category_2.id not in recipe_ids
def test_pagination_filter_datetimes( def test_pagination_filter_datetimes(
query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit] query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]
): ):
@ -197,15 +451,183 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien
def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
units_repo = query_units[0] units_repo, unit_1, unit_2, unit_3 = query_units
unit_3 = query_units[3]
dt = str(unit_3.created_at.isoformat()) # type: ignore dt = str(unit_3.created_at.isoformat()) # type: ignore
qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="test unit 2" OR createdAt > "{dt}"))' qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="{unit_2.name}" OR createdAt > "{dt}"))'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf) query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2 assert len(unit_results) == 2
assert unit_3.id not in [unit.id for unit in unit_results] result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id not in result_ids
qf = f'(name LIKE %_1 OR name IN ["{unit_2.name}"]) AND createdAt IS NOT NONE'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
assert unit_3.id not in result_ids
def test_pagination_filter_advanced_frontend_sort(database: AllRepositories, unique_user: TestUser):
categories = [
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
CategorySave(group_id=unique_user.group_id, name=random_string(10)),
]
category_1, category_2 = [database.categories.create(category) for category in categories]
slug1, slug2 = (random_string(10) for _ in range(2))
tags = [
TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1),
TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2),
]
tag_1, tag_2 = [database.tags.create(tag) for tag in tags]
tools = [
RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)),
RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)),
]
tool_1, tool_2 = [database.tools.create(tool) for tool in tools]
# Bootstrap the database with recipes
slug = random_string()
recipe_ct0_tg0_tl0 = database.recipes.create(
Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug)
)
slug = random_string()
recipe_ct1_tg0_tl0 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
)
)
slug = random_string()
recipe_ct12_tg0_tl0 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1, category_2],
)
)
slug = random_string()
recipe_ct1_tg1_tl0 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
tags=[tag_1],
)
)
slug = random_string()
recipe_ct1_tg0_tl1 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1],
tools=[tool_1],
)
)
slug = random_string()
recipe_ct0_tg2_tl2 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
tags=[tag_2],
tools=[tool_2],
)
)
slug = random_string()
recipe_ct12_tg12_tl2 = database.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=slug,
slug=slug,
recipe_category=[category_1, category_2],
tags=[tag_1, tag_2],
tools=[tool_2],
)
)
repo = database.recipes.by_group(unique_user.group_id) # type: ignore
qf = f'recipeCategory.id IN ["{category_1.id}"] AND tools.id IN ["{tool_1.id}"]'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id in recipe_ids
assert recipe_ct0_tg2_tl2.id not in recipe_ids
assert recipe_ct12_tg12_tl2.id not in recipe_ids
qf = f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"] AND tags.id IN ["{tag_1.id}"]'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id not in recipe_ids
assert recipe_ct0_tg2_tl2.id not in recipe_ids
assert recipe_ct12_tg12_tl2.id in recipe_ids
qf = f'tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_2.id}"]'
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 2
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id not in recipe_ids
assert recipe_ct0_tg2_tl2.id in recipe_ids
assert recipe_ct12_tg12_tl2.id in recipe_ids
qf = (
f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"]'
f'AND tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_1.id}", "{tool_2.id}"]'
)
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
recipe_results = repo.page_all(query).items
assert len(recipe_results) == 1
recipe_ids = {recipe.id for recipe in recipe_results}
assert recipe_ct0_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl0.id not in recipe_ids
assert recipe_ct12_tg0_tl0.id not in recipe_ids
assert recipe_ct1_tg1_tl0.id not in recipe_ids
assert recipe_ct1_tg0_tl1.id not in recipe_ids
assert recipe_ct0_tg2_tl2.id not in recipe_ids
assert recipe_ct12_tg12_tl2.id in recipe_ids
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -214,6 +636,13 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
pytest.param('(name="test name" AND useAbbreviation=f))', id="unbalanced parenthesis"), pytest.param('(name="test name" AND useAbbreviation=f))', id="unbalanced parenthesis"),
pytest.param('id="this is not a valid UUID"', id="invalid UUID"), pytest.param('id="this is not a valid UUID"', id="invalid UUID"),
pytest.param('createdAt="this is not a valid datetime format"', id="invalid datetime format"), pytest.param('createdAt="this is not a valid datetime format"', id="invalid datetime format"),
pytest.param('name IS "test name"', id="IS can only be used with NULL or NONE"),
pytest.param('name IS NOT "test name"', id="IS NOT can only be used with NULL or NONE"),
pytest.param('name IN "test name"', id="IN must use a list of values"),
pytest.param('name NOT IN "test name"', id="NOT IN must use a list of values"),
pytest.param('name CONTAINS ALL "test name"', id="CONTAINS ALL must use a list of values"),
pytest.param('createdAt LIKE "2023-02-25"', id="LIKE is only valid for string columns"),
pytest.param('createdAt NOT LIKE "2023-02-25"', id="NOT LIKE is only valid for string columns"),
pytest.param('badAttribute="test value"', id="invalid attribute"), pytest.param('badAttribute="test value"', id="invalid attribute"),
pytest.param('group.badAttribute="test value"', id="bad nested attribute"), pytest.param('group.badAttribute="test value"', id="bad nested attribute"),
pytest.param('group.preferences.badAttribute="test value"', id="bad double nested attribute"), pytest.param('group.preferences.badAttribute="test value"', id="bad double nested attribute"),