mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-05-31 12:15:42 -04:00
feat: Generalize Search to Other Models (#2472)
* generalized search logic to SearchFilter * added default search behavior for all models * fix for schema overrides * added search support to several models * fix for label search * tests and fixes * add config for normalizing characters * dramatically simplified search tests * bark bark * fix normalization bug * tweaked tests * maybe this time? --------- Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
parent
76ae0bafc7
commit
99372aa2b6
@ -16,6 +16,7 @@ from mealie.db.models._model_base import SqlAlchemyBase
|
|||||||
from mealie.schema._mealie import MealieModel
|
from mealie.schema._mealie import MealieModel
|
||||||
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
|
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
|
||||||
from mealie.schema.response.query_filter import QueryFilter
|
from mealie.schema.response.query_filter import QueryFilter
|
||||||
|
from mealie.schema.response.query_search import SearchFilter
|
||||||
|
|
||||||
Schema = TypeVar("Schema", bound=MealieModel)
|
Schema = TypeVar("Schema", bound=MealieModel)
|
||||||
Model = TypeVar("Model", bound=SqlAlchemyBase)
|
Model = TypeVar("Model", bound=SqlAlchemyBase)
|
||||||
@ -291,7 +292,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
|||||||
q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match)
|
q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match)
|
||||||
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
|
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
|
||||||
|
|
||||||
def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
|
def page_all(self, pagination: PaginationQuery, override=None, search: str | None = None) -> PaginationBase[Schema]:
|
||||||
"""
|
"""
|
||||||
pagination is a method to interact with the filtered database table and return a paginated result
|
pagination is a method to interact with the filtered database table and return a paginated result
|
||||||
using the PaginationBase that provides several data points that are needed to manage pagination
|
using the PaginationBase that provides several data points that are needed to manage pagination
|
||||||
@ -302,12 +303,16 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
|||||||
as the override, as the type system is not able to infer the result of this method.
|
as the override, as the type system is not able to infer the result of this method.
|
||||||
"""
|
"""
|
||||||
eff_schema = override or self.schema
|
eff_schema = override or self.schema
|
||||||
|
# Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
|
||||||
|
pagination_result = pagination.copy()
|
||||||
q = self._query(override_schema=eff_schema, with_options=False)
|
q = self._query(override_schema=eff_schema, with_options=False)
|
||||||
|
|
||||||
fltr = self._filter_builder()
|
fltr = self._filter_builder()
|
||||||
q = q.filter_by(**fltr)
|
q = q.filter_by(**fltr)
|
||||||
q, count, total_pages = self.add_pagination_to_query(q, pagination)
|
if search:
|
||||||
|
q = self.add_search_to_query(q, eff_schema, search)
|
||||||
|
|
||||||
|
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)
|
||||||
|
|
||||||
# Apply options late, so they do not get used for counting
|
# Apply options late, so they do not get used for counting
|
||||||
q = q.options(*eff_schema.loader_options())
|
q = q.options(*eff_schema.loader_options())
|
||||||
@ -318,8 +323,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
|||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
raise e
|
raise e
|
||||||
return PaginationBase(
|
return PaginationBase(
|
||||||
page=pagination.page,
|
page=pagination_result.page,
|
||||||
per_page=pagination.per_page,
|
per_page=pagination_result.per_page,
|
||||||
total=count,
|
total=count,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
items=[eff_schema.from_orm(s) for s in data],
|
items=[eff_schema.from_orm(s) for s in data],
|
||||||
@ -392,3 +397,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
|||||||
query = query.order_by(case_stmt)
|
query = query.order_by(case_stmt)
|
||||||
|
|
||||||
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
|
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
|
||||||
|
|
||||||
|
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
|
||||||
|
search_filter = SearchFilter(self.session, search, schema._normalize_search)
|
||||||
|
return search_filter.filter_query_by_search(query, schema, self.model)
|
||||||
|
@ -5,10 +5,9 @@ from uuid import UUID
|
|||||||
|
|
||||||
from pydantic import UUID4
|
from pydantic import UUID4
|
||||||
from slugify import slugify
|
from slugify import slugify
|
||||||
from sqlalchemy import Select, and_, desc, func, or_, select, text
|
from sqlalchemy import and_, func, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from text_unidecode import unidecode
|
|
||||||
|
|
||||||
from mealie.db.models.recipe.category import Category
|
from mealie.db.models.recipe.category import Category
|
||||||
from mealie.db.models.recipe.ingredient import RecipeIngredientModel
|
from mealie.db.models.recipe.ingredient import RecipeIngredientModel
|
||||||
@ -18,13 +17,7 @@ from mealie.db.models.recipe.tag import Tag
|
|||||||
from mealie.db.models.recipe.tool import Tool
|
from mealie.db.models.recipe.tool import Tool
|
||||||
from mealie.schema.cookbook.cookbook import ReadCookBook
|
from mealie.schema.cookbook.cookbook import ReadCookBook
|
||||||
from mealie.schema.recipe import Recipe
|
from mealie.schema.recipe import Recipe
|
||||||
from mealie.schema.recipe.recipe import (
|
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
|
||||||
RecipeCategory,
|
|
||||||
RecipePagination,
|
|
||||||
RecipeSummary,
|
|
||||||
RecipeTag,
|
|
||||||
RecipeTool,
|
|
||||||
)
|
|
||||||
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
|
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
|
||||||
from mealie.schema.response.pagination import PaginationQuery
|
from mealie.schema.response.pagination import PaginationQuery
|
||||||
|
|
||||||
@ -151,98 +144,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
|||||||
additional_ids = self.session.execute(select(model.id).filter(model.slug.in_(slugs))).scalars().all()
|
additional_ids = self.session.execute(select(model.id).filter(model.slug.in_(slugs))).scalars().all()
|
||||||
return ids + additional_ids
|
return ids + additional_ids
|
||||||
|
|
||||||
def _add_search_to_query(self, query: Select, search: str) -> Select:
|
def page_all( # type: ignore
|
||||||
"""
|
|
||||||
0. fuzzy search (postgres only) and tokenized search are performed separately
|
|
||||||
1. take search string and do a little pre-normalization
|
|
||||||
2. look for internal quoted strings and keep them together as "literal" parts of the search
|
|
||||||
3. remove special characters from each non-literal search string
|
|
||||||
4. token search looks for any individual exact hit in name, description, and ingredients
|
|
||||||
5. fuzzy search looks for trigram hits in name, description, and ingredients
|
|
||||||
6. Sort order is determined by closeness to the recipe name
|
|
||||||
Should search also look at tags?
|
|
||||||
"""
|
|
||||||
|
|
||||||
normalized_search = unidecode(search).lower().strip()
|
|
||||||
punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
|
|
||||||
# keep quoted phrases together as literal portions of the search string
|
|
||||||
literal = False
|
|
||||||
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""") # thank you stack exchange!
|
|
||||||
removequotes_regex = re.compile(r"""['"](.*)['"]""")
|
|
||||||
if quoted_regex.search(normalized_search):
|
|
||||||
literal = True
|
|
||||||
temp = normalized_search
|
|
||||||
quoted_search_list = [match.group() for match in quoted_regex.finditer(temp)] # all quoted strings
|
|
||||||
quoted_search_list = [removequotes_regex.sub("\\1", x) for x in quoted_search_list] # remove outer quotes
|
|
||||||
temp = quoted_regex.sub("", temp) # remove all quoted strings, leaving just non-quoted
|
|
||||||
temp = temp.translate(
|
|
||||||
str.maketrans(punctuation, " " * len(punctuation))
|
|
||||||
) # punctuation->spaces for splitting, but only on unquoted strings
|
|
||||||
unquoted_search_list = temp.split() # all unquoted strings
|
|
||||||
normalized_search_list = quoted_search_list + unquoted_search_list
|
|
||||||
else:
|
|
||||||
#
|
|
||||||
normalized_search = normalized_search.translate(str.maketrans(punctuation, " " * len(punctuation)))
|
|
||||||
normalized_search_list = normalized_search.split()
|
|
||||||
normalized_search_list = [x.strip() for x in normalized_search_list] # remove padding whitespace inside quotes
|
|
||||||
# I would prefer to just do this in the recipe_ingredient.any part of the main query, but it turns out
|
|
||||||
# that at least sqlite wont use indexes for that correctly anymore and takes a big hit, so prefiltering it is
|
|
||||||
if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search
|
|
||||||
ingredient_ids = (
|
|
||||||
self.session.execute(
|
|
||||||
select(RecipeIngredientModel.id).filter(
|
|
||||||
or_(
|
|
||||||
RecipeIngredientModel.note_normalized.op("%>")(normalized_search),
|
|
||||||
RecipeIngredientModel.original_text_normalized.op("%>")(normalized_search),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.scalars()
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
else: # exact token search
|
|
||||||
ingredient_ids = (
|
|
||||||
self.session.execute(
|
|
||||||
select(RecipeIngredientModel.id).filter(
|
|
||||||
or_(
|
|
||||||
*[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in normalized_search_list],
|
|
||||||
*[
|
|
||||||
RecipeIngredientModel.original_text_normalized.like(f"%{ns}%")
|
|
||||||
for ns in normalized_search_list
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.scalars()
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search
|
|
||||||
# default = 0.7 is too strict for effective fuzzing
|
|
||||||
self.session.execute(text("set pg_trgm.word_similarity_threshold = 0.5;"))
|
|
||||||
q = query.filter(
|
|
||||||
or_(
|
|
||||||
RecipeModel.name_normalized.op("%>")(normalized_search),
|
|
||||||
RecipeModel.description_normalized.op("%>")(normalized_search),
|
|
||||||
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
|
|
||||||
)
|
|
||||||
).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands.
|
|
||||||
func.least(
|
|
||||||
RecipeModel.name_normalized.op("<->>")(normalized_search),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else: # exact token search
|
|
||||||
q = query.filter(
|
|
||||||
or_(
|
|
||||||
*[RecipeModel.name_normalized.like(f"%{ns}%") for ns in normalized_search_list],
|
|
||||||
*[RecipeModel.description_normalized.like(f"%{ns}%") for ns in normalized_search_list],
|
|
||||||
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
|
|
||||||
)
|
|
||||||
).order_by(desc(RecipeModel.name_normalized.like(f"%{normalized_search}%")))
|
|
||||||
|
|
||||||
return q
|
|
||||||
|
|
||||||
def page_all(
|
|
||||||
self,
|
self,
|
||||||
pagination: PaginationQuery,
|
pagination: PaginationQuery,
|
||||||
override=None,
|
override=None,
|
||||||
@ -299,7 +201,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
|||||||
)
|
)
|
||||||
q = q.filter(*filters)
|
q = q.filter(*filters)
|
||||||
if search:
|
if search:
|
||||||
q = self._add_search_to_query(q, search)
|
q = self.add_search_to_query(q, self.schema, search)
|
||||||
|
|
||||||
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)
|
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)
|
||||||
|
|
||||||
|
@ -41,10 +41,11 @@ class MultiPurposeLabelsController(BaseUserController):
|
|||||||
return HttpRepo(self.repo, self.logger, self.registered_exceptions, self.t("generic.server-error"))
|
return HttpRepo(self.repo, self.logger, self.registered_exceptions, self.t("generic.server-error"))
|
||||||
|
|
||||||
@router.get("", response_model=MultiPurposeLabelPagination)
|
@router.get("", response_model=MultiPurposeLabelPagination)
|
||||||
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
|
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
|
||||||
response = self.repo.page_all(
|
response = self.repo.page_all(
|
||||||
pagination=q,
|
pagination=q,
|
||||||
override=MultiPurposeLabelSummary,
|
override=MultiPurposeLabelSummary,
|
||||||
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
||||||
|
@ -38,11 +38,12 @@ class RecipeCategoryController(BaseCrudController):
|
|||||||
return HttpRepo(self.repo, self.logger)
|
return HttpRepo(self.repo, self.logger)
|
||||||
|
|
||||||
@router.get("", response_model=RecipeCategoryPagination)
|
@router.get("", response_model=RecipeCategoryPagination)
|
||||||
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
|
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
|
||||||
"""Returns a list of available categories in the database"""
|
"""Returns a list of available categories in the database"""
|
||||||
response = self.repo.page_all(
|
response = self.repo.page_all(
|
||||||
pagination=q,
|
pagination=q,
|
||||||
override=RecipeCategory,
|
override=RecipeCategory,
|
||||||
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
||||||
|
@ -27,11 +27,12 @@ class TagController(BaseCrudController):
|
|||||||
return HttpRepo(self.repo, self.logger)
|
return HttpRepo(self.repo, self.logger)
|
||||||
|
|
||||||
@router.get("", response_model=RecipeTagPagination)
|
@router.get("", response_model=RecipeTagPagination)
|
||||||
async def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
|
async def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
|
||||||
"""Returns a list of available tags in the database"""
|
"""Returns a list of available tags in the database"""
|
||||||
response = self.repo.page_all(
|
response = self.repo.page_all(
|
||||||
pagination=q,
|
pagination=q,
|
||||||
override=RecipeTag,
|
override=RecipeTag,
|
||||||
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
||||||
|
@ -25,10 +25,11 @@ class RecipeToolController(BaseUserController):
|
|||||||
return HttpRepo[RecipeToolCreate, RecipeTool, RecipeToolCreate](self.repo, self.logger)
|
return HttpRepo[RecipeToolCreate, RecipeTool, RecipeToolCreate](self.repo, self.logger)
|
||||||
|
|
||||||
@router.get("", response_model=RecipeToolPagination)
|
@router.get("", response_model=RecipeToolPagination)
|
||||||
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
|
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
|
||||||
response = self.repo.page_all(
|
response = self.repo.page_all(
|
||||||
pagination=q,
|
pagination=q,
|
||||||
override=RecipeTool,
|
override=RecipeTool,
|
||||||
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
||||||
|
@ -45,10 +45,11 @@ class IngredientFoodsController(BaseUserController):
|
|||||||
raise HTTPException(500, "Failed to merge foods") from e
|
raise HTTPException(500, "Failed to merge foods") from e
|
||||||
|
|
||||||
@router.get("", response_model=IngredientFoodPagination)
|
@router.get("", response_model=IngredientFoodPagination)
|
||||||
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
|
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
|
||||||
response = self.repo.page_all(
|
response = self.repo.page_all(
|
||||||
pagination=q,
|
pagination=q,
|
||||||
override=IngredientFood,
|
override=IngredientFood,
|
||||||
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
||||||
|
@ -45,10 +45,11 @@ class IngredientUnitsController(BaseUserController):
|
|||||||
raise HTTPException(500, "Failed to merge units") from e
|
raise HTTPException(500, "Failed to merge units") from e
|
||||||
|
|
||||||
@router.get("", response_model=IngredientUnitPagination)
|
@router.get("", response_model=IngredientUnitPagination)
|
||||||
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
|
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
|
||||||
response = self.repo.page_all(
|
response = self.repo.page_all(
|
||||||
pagination=q,
|
pagination=q,
|
||||||
override=IngredientUnit,
|
override=IngredientUnit,
|
||||||
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# This file is auto-generated by gen_schema_exports.py
|
# This file is auto-generated by gen_schema_exports.py
|
||||||
from .mealie_model import HasUUID, MealieModel
|
from .mealie_model import HasUUID, MealieModel, SearchType
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"HasUUID",
|
"HasUUID",
|
||||||
"MealieModel",
|
"MealieModel",
|
||||||
|
"SearchType",
|
||||||
]
|
]
|
||||||
|
@ -1,16 +1,34 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Protocol, TypeVar
|
from enum import Enum
|
||||||
|
from typing import ClassVar, Protocol, TypeVar
|
||||||
|
|
||||||
from humps.main import camelize
|
from humps.main import camelize
|
||||||
from pydantic import UUID4, BaseModel
|
from pydantic import UUID4, BaseModel
|
||||||
|
from sqlalchemy import Select, desc, func, or_, text
|
||||||
|
from sqlalchemy.orm import InstrumentedAttribute, Session
|
||||||
from sqlalchemy.orm.interfaces import LoaderOption
|
from sqlalchemy.orm.interfaces import LoaderOption
|
||||||
|
|
||||||
|
from mealie.db.models._model_base import SqlAlchemyBase
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchType(Enum):
|
||||||
|
fuzzy = "fuzzy"
|
||||||
|
tokenized = "tokenized"
|
||||||
|
|
||||||
|
|
||||||
class MealieModel(BaseModel):
|
class MealieModel(BaseModel):
|
||||||
|
_fuzzy_similarity_threshold: ClassVar[float] = 0.5
|
||||||
|
_normalize_search: ClassVar[bool] = False
|
||||||
|
_searchable_properties: ClassVar[list[str]] = []
|
||||||
|
"""
|
||||||
|
Searchable properties for the search API.
|
||||||
|
The first property will be used for sorting (order_by)
|
||||||
|
"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
alias_generator = camelize
|
alias_generator = camelize
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
@ -59,6 +77,40 @@ class MealieModel(BaseModel):
|
|||||||
def loader_options(cls) -> list[LoaderOption]:
|
def loader_options(cls) -> list[LoaderOption]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def filter_search_query(
|
||||||
|
cls,
|
||||||
|
db_model: type[SqlAlchemyBase],
|
||||||
|
query: Select,
|
||||||
|
session: Session,
|
||||||
|
search_type: SearchType,
|
||||||
|
search: str,
|
||||||
|
search_list: list[str],
|
||||||
|
) -> Select:
|
||||||
|
"""
|
||||||
|
Filters a search query based on model attributes
|
||||||
|
|
||||||
|
Can be overridden to support a more advanced search
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not cls._searchable_properties:
|
||||||
|
raise AttributeError("Not Implemented")
|
||||||
|
|
||||||
|
model_properties: list[InstrumentedAttribute] = [getattr(db_model, prop) for prop in cls._searchable_properties]
|
||||||
|
if search_type is SearchType.fuzzy:
|
||||||
|
session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};"))
|
||||||
|
filters = [prop.op("%>")(search) for prop in model_properties]
|
||||||
|
|
||||||
|
# trigram ordering by the first searchable property
|
||||||
|
return query.filter(or_(*filters)).order_by(func.least(model_properties[0].op("<->>")(search)))
|
||||||
|
else:
|
||||||
|
filters = []
|
||||||
|
for prop in model_properties:
|
||||||
|
filters.extend([prop.like(f"%{s}%") for s in search_list])
|
||||||
|
|
||||||
|
# order by how close the result is to the first searchable property
|
||||||
|
return query.filter(or_(*filters)).order_by(desc(model_properties[0].like(f"%{search}%")))
|
||||||
|
|
||||||
|
|
||||||
class HasUUID(Protocol):
|
class HasUUID(Protocol):
|
||||||
id: UUID4
|
id: UUID4
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
from pydantic import UUID4
|
from pydantic import UUID4
|
||||||
|
|
||||||
from mealie.schema._mealie import MealieModel
|
from mealie.schema._mealie import MealieModel
|
||||||
@ -20,7 +22,7 @@ class MultiPurposeLabelUpdate(MultiPurposeLabelSave):
|
|||||||
|
|
||||||
|
|
||||||
class MultiPurposeLabelSummary(MultiPurposeLabelUpdate):
|
class MultiPurposeLabelSummary(MultiPurposeLabelUpdate):
|
||||||
pass
|
_searchable_properties: ClassVar[list[str]] = ["name"]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
@ -31,14 +33,5 @@ class MultiPurposeLabelPagination(PaginationBase):
|
|||||||
|
|
||||||
|
|
||||||
class MultiPurposeLabelOut(MultiPurposeLabelUpdate):
|
class MultiPurposeLabelOut(MultiPurposeLabelUpdate):
|
||||||
# shopping_list_items: list[ShoppingListItemOut] = []
|
|
||||||
# foods: list[IngredientFood] = []
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
# from mealie.schema.recipe.recipe_ingredient import IngredientFood
|
|
||||||
# from mealie.schema.group.group_shopping_list import ShoppingListItemOut
|
|
||||||
|
|
||||||
# MultiPurposeLabelOut.update_forward_refs()
|
|
||||||
|
@ -2,16 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import UUID4, BaseModel, Field, validator
|
from pydantic import UUID4, BaseModel, Field, validator
|
||||||
from slugify import slugify
|
from slugify import slugify
|
||||||
from sqlalchemy.orm import joinedload, selectinload
|
from sqlalchemy import Select, desc, func, or_, select, text
|
||||||
|
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||||
from sqlalchemy.orm.interfaces import LoaderOption
|
from sqlalchemy.orm.interfaces import LoaderOption
|
||||||
|
|
||||||
from mealie.core.config import get_app_dirs
|
from mealie.core.config import get_app_dirs
|
||||||
from mealie.schema._mealie import MealieModel
|
from mealie.schema._mealie import MealieModel, SearchType
|
||||||
from mealie.schema.response.pagination import PaginationBase
|
from mealie.schema.response.pagination import PaginationBase
|
||||||
|
|
||||||
from ...db.models.recipe import (
|
from ...db.models.recipe import (
|
||||||
@ -37,6 +38,8 @@ class RecipeTag(MealieModel):
|
|||||||
name: str
|
name: str
|
||||||
slug: str
|
slug: str
|
||||||
|
|
||||||
|
_searchable_properties: ClassVar[list[str]] = ["name"]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
@ -78,6 +81,7 @@ class CreateRecipe(MealieModel):
|
|||||||
|
|
||||||
class RecipeSummary(MealieModel):
|
class RecipeSummary(MealieModel):
|
||||||
id: UUID4 | None
|
id: UUID4 | None
|
||||||
|
_normalize_search: ClassVar[bool] = True
|
||||||
|
|
||||||
user_id: UUID4 = Field(default_factory=uuid4)
|
user_id: UUID4 = Field(default_factory=uuid4)
|
||||||
group_id: UUID4 = Field(default_factory=uuid4)
|
group_id: UUID4 = Field(default_factory=uuid4)
|
||||||
@ -259,6 +263,69 @@ class Recipe(RecipeSummary):
|
|||||||
selectinload(RecipeModel.notes),
|
selectinload(RecipeModel.notes),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def filter_search_query(
|
||||||
|
cls, db_model, query: Select, session: Session, search_type: SearchType, search: str, search_list: list[str]
|
||||||
|
) -> Select:
|
||||||
|
"""
|
||||||
|
1. token search looks for any individual exact hit in name, description, and ingredients
|
||||||
|
2. fuzzy search looks for trigram hits in name, description, and ingredients
|
||||||
|
3. Sort order is determined by closeness to the recipe name
|
||||||
|
Should search also look at tags?
|
||||||
|
"""
|
||||||
|
|
||||||
|
if search_type is SearchType.fuzzy:
|
||||||
|
# I would prefer to just do this in the recipe_ingredient.any part of the main query,
|
||||||
|
# but it turns out that at least sqlite wont use indexes for that correctly anymore and
|
||||||
|
# takes a big hit, so prefiltering it is
|
||||||
|
ingredient_ids = (
|
||||||
|
session.execute(
|
||||||
|
select(RecipeIngredientModel.id).filter(
|
||||||
|
or_(
|
||||||
|
RecipeIngredientModel.note_normalized.op("%>")(search),
|
||||||
|
RecipeIngredientModel.original_text_normalized.op("%>")(search),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};"))
|
||||||
|
return query.filter(
|
||||||
|
or_(
|
||||||
|
RecipeModel.name_normalized.op("%>")(search),
|
||||||
|
RecipeModel.description_normalized.op("%>")(search),
|
||||||
|
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
|
||||||
|
)
|
||||||
|
).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands.
|
||||||
|
func.least(
|
||||||
|
RecipeModel.name_normalized.op("<->>")(search),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
ingredient_ids = (
|
||||||
|
session.execute(
|
||||||
|
select(RecipeIngredientModel.id).filter(
|
||||||
|
or_(
|
||||||
|
*[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in search_list],
|
||||||
|
*[RecipeIngredientModel.original_text_normalized.like(f"%{ns}%") for ns in search_list],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return query.filter(
|
||||||
|
or_(
|
||||||
|
*[RecipeModel.name_normalized.like(f"%{ns}%") for ns in search_list],
|
||||||
|
*[RecipeModel.description_normalized.like(f"%{ns}%") for ns in search_list],
|
||||||
|
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
|
||||||
|
)
|
||||||
|
).order_by(desc(RecipeModel.name_normalized.like(f"%{search}%")))
|
||||||
|
|
||||||
|
|
||||||
class RecipeLastMade(BaseModel):
|
class RecipeLastMade(BaseModel):
|
||||||
timestamp: datetime.datetime
|
timestamp: datetime.datetime
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
from typing import ClassVar
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from pydantic import UUID4, Field, validator
|
from pydantic import UUID4, Field, validator
|
||||||
@ -50,6 +51,8 @@ class IngredientFood(CreateIngredientFood):
|
|||||||
created_at: datetime.datetime | None
|
created_at: datetime.datetime | None
|
||||||
update_at: datetime.datetime | None
|
update_at: datetime.datetime | None
|
||||||
|
|
||||||
|
_searchable_properties: ClassVar[list[str]] = ["name", "description"]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
getter_dict = ExtrasGetterDict
|
getter_dict = ExtrasGetterDict
|
||||||
@ -78,6 +81,8 @@ class IngredientUnit(CreateIngredientUnit):
|
|||||||
created_at: datetime.datetime | None
|
created_at: datetime.datetime | None
|
||||||
update_at: datetime.datetime | None
|
update_at: datetime.datetime | None
|
||||||
|
|
||||||
|
_searchable_properties: ClassVar[list[str]] = ["name", "abbreviation", "description"]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
67
mealie/schema/response/query_search.py
Normal file
67
mealie/schema/response/query_search.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from sqlalchemy import Select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from text_unidecode import unidecode
|
||||||
|
|
||||||
|
from ...db.models._model_base import SqlAlchemyBase
|
||||||
|
from .._mealie import MealieModel, SearchType
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilter:
|
||||||
|
"""
|
||||||
|
0. fuzzy search (postgres only) and tokenized search are performed separately
|
||||||
|
1. take search string and do a little pre-normalization
|
||||||
|
2. look for internal quoted strings and keep them together as "literal" parts of the search
|
||||||
|
3. remove special characters from each non-literal search string
|
||||||
|
"""
|
||||||
|
|
||||||
|
punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
|
||||||
|
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""")
|
||||||
|
remove_quotes_regex = re.compile(r"""['"](.*)['"]""")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _normalize_search(cls, search: str, normalize_characters: bool) -> str:
|
||||||
|
search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation)))
|
||||||
|
|
||||||
|
if normalize_characters:
|
||||||
|
search = unidecode(search).lower().strip()
|
||||||
|
else:
|
||||||
|
search = search.strip()
|
||||||
|
|
||||||
|
return search
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_search_list(cls, search: str) -> list[str]:
|
||||||
|
if cls.quoted_regex.search(search):
|
||||||
|
# all quoted strings
|
||||||
|
quoted_search_list = [match.group() for match in cls.quoted_regex.finditer(search)]
|
||||||
|
|
||||||
|
# remove outer quotes
|
||||||
|
quoted_search_list = [cls.remove_quotes_regex.sub("\\1", x) for x in quoted_search_list]
|
||||||
|
|
||||||
|
# punctuation->spaces for splitting, but only on unquoted strings
|
||||||
|
search = cls.quoted_regex.sub("", search) # remove all quoted strings, leaving just non-quoted
|
||||||
|
search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation)))
|
||||||
|
|
||||||
|
# all unquoted strings
|
||||||
|
unquoted_search_list = search.split()
|
||||||
|
search_list = quoted_search_list + unquoted_search_list
|
||||||
|
else:
|
||||||
|
search_list = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation))).split()
|
||||||
|
|
||||||
|
# remove padding whitespace inside quotes
|
||||||
|
return [x.strip() for x in search_list]
|
||||||
|
|
||||||
|
def __init__(self, session: Session, search: str, normalize_characters: bool = False) -> None:
|
||||||
|
if session.get_bind().name != "postgresql" or self.quoted_regex.search(search.strip()):
|
||||||
|
self.search_type = SearchType.tokenized
|
||||||
|
else:
|
||||||
|
self.search_type = SearchType.fuzzy
|
||||||
|
|
||||||
|
self.session = session
|
||||||
|
self.search = self._normalize_search(search, normalize_characters)
|
||||||
|
self.search_list = self._build_search_list(self.search)
|
||||||
|
|
||||||
|
def filter_query_by_search(self, query: Select, schema: type[MealieModel], model: type[SqlAlchemyBase]) -> Select:
|
||||||
|
return schema.filter_search_query(model, query, self.session, self.search_type, self.search, self.search_list)
|
@ -1,17 +1,108 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mealie.repos.repository_factory import AllRepositories
|
from mealie.repos.repository_factory import AllRepositories
|
||||||
from mealie.repos.repository_recipes import RepositoryRecipes
|
from mealie.repos.repository_recipes import RepositoryRecipes
|
||||||
from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood, RecipeStep
|
from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood
|
||||||
from mealie.schema.recipe.recipe import Recipe, RecipeCategory, RecipeSummary
|
from mealie.schema.recipe.recipe import Recipe, RecipeCategory, RecipeSummary
|
||||||
from mealie.schema.recipe.recipe_category import CategoryOut, CategorySave, TagSave
|
from mealie.schema.recipe.recipe_category import CategoryOut, CategorySave, TagSave
|
||||||
from mealie.schema.recipe.recipe_tool import RecipeToolSave
|
from mealie.schema.recipe.recipe_tool import RecipeToolSave
|
||||||
from mealie.schema.response import OrderDirection, PaginationQuery
|
from mealie.schema.response import OrderDirection, PaginationQuery
|
||||||
from tests.utils.factories import random_string
|
from mealie.schema.user.user import GroupBase
|
||||||
|
from tests.utils.factories import random_email, random_string
|
||||||
from tests.utils.fixture_schemas import TestUser
|
from tests.utils.fixture_schemas import TestUser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def unique_local_group_id(database: AllRepositories) -> str:
|
||||||
|
return str(database.groups.create(GroupBase(name=random_string())).id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def unique_local_user_id(database: AllRepositories, unique_local_group_id: str) -> str:
|
||||||
|
return str(
|
||||||
|
database.users.create(
|
||||||
|
{
|
||||||
|
"username": random_string(),
|
||||||
|
"email": random_email(),
|
||||||
|
"group_id": unique_local_group_id,
|
||||||
|
"full_name": random_string(),
|
||||||
|
"password": random_string(),
|
||||||
|
"admin": False,
|
||||||
|
}
|
||||||
|
).id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def search_recipes(database: AllRepositories, unique_local_group_id: str, unique_local_user_id: str) -> list[Recipe]:
|
||||||
|
recipes = [
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Steinbock Sloop",
|
||||||
|
description=f"My favorite horns are delicious",
|
||||||
|
recipe_ingredient=[
|
||||||
|
RecipeIngredient(note="alpine animal"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Fiddlehead Fern Stir Fry",
|
||||||
|
recipe_ingredient=[
|
||||||
|
RecipeIngredient(note="moss"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Animal Sloop",
|
||||||
|
),
|
||||||
|
# Test diacritics
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Rátàtôuile",
|
||||||
|
),
|
||||||
|
# Add a bunch of recipes for stable randomization
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name=f"{random_string(10)} soup",
|
||||||
|
),
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name=f"{random_string(10)} soup",
|
||||||
|
),
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name=f"{random_string(10)} soup",
|
||||||
|
),
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name=f"{random_string(10)} soup",
|
||||||
|
),
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name=f"{random_string(10)} soup",
|
||||||
|
),
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_local_user_id,
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name=f"{random_string(10)} soup",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return database.recipes.create_many(recipes)
|
||||||
|
|
||||||
|
|
||||||
def test_recipe_repo_get_by_categories_basic(database: AllRepositories, unique_user: TestUser):
|
def test_recipe_repo_get_by_categories_basic(database: AllRepositories, unique_user: TestUser):
|
||||||
# Bootstrap the database with categories
|
# Bootstrap the database with categories
|
||||||
slug1, slug2, slug3 = (random_string(10) for _ in range(3))
|
slug1, slug2, slug3 = (random_string(10) for _ in range(3))
|
||||||
@ -112,7 +203,7 @@ def test_recipe_repo_get_by_categories_multi(database: AllRepositories, unique_u
|
|||||||
database.recipes.create(recipe)
|
database.recipes.create(recipe)
|
||||||
|
|
||||||
# Get all recipes by both categories
|
# Get all recipes by both categories
|
||||||
repo: RepositoryRecipes = database.recipes.by_group(unique_user.group_id) # type: ignore
|
repo: RepositoryRecipes = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||||
by_category = repo.get_by_categories(cast(list[RecipeCategory], created_categories))
|
by_category = repo.get_by_categories(cast(list[RecipeCategory], created_categories))
|
||||||
|
|
||||||
assert len(by_category) == 10
|
assert len(by_category) == 10
|
||||||
@ -490,129 +581,72 @@ def test_recipe_repo_pagination_by_foods(database: AllRepositories, unique_user:
|
|||||||
assert not all(i == random_ordered[0] for i in random_ordered)
|
assert not all(i == random_ordered[0] for i in random_ordered)
|
||||||
|
|
||||||
|
|
||||||
def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser):
|
@pytest.mark.parametrize(
|
||||||
recipes = [
|
"search, expected_names",
|
||||||
Recipe(
|
[
|
||||||
user_id=unique_user.user_id,
|
(random_string(), []),
|
||||||
group_id=unique_user.group_id,
|
("Steinbock", ["Steinbock Sloop"]),
|
||||||
name="Steinbock Sloop",
|
("horns", ["Steinbock Sloop"]),
|
||||||
description=f"My favorite horns are delicious",
|
("moss", ["Fiddlehead Fern Stir Fry"]),
|
||||||
recipe_ingredient=[
|
('"Animal Sloop"', ["Animal Sloop"]),
|
||||||
RecipeIngredient(note="alpine animal"),
|
("animal-sloop", ["Animal Sloop"]),
|
||||||
|
("ratat", ["Rátàtôuile"]),
|
||||||
|
("delicious horns", ["Steinbock Sloop"]),
|
||||||
],
|
],
|
||||||
),
|
ids=[
|
||||||
Recipe(
|
"no_match",
|
||||||
user_id=unique_user.user_id,
|
"search_by_title",
|
||||||
group_id=unique_user.group_id,
|
"search_by_description",
|
||||||
name="Fiddlehead Fern Stir Fry",
|
"search_by_ingredient",
|
||||||
recipe_ingredient=[
|
"literal_search",
|
||||||
RecipeIngredient(note="moss"),
|
"special_character_removal",
|
||||||
|
"normalization",
|
||||||
|
"token_separation",
|
||||||
],
|
],
|
||||||
),
|
)
|
||||||
Recipe(
|
def test_basic_recipe_search(
|
||||||
user_id=unique_user.user_id,
|
search: str,
|
||||||
group_id=unique_user.group_id,
|
expected_names: list[str],
|
||||||
name="Animal Sloop",
|
database: AllRepositories,
|
||||||
),
|
search_recipes: list[Recipe], # required so database is populated
|
||||||
# Test diacritics
|
unique_local_group_id: str,
|
||||||
Recipe(
|
):
|
||||||
user_id=unique_user.user_id,
|
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||||
group_id=unique_user.group_id,
|
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||||
name="Rátàtôuile",
|
results = repo.page_all(pagination, search=search).items
|
||||||
),
|
|
||||||
# Add a bunch of recipes for stable randomization
|
|
||||||
Recipe(
|
|
||||||
user_id=unique_user.user_id,
|
|
||||||
group_id=unique_user.group_id,
|
|
||||||
name=f"{random_string(10)} soup",
|
|
||||||
),
|
|
||||||
Recipe(
|
|
||||||
user_id=unique_user.user_id,
|
|
||||||
group_id=unique_user.group_id,
|
|
||||||
name=f"{random_string(10)} soup",
|
|
||||||
),
|
|
||||||
Recipe(
|
|
||||||
user_id=unique_user.user_id,
|
|
||||||
group_id=unique_user.group_id,
|
|
||||||
name=f"{random_string(10)} soup",
|
|
||||||
),
|
|
||||||
Recipe(
|
|
||||||
user_id=unique_user.user_id,
|
|
||||||
group_id=unique_user.group_id,
|
|
||||||
name=f"{random_string(10)} soup",
|
|
||||||
),
|
|
||||||
Recipe(
|
|
||||||
user_id=unique_user.user_id,
|
|
||||||
group_id=unique_user.group_id,
|
|
||||||
name=f"{random_string(10)} soup",
|
|
||||||
),
|
|
||||||
Recipe(
|
|
||||||
user_id=unique_user.user_id,
|
|
||||||
group_id=unique_user.group_id,
|
|
||||||
name=f"{random_string(10)} soup",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for recipe in recipes:
|
if len(expected_names) == 0:
|
||||||
database.recipes.create(recipe)
|
assert len(results) == 0
|
||||||
|
else:
|
||||||
|
# if more results are returned, that's acceptable, as long as they are ranked correctly
|
||||||
|
assert len(results) >= len(expected_names)
|
||||||
|
for recipe, name in zip(results, expected_names, strict=False):
|
||||||
|
assert recipe.name == name
|
||||||
|
|
||||||
pagination_query = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
|
||||||
|
|
||||||
# No hits
|
def test_fuzzy_recipe_search(
|
||||||
empty_result = database.recipes.page_all(pagination_query, search=random_string(10)).items
|
database: AllRepositories,
|
||||||
assert len(empty_result) == 0
|
search_recipes: list[Recipe], # required so database is populated
|
||||||
|
unique_local_group_id: str,
|
||||||
|
):
|
||||||
|
# this only works on postgres
|
||||||
|
if database.session.get_bind().name != "postgresql":
|
||||||
|
return
|
||||||
|
|
||||||
# Search by title
|
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||||
title_result = database.recipes.page_all(pagination_query, search="Steinbock").items
|
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||||
assert len(title_result) == 1
|
results = repo.page_all(pagination, search="Steinbuck").items
|
||||||
assert title_result[0].name == "Steinbock Sloop"
|
|
||||||
|
|
||||||
# Search by description
|
assert results and results[0].name == "Steinbock Sloop"
|
||||||
description_result = database.recipes.page_all(pagination_query, search="horns").items
|
|
||||||
assert len(description_result) == 1
|
|
||||||
assert description_result[0].name == "Steinbock Sloop"
|
|
||||||
|
|
||||||
# Search by ingredient
|
|
||||||
ingredient_result = database.recipes.page_all(pagination_query, search="moss").items
|
|
||||||
assert len(ingredient_result) == 1
|
|
||||||
assert ingredient_result[0].name == "Fiddlehead Fern Stir Fry"
|
|
||||||
|
|
||||||
# Make sure title matches are ordered in front
|
def test_random_order_recipe_search(
|
||||||
ordered_result = database.recipes.page_all(pagination_query, search="animal sloop").items
|
database: AllRepositories,
|
||||||
assert len(ordered_result) == 2
|
search_recipes: list[Recipe], # required so database is populated
|
||||||
assert ordered_result[0].name == "Animal Sloop"
|
unique_local_group_id: str,
|
||||||
assert ordered_result[1].name == "Steinbock Sloop"
|
):
|
||||||
|
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||||
# Test literal search
|
pagination = PaginationQuery(
|
||||||
literal_result = database.recipes.page_all(pagination_query, search='"Animal Sloop"').items
|
|
||||||
assert len(literal_result) == 1
|
|
||||||
assert literal_result[0].name == "Animal Sloop"
|
|
||||||
|
|
||||||
# Test special character removal from non-literal searches
|
|
||||||
character_result = database.recipes.page_all(pagination_query, search="animal-sloop").items
|
|
||||||
assert len(character_result) == 2
|
|
||||||
assert character_result[0].name == "Animal Sloop"
|
|
||||||
assert character_result[1].name == "Steinbock Sloop"
|
|
||||||
|
|
||||||
# Test string normalization
|
|
||||||
normalized_result = database.recipes.page_all(pagination_query, search="ratat").items
|
|
||||||
print([r.name for r in normalized_result])
|
|
||||||
assert len(normalized_result) == 1
|
|
||||||
assert normalized_result[0].name == "Rátàtôuile"
|
|
||||||
|
|
||||||
# Test token separation
|
|
||||||
token_result = database.recipes.page_all(pagination_query, search="delicious horns").items
|
|
||||||
assert len(token_result) == 1
|
|
||||||
assert token_result[0].name == "Steinbock Sloop"
|
|
||||||
|
|
||||||
# Test fuzzy search
|
|
||||||
if database.session.get_bind().name == "postgresql":
|
|
||||||
fuzzy_result = database.recipes.page_all(pagination_query, search="Steinbuck").items
|
|
||||||
assert len(fuzzy_result) == 1
|
|
||||||
assert fuzzy_result[0].name == "Steinbock Sloop"
|
|
||||||
|
|
||||||
# Test random ordering with search
|
|
||||||
pagination_query = PaginationQuery(
|
|
||||||
page=1,
|
page=1,
|
||||||
per_page=-1,
|
per_page=-1,
|
||||||
order_by="random",
|
order_by="random",
|
||||||
@ -620,7 +654,7 @@ def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser):
|
|||||||
order_direction=OrderDirection.asc,
|
order_direction=OrderDirection.asc,
|
||||||
)
|
)
|
||||||
random_ordered = []
|
random_ordered = []
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
pagination_query.pagination_seed = str(datetime.now())
|
pagination.pagination_seed = str(datetime.now())
|
||||||
random_ordered.append(database.recipes.page_all(pagination_query, search="soup").items)
|
random_ordered.append(repo.page_all(pagination, search="soup").items)
|
||||||
assert not all(i == random_ordered[0] for i in random_ordered)
|
assert not all(i == random_ordered[0] for i in random_ordered)
|
||||||
|
135
tests/unit_tests/repository_tests/test_search.py
Normal file
135
tests/unit_tests/repository_tests/test_search.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mealie.repos.repository_factory import AllRepositories
|
||||||
|
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
|
||||||
|
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
||||||
|
from mealie.schema.user.user import GroupBase
|
||||||
|
from tests.utils.factories import random_int, random_string
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def unique_local_group_id(database: AllRepositories) -> str:
|
||||||
|
return str(database.groups.create(GroupBase(name=random_string())).id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def search_units(database: AllRepositories, unique_local_group_id: str) -> list[IngredientUnit]:
|
||||||
|
units = [
|
||||||
|
SaveIngredientUnit(
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Tea Spoon",
|
||||||
|
abbreviation="tsp",
|
||||||
|
),
|
||||||
|
SaveIngredientUnit(
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Table Spoon",
|
||||||
|
description="unique description",
|
||||||
|
abbreviation="tbsp",
|
||||||
|
),
|
||||||
|
SaveIngredientUnit(
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Cup",
|
||||||
|
description="A bucket that's full",
|
||||||
|
),
|
||||||
|
SaveIngredientUnit(
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Píñch",
|
||||||
|
),
|
||||||
|
SaveIngredientUnit(
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Unit with a very cool name",
|
||||||
|
),
|
||||||
|
SaveIngredientUnit(
|
||||||
|
group_id=unique_local_group_id,
|
||||||
|
name="Unit with a pretty cool name",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add a bunch of units for stable randomization
|
||||||
|
units.extend(
|
||||||
|
[
|
||||||
|
SaveIngredientUnit(group_id=unique_local_group_id, name=f"{random_string()} unit")
|
||||||
|
for _ in range(random_int(12, 20))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return database.ingredient_units.create_many(units)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"search, expected_names",
|
||||||
|
[
|
||||||
|
(random_string(), []),
|
||||||
|
("Cup", ["Cup"]),
|
||||||
|
("tbsp", ["Table Spoon"]),
|
||||||
|
("unique description", ["Table Spoon"]),
|
||||||
|
("very cool name", ["Unit with a very cool name", "Unit with a pretty cool name"]),
|
||||||
|
('"Tea Spoon"', ["Tea Spoon"]),
|
||||||
|
("full bucket", ["Cup"]),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"no_match",
|
||||||
|
"search_by_name",
|
||||||
|
"search_by_unit",
|
||||||
|
"search_by_description",
|
||||||
|
"match_order",
|
||||||
|
"literal_search",
|
||||||
|
"token_separation",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_basic_search(
|
||||||
|
search: str,
|
||||||
|
expected_names: list[str],
|
||||||
|
database: AllRepositories,
|
||||||
|
search_units: list[IngredientUnit], # required so database is populated
|
||||||
|
unique_local_group_id: str,
|
||||||
|
):
|
||||||
|
repo = database.ingredient_units.by_group(unique_local_group_id)
|
||||||
|
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||||
|
results = repo.page_all(pagination, search=search).items
|
||||||
|
|
||||||
|
if len(expected_names) == 0:
|
||||||
|
assert len(results) == 0
|
||||||
|
else:
|
||||||
|
# if more results are returned, that's acceptable, as long as they are ranked correctly
|
||||||
|
assert len(results) >= len(expected_names)
|
||||||
|
for unit, name in zip(results, expected_names, strict=False):
|
||||||
|
assert unit.name == name
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuzzy_search(
|
||||||
|
database: AllRepositories,
|
||||||
|
search_units: list[IngredientUnit], # required so database is populated
|
||||||
|
unique_local_group_id: str,
|
||||||
|
):
|
||||||
|
# this only works on postgres
|
||||||
|
if database.session.get_bind().name != "postgresql":
|
||||||
|
return
|
||||||
|
|
||||||
|
repo = database.ingredient_units.by_group(unique_local_group_id)
|
||||||
|
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||||
|
results = repo.page_all(pagination, search="unique decsription").items
|
||||||
|
|
||||||
|
assert results and results[0].name == "Table Spoon"
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_order_search(
|
||||||
|
database: AllRepositories,
|
||||||
|
search_units: list[IngredientUnit], # required so database is populated
|
||||||
|
unique_local_group_id: str,
|
||||||
|
):
|
||||||
|
repo = database.ingredient_units.by_group(unique_local_group_id)
|
||||||
|
pagination = PaginationQuery(
|
||||||
|
page=1,
|
||||||
|
per_page=-1,
|
||||||
|
order_by="random",
|
||||||
|
pagination_seed=str(datetime.now()),
|
||||||
|
order_direction=OrderDirection.asc,
|
||||||
|
)
|
||||||
|
random_ordered = []
|
||||||
|
for _ in range(5):
|
||||||
|
pagination.pagination_seed = str(datetime.now())
|
||||||
|
random_ordered.append(repo.page_all(pagination, search="unit").items)
|
||||||
|
assert not all(i == random_ordered[0] for i in random_ordered)
|
Loading…
x
Reference in New Issue
Block a user