Michael Genson 2dfbe9f08d
feat: Improved Ingredient Matching (#2535)
* added normalization to foods and units

* changed search to reference new normalized fields

* fix tests

* added parsed food matching to backend

* prevent pagination from ordering when searching

* added extra fuzzy matching to sqlite ing matching

* added tests

* only apply search ordering when order_by is null

* enabled post-search fuzzy matching for postgres

* fixed postgres fuzzy search test

* idk why this is failing

* 🤦

* simplified frontend ing matching
and restored automatic unit creation

* tightened food fuzzy threshold

* change to rapidfuzz

* sped up fuzzy matching with process

* fixed units not matching by abbreviation

* fast return for exact matches

* replace db searching with pure fuzz

* added fuzzy normalization

* tightened unit fuzzy matching thresh

* cleaned up comments/var names

* ran matching logic through the dryer

* oops

* simplified order by application logic
2023-09-15 17:19:34 +00:00

136 lines
4.4 KiB
Python

from datetime import datetime
import pytest
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.user.user import GroupBase
from tests.utils.factories import random_int, random_string
@pytest.fixture()
def unique_local_group_id(database: AllRepositories) -> str:
return str(database.groups.create(GroupBase(name=random_string())).id)
@pytest.fixture()
def search_units(database: AllRepositories, unique_local_group_id: str) -> list[IngredientUnit]:
units = [
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Tea Spoon",
abbreviation="tsp",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Table Spoon",
abbreviation="tbsp",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Cup",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Píñch",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Unit with a very cool name",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Unit with a pretty cool name",
),
SaveIngredientUnit(
group_id=unique_local_group_id,
name="Unit with a correct horse battery staple",
),
]
# Add a bunch of units for stable randomization
units.extend(
[
SaveIngredientUnit(group_id=unique_local_group_id, name=f"{random_string()} unit")
for _ in range(random_int(12, 20))
]
)
return database.ingredient_units.create_many(units)
@pytest.mark.parametrize(
"search, expected_names",
[
(random_string(), []),
("Cup", ["Cup"]),
("tbsp", ["Table Spoon"]),
("very cool name", ["Unit with a very cool name", "Unit with a pretty cool name"]),
('"Tea Spoon"', ["Tea Spoon"]),
("correct staple", ["Unit with a correct horse battery staple"]),
],
ids=[
"no_match",
"search_by_name",
"search_by_unit",
"match_order",
"literal_search",
"token_separation",
],
)
def test_basic_search(
search: str,
expected_names: list[str],
database: AllRepositories,
search_units: list[IngredientUnit], # required so database is populated
unique_local_group_id: str,
):
repo = database.ingredient_units.by_group(unique_local_group_id)
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
results = repo.page_all(pagination, search=search).items
if len(expected_names) == 0:
assert len(results) == 0
else:
# if more results are returned, that's acceptable, as long as they are ranked correctly
assert len(results) >= len(expected_names)
for unit, name in zip(results, expected_names, strict=False):
assert unit.name == name
def test_fuzzy_search(
database: AllRepositories,
search_units: list[IngredientUnit], # required so database is populated
unique_local_group_id: str,
):
# this only works on postgres
if database.session.get_bind().name != "postgresql":
return
repo = database.ingredient_units.by_group(unique_local_group_id)
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
results = repo.page_all(pagination, search="tabel spoone").items
assert results and results[0].name == "Table Spoon"
def test_random_order_search(
database: AllRepositories,
search_units: list[IngredientUnit], # required so database is populated
unique_local_group_id: str,
):
repo = database.ingredient_units.by_group(unique_local_group_id)
pagination = PaginationQuery(
page=1,
per_page=-1,
order_by="random",
pagination_seed=str(datetime.now()),
order_direction=OrderDirection.asc,
)
random_ordered = []
for _ in range(5):
pagination.pagination_seed = str(datetime.now())
random_ordered.append(repo.page_all(pagination, search="unit").items)
assert not all(i == random_ordered[0] for i in random_ordered)