feat: Generalize Search to Other Models (#2472)

* generalized search logic to SearchFilter

* added default search behavior for all models

* fix for schema overrides

* added search support to several models

* fix for label search

* tests and fixes

* add config for normalizing characters

* dramatically simplified search tests

* bark bark

* fix normalization bug

* tweaked tests

* maybe this time?

---------

Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
Michael Genson 2023-08-20 13:30:21 -05:00 committed by GitHub
parent 76ae0bafc7
commit 99372aa2b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 521 additions and 250 deletions

View File

@ -16,6 +16,7 @@ from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema._mealie import MealieModel
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
from mealie.schema.response.query_search import SearchFilter
Schema = TypeVar("Schema", bound=MealieModel)
Model = TypeVar("Model", bound=SqlAlchemyBase)
@ -291,7 +292,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match)
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
def page_all(self, pagination: PaginationQuery, override=None, search: str | None = None) -> PaginationBase[Schema]:
"""
pagination is a method to interact with the filtered database table and return a paginated result
using the PaginationBase that provides several data points that are needed to manage pagination
@ -302,12 +303,16 @@ class RepositoryGeneric(Generic[Schema, Model]):
as the override, as the type system is not able to infer the result of this method.
"""
eff_schema = override or self.schema
# Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
pagination_result = pagination.copy()
q = self._query(override_schema=eff_schema, with_options=False)
fltr = self._filter_builder()
q = q.filter_by(**fltr)
q, count, total_pages = self.add_pagination_to_query(q, pagination)
if search:
q = self.add_search_to_query(q, eff_schema, search)
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)
# Apply options late, so they do not get used for counting
q = q.options(*eff_schema.loader_options())
@ -318,8 +323,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.session.rollback()
raise e
return PaginationBase(
page=pagination.page,
per_page=pagination.per_page,
page=pagination_result.page,
per_page=pagination_result.per_page,
total=count,
total_pages=total_pages,
items=[eff_schema.from_orm(s) for s in data],
@ -392,3 +397,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
query = query.order_by(case_stmt)
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
search_filter = SearchFilter(self.session, search, schema._normalize_search)
return search_filter.filter_query_by_search(query, schema, self.model)

View File

@ -5,10 +5,9 @@ from uuid import UUID
from pydantic import UUID4
from slugify import slugify
from sqlalchemy import Select, and_, desc, func, or_, select, text
from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from text_unidecode import unidecode
from mealie.db.models.recipe.category import Category
from mealie.db.models.recipe.ingredient import RecipeIngredientModel
@ -18,13 +17,7 @@ from mealie.db.models.recipe.tag import Tag
from mealie.db.models.recipe.tool import Tool
from mealie.schema.cookbook.cookbook import ReadCookBook
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe import (
RecipeCategory,
RecipePagination,
RecipeSummary,
RecipeTag,
RecipeTool,
)
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
from mealie.schema.response.pagination import PaginationQuery
@ -151,98 +144,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
additional_ids = self.session.execute(select(model.id).filter(model.slug.in_(slugs))).scalars().all()
return ids + additional_ids
def _add_search_to_query(self, query: Select, search: str) -> Select:
"""
0. fuzzy search (postgres only) and tokenized search are performed separately
1. take search string and do a little pre-normalization
2. look for internal quoted strings and keep them together as "literal" parts of the search
3. remove special characters from each non-literal search string
4. token search looks for any individual exact hit in name, description, and ingredients
5. fuzzy search looks for trigram hits in name, description, and ingredients
6. Sort order is determined by closeness to the recipe name
Should search also look at tags?
"""
normalized_search = unidecode(search).lower().strip()
punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
# keep quoted phrases together as literal portions of the search string
literal = False
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""") # thank you stack exchange!
removequotes_regex = re.compile(r"""['"](.*)['"]""")
if quoted_regex.search(normalized_search):
literal = True
temp = normalized_search
quoted_search_list = [match.group() for match in quoted_regex.finditer(temp)] # all quoted strings
quoted_search_list = [removequotes_regex.sub("\\1", x) for x in quoted_search_list] # remove outer quotes
temp = quoted_regex.sub("", temp) # remove all quoted strings, leaving just non-quoted
temp = temp.translate(
str.maketrans(punctuation, " " * len(punctuation))
) # punctuation->spaces for splitting, but only on unquoted strings
unquoted_search_list = temp.split() # all unquoted strings
normalized_search_list = quoted_search_list + unquoted_search_list
else:
#
normalized_search = normalized_search.translate(str.maketrans(punctuation, " " * len(punctuation)))
normalized_search_list = normalized_search.split()
normalized_search_list = [x.strip() for x in normalized_search_list] # remove padding whitespace inside quotes
# I would prefer to just do this in the recipe_ingredient.any part of the main query, but it turns out
# that at least sqlite wont use indexes for that correctly anymore and takes a big hit, so prefiltering it is
if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search
ingredient_ids = (
self.session.execute(
select(RecipeIngredientModel.id).filter(
or_(
RecipeIngredientModel.note_normalized.op("%>")(normalized_search),
RecipeIngredientModel.original_text_normalized.op("%>")(normalized_search),
)
)
)
.scalars()
.all()
)
else: # exact token search
ingredient_ids = (
self.session.execute(
select(RecipeIngredientModel.id).filter(
or_(
*[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in normalized_search_list],
*[
RecipeIngredientModel.original_text_normalized.like(f"%{ns}%")
for ns in normalized_search_list
],
)
)
)
.scalars()
.all()
)
if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search
# default = 0.7 is too strict for effective fuzzing
self.session.execute(text("set pg_trgm.word_similarity_threshold = 0.5;"))
q = query.filter(
or_(
RecipeModel.name_normalized.op("%>")(normalized_search),
RecipeModel.description_normalized.op("%>")(normalized_search),
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands.
func.least(
RecipeModel.name_normalized.op("<->>")(normalized_search),
)
)
else: # exact token search
q = query.filter(
or_(
*[RecipeModel.name_normalized.like(f"%{ns}%") for ns in normalized_search_list],
*[RecipeModel.description_normalized.like(f"%{ns}%") for ns in normalized_search_list],
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by(desc(RecipeModel.name_normalized.like(f"%{normalized_search}%")))
return q
def page_all(
def page_all( # type: ignore
self,
pagination: PaginationQuery,
override=None,
@ -299,7 +201,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
)
q = q.filter(*filters)
if search:
q = self._add_search_to_query(q, search)
q = self.add_search_to_query(q, self.schema, search)
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)

View File

@ -41,10 +41,11 @@ class MultiPurposeLabelsController(BaseUserController):
return HttpRepo(self.repo, self.logger, self.registered_exceptions, self.t("generic.server-error"))
@router.get("", response_model=MultiPurposeLabelPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=MultiPurposeLabelSummary,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@ -38,11 +38,12 @@ class RecipeCategoryController(BaseCrudController):
return HttpRepo(self.repo, self.logger)
@router.get("", response_model=RecipeCategoryPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
"""Returns a list of available categories in the database"""
response = self.repo.page_all(
pagination=q,
override=RecipeCategory,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@ -27,11 +27,12 @@ class TagController(BaseCrudController):
return HttpRepo(self.repo, self.logger)
@router.get("", response_model=RecipeTagPagination)
async def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
async def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
"""Returns a list of available tags in the database"""
response = self.repo.page_all(
pagination=q,
override=RecipeTag,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@ -25,10 +25,11 @@ class RecipeToolController(BaseUserController):
return HttpRepo[RecipeToolCreate, RecipeTool, RecipeToolCreate](self.repo, self.logger)
@router.get("", response_model=RecipeToolPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=RecipeTool,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@ -45,10 +45,11 @@ class IngredientFoodsController(BaseUserController):
raise HTTPException(500, "Failed to merge foods") from e
@router.get("", response_model=IngredientFoodPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=IngredientFood,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@ -45,10 +45,11 @@ class IngredientUnitsController(BaseUserController):
raise HTTPException(500, "Failed to merge units") from e
@router.get("", response_model=IngredientUnitPagination)
def get_all(self, q: PaginationQuery = Depends(PaginationQuery)):
def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None):
response = self.repo.page_all(
pagination=q,
override=IngredientUnit,
search=search,
)
response.set_pagination_guides(router.url_path_for("get_all"), q.dict())

View File

@ -1,7 +1,8 @@
# This file is auto-generated by gen_schema_exports.py
from .mealie_model import HasUUID, MealieModel
from .mealie_model import HasUUID, MealieModel, SearchType
__all__ = [
"HasUUID",
"MealieModel",
"SearchType",
]

View File

@ -1,16 +1,34 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol, TypeVar
from enum import Enum
from typing import ClassVar, Protocol, TypeVar
from humps.main import camelize
from pydantic import UUID4, BaseModel
from sqlalchemy import Select, desc, func, or_, text
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.interfaces import LoaderOption
from mealie.db.models._model_base import SqlAlchemyBase
T = TypeVar("T", bound=BaseModel)
class SearchType(Enum):
fuzzy = "fuzzy"
tokenized = "tokenized"
class MealieModel(BaseModel):
_fuzzy_similarity_threshold: ClassVar[float] = 0.5
_normalize_search: ClassVar[bool] = False
_searchable_properties: ClassVar[list[str]] = []
"""
Searchable properties for the search API.
The first property will be used for sorting (order_by)
"""
class Config:
alias_generator = camelize
allow_population_by_field_name = True
@ -59,6 +77,40 @@ class MealieModel(BaseModel):
def loader_options(cls) -> list[LoaderOption]:
return []
@classmethod
def filter_search_query(
cls,
db_model: type[SqlAlchemyBase],
query: Select,
session: Session,
search_type: SearchType,
search: str,
search_list: list[str],
) -> Select:
"""
Filters a search query based on model attributes
Can be overridden to support a more advanced search
"""
if not cls._searchable_properties:
raise AttributeError("Not Implemented")
model_properties: list[InstrumentedAttribute] = [getattr(db_model, prop) for prop in cls._searchable_properties]
if search_type is SearchType.fuzzy:
session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};"))
filters = [prop.op("%>")(search) for prop in model_properties]
# trigram ordering by the first searchable property
return query.filter(or_(*filters)).order_by(func.least(model_properties[0].op("<->>")(search)))
else:
filters = []
for prop in model_properties:
filters.extend([prop.like(f"%{s}%") for s in search_list])
# order by how close the result is to the first searchable property
return query.filter(or_(*filters)).order_by(desc(model_properties[0].like(f"%{search}%")))
class HasUUID(Protocol):
id: UUID4

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from typing import ClassVar
from pydantic import UUID4
from mealie.schema._mealie import MealieModel
@ -20,7 +22,7 @@ class MultiPurposeLabelUpdate(MultiPurposeLabelSave):
class MultiPurposeLabelSummary(MultiPurposeLabelUpdate):
pass
_searchable_properties: ClassVar[list[str]] = ["name"]
class Config:
orm_mode = True
@ -31,14 +33,5 @@ class MultiPurposeLabelPagination(PaginationBase):
class MultiPurposeLabelOut(MultiPurposeLabelUpdate):
# shopping_list_items: list[ShoppingListItemOut] = []
# foods: list[IngredientFood] = []
class Config:
orm_mode = True
# from mealie.schema.recipe.recipe_ingredient import IngredientFood
# from mealie.schema.group.group_shopping_list import ShoppingListItemOut
# MultiPurposeLabelOut.update_forward_refs()

View File

@ -2,16 +2,17 @@ from __future__ import annotations
import datetime
from pathlib import Path
from typing import Any
from typing import Any, ClassVar
from uuid import uuid4
from pydantic import UUID4, BaseModel, Field, validator
from slugify import slugify
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy import Select, desc, func, or_, select, text
from sqlalchemy.orm import Session, joinedload, selectinload
from sqlalchemy.orm.interfaces import LoaderOption
from mealie.core.config import get_app_dirs
from mealie.schema._mealie import MealieModel
from mealie.schema._mealie import MealieModel, SearchType
from mealie.schema.response.pagination import PaginationBase
from ...db.models.recipe import (
@ -37,6 +38,8 @@ class RecipeTag(MealieModel):
name: str
slug: str
_searchable_properties: ClassVar[list[str]] = ["name"]
class Config:
orm_mode = True
@ -78,6 +81,7 @@ class CreateRecipe(MealieModel):
class RecipeSummary(MealieModel):
id: UUID4 | None
_normalize_search: ClassVar[bool] = True
user_id: UUID4 = Field(default_factory=uuid4)
group_id: UUID4 = Field(default_factory=uuid4)
@ -259,6 +263,69 @@ class Recipe(RecipeSummary):
selectinload(RecipeModel.notes),
]
@classmethod
def filter_search_query(
cls, db_model, query: Select, session: Session, search_type: SearchType, search: str, search_list: list[str]
) -> Select:
"""
1. token search looks for any individual exact hit in name, description, and ingredients
2. fuzzy search looks for trigram hits in name, description, and ingredients
3. Sort order is determined by closeness to the recipe name
Should search also look at tags?
"""
if search_type is SearchType.fuzzy:
# I would prefer to just do this in the recipe_ingredient.any part of the main query,
# but it turns out that at least sqlite wont use indexes for that correctly anymore and
# takes a big hit, so prefiltering it is
ingredient_ids = (
session.execute(
select(RecipeIngredientModel.id).filter(
or_(
RecipeIngredientModel.note_normalized.op("%>")(search),
RecipeIngredientModel.original_text_normalized.op("%>")(search),
)
)
)
.scalars()
.all()
)
session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};"))
return query.filter(
or_(
RecipeModel.name_normalized.op("%>")(search),
RecipeModel.description_normalized.op("%>")(search),
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands.
func.least(
RecipeModel.name_normalized.op("<->>")(search),
)
)
else:
ingredient_ids = (
session.execute(
select(RecipeIngredientModel.id).filter(
or_(
*[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in search_list],
*[RecipeIngredientModel.original_text_normalized.like(f"%{ns}%") for ns in search_list],
)
)
)
.scalars()
.all()
)
return query.filter(
or_(
*[RecipeModel.name_normalized.like(f"%{ns}%") for ns in search_list],
*[RecipeModel.description_normalized.like(f"%{ns}%") for ns in search_list],
RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)),
)
).order_by(desc(RecipeModel.name_normalized.like(f"%{search}%")))
class RecipeLastMade(BaseModel):
timestamp: datetime.datetime

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import datetime
import enum
from fractions import Fraction
from typing import ClassVar
from uuid import UUID, uuid4
from pydantic import UUID4, Field, validator
@ -50,6 +51,8 @@ class IngredientFood(CreateIngredientFood):
created_at: datetime.datetime | None
update_at: datetime.datetime | None
_searchable_properties: ClassVar[list[str]] = ["name", "description"]
class Config:
orm_mode = True
getter_dict = ExtrasGetterDict
@ -78,6 +81,8 @@ class IngredientUnit(CreateIngredientUnit):
created_at: datetime.datetime | None
update_at: datetime.datetime | None
_searchable_properties: ClassVar[list[str]] = ["name", "abbreviation", "description"]
class Config:
orm_mode = True

View File

@ -0,0 +1,67 @@
import re
from sqlalchemy import Select
from sqlalchemy.orm import Session
from text_unidecode import unidecode
from ...db.models._model_base import SqlAlchemyBase
from .._mealie import MealieModel, SearchType
class SearchFilter:
"""
0. fuzzy search (postgres only) and tokenized search are performed separately
1. take search string and do a little pre-normalization
2. look for internal quoted strings and keep them together as "literal" parts of the search
3. remove special characters from each non-literal search string
"""
punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""")
remove_quotes_regex = re.compile(r"""['"](.*)['"]""")
@classmethod
def _normalize_search(cls, search: str, normalize_characters: bool) -> str:
search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation)))
if normalize_characters:
search = unidecode(search).lower().strip()
else:
search = search.strip()
return search
@classmethod
def _build_search_list(cls, search: str) -> list[str]:
if cls.quoted_regex.search(search):
# all quoted strings
quoted_search_list = [match.group() for match in cls.quoted_regex.finditer(search)]
# remove outer quotes
quoted_search_list = [cls.remove_quotes_regex.sub("\\1", x) for x in quoted_search_list]
# punctuation->spaces for splitting, but only on unquoted strings
search = cls.quoted_regex.sub("", search) # remove all quoted strings, leaving just non-quoted
search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation)))
# all unquoted strings
unquoted_search_list = search.split()
search_list = quoted_search_list + unquoted_search_list
else:
search_list = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation))).split()
# remove padding whitespace inside quotes
return [x.strip() for x in search_list]
def __init__(self, session: Session, search: str, normalize_characters: bool = False) -> None:
if session.get_bind().name != "postgresql" or self.quoted_regex.search(search.strip()):
self.search_type = SearchType.tokenized
else:
self.search_type = SearchType.fuzzy
self.session = session
self.search = self._normalize_search(search, normalize_characters)
self.search_list = self._build_search_list(self.search)
def filter_query_by_search(self, query: Select, schema: type[MealieModel], model: type[SqlAlchemyBase]) -> Select:
return schema.filter_search_query(model, query, self.session, self.search_type, self.search, self.search_list)

View File

@ -1,17 +1,108 @@
from datetime import datetime
from typing import cast
import pytest
from mealie.repos.repository_factory import AllRepositories
from mealie.repos.repository_recipes import RepositoryRecipes
from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood, RecipeStep
from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood
from mealie.schema.recipe.recipe import Recipe, RecipeCategory, RecipeSummary
from mealie.schema.recipe.recipe_category import CategoryOut, CategorySave, TagSave
from mealie.schema.recipe.recipe_tool import RecipeToolSave
from mealie.schema.response import OrderDirection, PaginationQuery
from tests.utils.factories import random_string
from mealie.schema.user.user import GroupBase
from tests.utils.factories import random_email, random_string
from tests.utils.fixture_schemas import TestUser
@pytest.fixture()
def unique_local_group_id(database: AllRepositories) -> str:
return str(database.groups.create(GroupBase(name=random_string())).id)
@pytest.fixture()
def unique_local_user_id(database: AllRepositories, unique_local_group_id: str) -> str:
return str(
database.users.create(
{
"username": random_string(),
"email": random_email(),
"group_id": unique_local_group_id,
"full_name": random_string(),
"password": random_string(),
"admin": False,
}
).id
)
@pytest.fixture()
def search_recipes(database: AllRepositories, unique_local_group_id: str, unique_local_user_id: str) -> list[Recipe]:
recipes = [
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name="Steinbock Sloop",
description=f"My favorite horns are delicious",
recipe_ingredient=[
RecipeIngredient(note="alpine animal"),
],
),
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name="Fiddlehead Fern Stir Fry",
recipe_ingredient=[
RecipeIngredient(note="moss"),
],
),
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name="Animal Sloop",
),
# Test diacritics
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name="Rátàtôuile",
),
# Add a bunch of recipes for stable randomization
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_local_user_id,
group_id=unique_local_group_id,
name=f"{random_string(10)} soup",
),
]
return database.recipes.create_many(recipes)
def test_recipe_repo_get_by_categories_basic(database: AllRepositories, unique_user: TestUser):
# Bootstrap the database with categories
slug1, slug2, slug3 = (random_string(10) for _ in range(3))
@ -112,7 +203,7 @@ def test_recipe_repo_get_by_categories_multi(database: AllRepositories, unique_u
database.recipes.create(recipe)
# Get all recipes by both categories
repo: RepositoryRecipes = database.recipes.by_group(unique_user.group_id) # type: ignore
repo: RepositoryRecipes = database.recipes.by_group(unique_local_group_id) # type: ignore
by_category = repo.get_by_categories(cast(list[RecipeCategory], created_categories))
assert len(by_category) == 10
@ -490,129 +581,72 @@ def test_recipe_repo_pagination_by_foods(database: AllRepositories, unique_user:
assert not all(i == random_ordered[0] for i in random_ordered)
def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser):
recipes = [
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name="Steinbock Sloop",
description=f"My favorite horns are delicious",
recipe_ingredient=[
RecipeIngredient(note="alpine animal"),
@pytest.mark.parametrize(
"search, expected_names",
[
(random_string(), []),
("Steinbock", ["Steinbock Sloop"]),
("horns", ["Steinbock Sloop"]),
("moss", ["Fiddlehead Fern Stir Fry"]),
('"Animal Sloop"', ["Animal Sloop"]),
("animal-sloop", ["Animal Sloop"]),
("ratat", ["Rátàtôuile"]),
("delicious horns", ["Steinbock Sloop"]),
],
),
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name="Fiddlehead Fern Stir Fry",
recipe_ingredient=[
RecipeIngredient(note="moss"),
ids=[
"no_match",
"search_by_title",
"search_by_description",
"search_by_ingredient",
"literal_search",
"special_character_removal",
"normalization",
"token_separation",
],
),
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name="Animal Sloop",
),
# Test diacritics
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name="Rátàtôuile",
),
# Add a bunch of recipes for stable randomization
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=f"{random_string(10)} soup",
),
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=f"{random_string(10)} soup",
),
]
)
def test_basic_recipe_search(
search: str,
expected_names: list[str],
database: AllRepositories,
search_recipes: list[Recipe], # required so database is populated
unique_local_group_id: str,
):
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
results = repo.page_all(pagination, search=search).items
for recipe in recipes:
database.recipes.create(recipe)
if len(expected_names) == 0:
assert len(results) == 0
else:
# if more results are returned, that's acceptable, as long as they are ranked correctly
assert len(results) >= len(expected_names)
for recipe, name in zip(results, expected_names, strict=False):
assert recipe.name == name
pagination_query = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
# No hits
empty_result = database.recipes.page_all(pagination_query, search=random_string(10)).items
assert len(empty_result) == 0
def test_fuzzy_recipe_search(
database: AllRepositories,
search_recipes: list[Recipe], # required so database is populated
unique_local_group_id: str,
):
# this only works on postgres
if database.session.get_bind().name != "postgresql":
return
# Search by title
title_result = database.recipes.page_all(pagination_query, search="Steinbock").items
assert len(title_result) == 1
assert title_result[0].name == "Steinbock Sloop"
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
results = repo.page_all(pagination, search="Steinbuck").items
# Search by description
description_result = database.recipes.page_all(pagination_query, search="horns").items
assert len(description_result) == 1
assert description_result[0].name == "Steinbock Sloop"
assert results and results[0].name == "Steinbock Sloop"
# Search by ingredient
ingredient_result = database.recipes.page_all(pagination_query, search="moss").items
assert len(ingredient_result) == 1
assert ingredient_result[0].name == "Fiddlehead Fern Stir Fry"
# Make sure title matches are ordered in front
ordered_result = database.recipes.page_all(pagination_query, search="animal sloop").items
assert len(ordered_result) == 2
assert ordered_result[0].name == "Animal Sloop"
assert ordered_result[1].name == "Steinbock Sloop"
# Test literal search
literal_result = database.recipes.page_all(pagination_query, search='"Animal Sloop"').items
assert len(literal_result) == 1
assert literal_result[0].name == "Animal Sloop"
# Test special character removal from non-literal searches
character_result = database.recipes.page_all(pagination_query, search="animal-sloop").items
assert len(character_result) == 2
assert character_result[0].name == "Animal Sloop"
assert character_result[1].name == "Steinbock Sloop"
# Test string normalization
normalized_result = database.recipes.page_all(pagination_query, search="ratat").items
print([r.name for r in normalized_result])
assert len(normalized_result) == 1
assert normalized_result[0].name == "Rátàtôuile"
# Test token separation
token_result = database.recipes.page_all(pagination_query, search="delicious horns").items
assert len(token_result) == 1
assert token_result[0].name == "Steinbock Sloop"
# Test fuzzy search
if database.session.get_bind().name == "postgresql":
fuzzy_result = database.recipes.page_all(pagination_query, search="Steinbuck").items
assert len(fuzzy_result) == 1
assert fuzzy_result[0].name == "Steinbock Sloop"
# Test random ordering with search
pagination_query = PaginationQuery(
def test_random_order_recipe_search(
database: AllRepositories,
search_recipes: list[Recipe], # required so database is populated
unique_local_group_id: str,
):
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
pagination = PaginationQuery(
page=1,
per_page=-1,
order_by="random",
@ -620,7 +654,7 @@ def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser):
order_direction=OrderDirection.asc,
)
random_ordered = []
for i in range(5):
pagination_query.pagination_seed = str(datetime.now())
random_ordered.append(database.recipes.page_all(pagination_query, search="soup").items)
for _ in range(5):
pagination.pagination_seed = str(datetime.now())
random_ordered.append(repo.page_all(pagination, search="soup").items)
assert not all(i == random_ordered[0] for i in random_ordered)

View File

@ -0,0 +1,135 @@
from datetime import datetime
import pytest
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.user.user import GroupBase
from tests.utils.factories import random_int, random_string
@pytest.fixture()
def unique_local_group_id(database: AllRepositories) -> str:
return str(database.groups.create(GroupBase(name=random_string())).id)
@pytest.fixture()
def search_units(database: AllRepositories, unique_local_group_id: str) -> list[IngredientUnit]:
units = [
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Tea Spoon",
abbreviation="tsp",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Table Spoon",
description="unique description",
abbreviation="tbsp",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Cup",
description="A bucket that's full",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Píñch",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Unit with a very cool name",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Unit with a pretty cool name",
),
]
# Add a bunch of units for stable randomization
units.extend(
[
SaveIngredientUnit(group_id=unique_local_group_id, name=f"{random_string()} unit")
for _ in range(random_int(12, 20))
]
)
return database.ingredient_units.create_many(units)
@pytest.mark.parametrize(
"search, expected_names",
[
(random_string(), []),
("Cup", ["Cup"]),
("tbsp", ["Table Spoon"]),
("unique description", ["Table Spoon"]),
("very cool name", ["Unit with a very cool name", "Unit with a pretty cool name"]),
('"Tea Spoon"', ["Tea Spoon"]),
("full bucket", ["Cup"]),
],
ids=[
"no_match",
"search_by_name",
"search_by_unit",
"search_by_description",
"match_order",
"literal_search",
"token_separation",
],
)
def test_basic_search(
search: str,
expected_names: list[str],
database: AllRepositories,
search_units: list[IngredientUnit], # required so database is populated
unique_local_group_id: str,
):
repo = database.ingredient_units.by_group(unique_local_group_id)
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
results = repo.page_all(pagination, search=search).items
if len(expected_names) == 0:
assert len(results) == 0
else:
# if more results are returned, that's acceptable, as long as they are ranked correctly
assert len(results) >= len(expected_names)
for unit, name in zip(results, expected_names, strict=False):
assert unit.name == name
def test_fuzzy_search(
database: AllRepositories,
search_units: list[IngredientUnit], # required so database is populated
unique_local_group_id: str,
):
# this only works on postgres
if database.session.get_bind().name != "postgresql":
return
repo = database.ingredient_units.by_group(unique_local_group_id)
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
results = repo.page_all(pagination, search="unique decsription").items
assert results and results[0].name == "Table Spoon"
def test_random_order_search(
database: AllRepositories,
search_units: list[IngredientUnit], # required so database is populated
unique_local_group_id: str,
):
repo = database.ingredient_units.by_group(unique_local_group_id)
pagination = PaginationQuery(
page=1,
per_page=-1,
order_by="random",
pagination_seed=str(datetime.now()),
order_direction=OrderDirection.asc,
)
random_ordered = []
for _ in range(5):
pagination.pagination_seed = str(datetime.now())
random_ordered.append(repo.page_all(pagination, search="unit").items)
assert not all(i == random_ordered[0] for i in random_ordered)