feat: Auto-label new shopping list items (#3800)

This commit is contained in:
Michael Genson 2024-06-28 05:03:23 -05:00 committed by GitHub
parent 9795b4c553
commit da11204cd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 156 additions and 56 deletions

View File

@ -17,12 +17,19 @@ from mealie.schema.group.group_shopping_list import (
ShoppingListMultiPurposeLabelCreate,
ShoppingListSave,
)
from mealie.schema.recipe.recipe_ingredient import IngredientFood, IngredientUnit, RecipeIngredient
from mealie.schema.recipe.recipe_ingredient import (
IngredientFood,
IngredientUnit,
RecipeIngredient,
)
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.user.user import GroupInDB, UserOut
from mealie.services.parser_services._base import DataMatcher
class ShoppingListService:
DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD = 80
def __init__(self, repos: AllRepositories, group: GroupInDB, user: UserOut):
self.repos = repos
self.group = group
@ -31,6 +38,9 @@ class ShoppingListService:
self.list_items = repos.group_shopping_list_item
self.list_item_refs = repos.group_shopping_list_item_references
self.list_refs = repos.group_shopping_list_recipe_refs
self.data_matcher = DataMatcher(
self.group.id, self.repos, food_fuzzy_match_threshold=self.DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD
)
@staticmethod
def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool:
@ -108,7 +118,23 @@ class ShoppingListService:
if list_refs_to_delete:
self.list_refs.delete_many(list_refs_to_delete)
def bulk_create_items(self, create_items: list[ShoppingListItemCreate]) -> ShoppingListItemsCollectionOut:
def find_matching_label(self, item: ShoppingListItemBase) -> UUID4 | None:
if item.label_id:
return item.label_id
if item.food:
return item.food.label_id
food_search = self.data_matcher.find_food_match(item.display)
return food_search.label_id if food_search else None
def bulk_create_items(
self, create_items: list[ShoppingListItemCreate], auto_find_labels=True
) -> ShoppingListItemsCollectionOut:
"""
Create a list of items, merging into existing ones where possible.
Optionally try to find a label for each item if one isn't provided using the item's food data or display name.
"""
# consolidate items to be created
consolidated_create_items: list[ShoppingListItemCreate] = []
for create_item in create_items:
@ -157,18 +183,13 @@ class ShoppingListService:
if create_item.checked:
# checked items should not have recipe references
create_item.recipe_references = []
if auto_find_labels:
create_item.label_id = self.find_matching_label(create_item)
filtered_create_items.append(create_item)
created_items = cast(
list[ShoppingListItemOut],
self.list_items.create_many(filtered_create_items) if filtered_create_items else [], # type: ignore
)
updated_items = cast(
list[ShoppingListItemOut],
self.list_items.update_many(update_items) if update_items else [], # type: ignore
)
created_items = self.list_items.create_many(filtered_create_items) if filtered_create_items else []
updated_items = self.list_items.update_many(update_items) if update_items else []
for list_id in set(item.shopping_list_id for item in created_items + updated_items):
self.remove_unused_recipe_references(list_id)

View File

@ -20,26 +20,26 @@ from mealie.schema.response.pagination import PaginationQuery
T = TypeVar("T", bound=BaseModel)
class ABCIngredientParser(ABC):
"""
Abstract class for ingredient parsers.
"""
def __init__(self, group_id: UUID4, session: Session) -> None:
class DataMatcher:
def __init__(
self,
group_id: UUID4,
repos: AllRepositories,
food_fuzzy_match_threshold: int = 85,
unit_fuzzy_match_threshold: int = 70,
) -> None:
self.group_id = group_id
self.session = session
self.repos = repos
self._food_fuzzy_match_threshold = food_fuzzy_match_threshold
self._unit_fuzzy_match_threshold = unit_fuzzy_match_threshold
self._foods_by_alias: dict[str, IngredientFood] | None = None
self._units_by_alias: dict[str, IngredientUnit] | None = None
@property
def _repos(self) -> AllRepositories:
return get_repositories(self.session)
@property
def foods_by_alias(self) -> dict[str, IngredientFood]:
if self._foods_by_alias is None:
foods_repo = self._repos.ingredient_foods.by_group(self.group_id)
foods_repo = self.repos.ingredient_foods.by_group(self.group_id)
query = PaginationQuery(page=1, per_page=-1)
all_foods = foods_repo.page_all(query).items
@ -61,7 +61,7 @@ class ABCIngredientParser(ABC):
@property
def units_by_alias(self) -> dict[str, IngredientUnit]:
if self._units_by_alias is None:
units_repo = self._repos.ingredient_units.by_group(self.group_id)
units_repo = self.repos.ingredient_units.by_group(self.group_id)
query = PaginationQuery(page=1, per_page=-1)
all_units = units_repo.page_all(query).items
@ -84,24 +84,6 @@ class ABCIngredientParser(ABC):
return self._units_by_alias
@property
def food_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database food search"""
return 85
@property
def unit_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database unit search"""
return 70
@abstractmethod
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
@abstractmethod
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
@classmethod
def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None:
# check for literal matches
@ -126,7 +108,7 @@ class ABCIngredientParser(ABC):
return self.find_match(
match_value,
store_map=self.foods_by_alias,
fuzzy_match_threshold=self.food_fuzzy_match_threshold,
fuzzy_match_threshold=self._food_fuzzy_match_threshold,
)
def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit | str) -> IngredientUnit | None:
@ -138,21 +120,56 @@ class ABCIngredientParser(ABC):
return self.find_match(
match_value,
store_map=self.units_by_alias,
fuzzy_match_threshold=self.unit_fuzzy_match_threshold,
fuzzy_match_threshold=self._unit_fuzzy_match_threshold,
)
class ABCIngredientParser(ABC):
"""
Abstract class for ingredient parsers.
"""
def __init__(self, group_id: UUID4, session: Session) -> None:
self.group_id = group_id
self.session = session
self.data_matcher = DataMatcher(
self.group_id, self._repos, self.food_fuzzy_match_threshold, self.unit_fuzzy_match_threshold
)
@property
def _repos(self) -> AllRepositories:
return get_repositories(self.session)
@property
def food_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database food search"""
return 85
@property
def unit_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database unit search"""
return 70
@abstractmethod
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
@abstractmethod
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient:
if ingredient.ingredient.food and (food_match := self.find_food_match(ingredient.ingredient.food)):
if ingredient.ingredient.food and (food_match := self.data_matcher.find_food_match(ingredient.ingredient.food)):
ingredient.ingredient.food = food_match
if ingredient.ingredient.unit and (unit_match := self.find_unit_match(ingredient.ingredient.unit)):
if ingredient.ingredient.unit and (unit_match := self.data_matcher.find_unit_match(ingredient.ingredient.unit)):
ingredient.ingredient.unit = unit_match
# Parser might have wrongly split a food into a unit and food.
if isinstance(ingredient.ingredient.food, CreateIngredientFood) and isinstance(
ingredient.ingredient.unit, CreateIngredientUnit
):
if food_match := self.find_food_match(
if food_match := self.data_matcher.find_food_match(
f"{ingredient.ingredient.unit.name} {ingredient.ingredient.food.name}"
):
ingredient.ingredient.food = food_match

View File

@ -194,7 +194,7 @@ def parse(ing_str, parser) -> BruteParsedIngredient:
# try to parse as unit and ingredient (e.g. "a tblsp salt"), with unit in first three tokens
# won't work for units that have spaces
for index, token in enumerate(tokens[:3]):
if parser.find_unit_match(token):
if parser.data_matcher.find_unit_match(token):
unit = token
ingredient, note = parse_ingredient(tokens[index + 1 :])
break

View File

@ -45,7 +45,7 @@ class OpenAIParser(ABCIngredientParser):
),
]
if service.send_db_data and self.units_by_alias:
if service.send_db_data and self.data_matcher.units_by_alias:
data_injections.extend(
[
OpenAIDataInjection(
@ -55,7 +55,7 @@ class OpenAIParser(ABCIngredientParser):
"find a unit in the input that does not exist in this list. This should not prevent "
"you from parsing that text as a unit, however it may lower your confidence level."
),
value=list(set(self.units_by_alias)),
value=list(set(self.data_matcher.units_by_alias)),
),
]
)

View File

@ -2,22 +2,21 @@ import random
from math import ceil, floor
from uuid import uuid4
import pytest
from fastapi.testclient import TestClient
from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_shopping_list import ShoppingListItemOut, ShoppingListOut
from mealie.schema.recipe.recipe_ingredient import SaveIngredientFood
from tests import utils
from tests.utils import api_routes
from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser
def create_item(list_id: UUID4) -> dict:
return {
"shopping_list_id": str(list_id),
"note": random_string(10),
"quantity": random_int(1, 10),
}
def create_item(list_id: UUID4, **kwargs) -> dict:
return {"shopping_list_id": str(list_id), "note": random_string(10), "quantity": random_int(1, 10), **kwargs}
def serialize_list_items(list_items: list[ShoppingListItemOut]) -> list:
@ -89,6 +88,69 @@ def test_shopping_list_items_create_many(
assert not created_item_ids
def test_shopping_list_items_auto_assign_label_with_food_without_label(
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
):
food = database.ingredient_foods.create(SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id))
item = create_item(shopping_list.id, food_id=str(food.id))
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
as_json = utils.assert_derserialize(response, 201)
assert len(as_json["createdItems"]) == 1
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
assert item_out.label_id is None
assert item_out.label is None
def test_shopping_list_items_auto_assign_label_with_food_with_label(
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
):
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
food = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id, label_id=label.id)
)
item = create_item(shopping_list.id, food_id=str(food.id))
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
as_json = utils.assert_derserialize(response, 201)
assert len(as_json["createdItems"]) == 1
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
assert item_out.label_id == label.id
assert item_out.label
assert item_out.label.id == label.id
@pytest.mark.parametrize("use_fuzzy_name", [True, False])
def test_shopping_list_items_auto_assign_label_with_food_search(
api_client: TestClient,
unique_user: TestUser,
shopping_list: ShoppingListOut,
database: AllRepositories,
use_fuzzy_name: bool,
):
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
food = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(20), group_id=unique_user.group_id, label_id=label.id)
)
item = create_item(shopping_list.id)
name = food.name
if use_fuzzy_name:
name = name + random_string(2)
item["note"] = name
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
as_json = utils.assert_derserialize(response, 201)
assert len(as_json["createdItems"]) == 1
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
assert item_out.label_id == label.id
assert item_out.label
assert item_out.label.id == label.id
def test_shopping_list_items_get_one(
api_client: TestClient,
unique_user: TestUser,