From da11204cd74e088b6fdc828fe51e1a9083d31f7e Mon Sep 17 00:00:00 2001 From: Michael Genson <71845777+michael-genson@users.noreply.github.com> Date: Fri, 28 Jun 2024 05:03:23 -0500 Subject: [PATCH] feat: Auto-label new shopping list items (#3800) --- .../services/group_services/shopping_lists.py | 43 ++++++--- mealie/services/parser_services/_base.py | 89 +++++++++++-------- .../services/parser_services/brute/process.py | 2 +- .../services/parser_services/openai/parser.py | 4 +- .../test_group_shopping_list_items.py | 74 +++++++++++++-- 5 files changed, 156 insertions(+), 56 deletions(-) diff --git a/mealie/services/group_services/shopping_lists.py b/mealie/services/group_services/shopping_lists.py index 09350b603786..38a4c7e6777a 100644 --- a/mealie/services/group_services/shopping_lists.py +++ b/mealie/services/group_services/shopping_lists.py @@ -17,12 +17,19 @@ from mealie.schema.group.group_shopping_list import ( ShoppingListMultiPurposeLabelCreate, ShoppingListSave, ) -from mealie.schema.recipe.recipe_ingredient import IngredientFood, IngredientUnit, RecipeIngredient +from mealie.schema.recipe.recipe_ingredient import ( + IngredientFood, + IngredientUnit, + RecipeIngredient, +) from mealie.schema.response.pagination import OrderDirection, PaginationQuery from mealie.schema.user.user import GroupInDB, UserOut +from mealie.services.parser_services._base import DataMatcher class ShoppingListService: + DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD = 80 + def __init__(self, repos: AllRepositories, group: GroupInDB, user: UserOut): self.repos = repos self.group = group @@ -31,6 +38,9 @@ class ShoppingListService: self.list_items = repos.group_shopping_list_item self.list_item_refs = repos.group_shopping_list_item_references self.list_refs = repos.group_shopping_list_recipe_refs + self.data_matcher = DataMatcher( + self.group.id, self.repos, food_fuzzy_match_threshold=self.DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD + ) @staticmethod def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool: @@ -108,7 +118,23 @@ class ShoppingListService: if list_refs_to_delete: self.list_refs.delete_many(list_refs_to_delete) - def bulk_create_items(self, create_items: list[ShoppingListItemCreate]) -> ShoppingListItemsCollectionOut: + def find_matching_label(self, item: ShoppingListItemBase) -> UUID4 | None: + if item.label_id: + return item.label_id + if item.food: + return item.food.label_id + + food_search = self.data_matcher.find_food_match(item.display) + return food_search.label_id if food_search else None + + def bulk_create_items( + self, create_items: list[ShoppingListItemCreate], auto_find_labels=True + ) -> ShoppingListItemsCollectionOut: + """ + Create a list of items, merging into existing ones where possible. + Optionally try to find a label for each item if one isn't provided using the item's food data or display name. + """ + # consolidate items to be created consolidated_create_items: list[ShoppingListItemCreate] = [] for create_item in create_items: @@ -157,18 +183,13 @@ class ShoppingListService: if create_item.checked: # checked items should not have recipe references create_item.recipe_references = [] + if auto_find_labels: + create_item.label_id = self.find_matching_label(create_item) filtered_create_items.append(create_item) - created_items = cast( - list[ShoppingListItemOut], - self.list_items.create_many(filtered_create_items) if filtered_create_items else [], # type: ignore - ) - - updated_items = cast( - list[ShoppingListItemOut], - self.list_items.update_many(update_items) if update_items else [], # type: ignore - ) + created_items = self.list_items.create_many(filtered_create_items) if filtered_create_items else [] + updated_items = self.list_items.update_many(update_items) if update_items else [] for list_id in set(item.shopping_list_id for item in created_items + updated_items): self.remove_unused_recipe_references(list_id) diff --git a/mealie/services/parser_services/_base.py b/mealie/services/parser_services/_base.py index 4be2492ee4a4..fdd242e1c696 100644 --- a/mealie/services/parser_services/_base.py +++ b/mealie/services/parser_services/_base.py @@ -20,26 +20,26 @@ from mealie.schema.response.pagination import PaginationQuery T = TypeVar("T", bound=BaseModel) -class ABCIngredientParser(ABC): - """ - Abstract class for ingredient parsers. - """ - - def __init__(self, group_id: UUID4, session: Session) -> None: +class DataMatcher: + def __init__( + self, + group_id: UUID4, + repos: AllRepositories, + food_fuzzy_match_threshold: int = 85, + unit_fuzzy_match_threshold: int = 70, + ) -> None: self.group_id = group_id - self.session = session + self.repos = repos + self._food_fuzzy_match_threshold = food_fuzzy_match_threshold + self._unit_fuzzy_match_threshold = unit_fuzzy_match_threshold self._foods_by_alias: dict[str, IngredientFood] | None = None self._units_by_alias: dict[str, IngredientUnit] | None = None - @property - def _repos(self) -> AllRepositories: - return get_repositories(self.session) - @property def foods_by_alias(self) -> dict[str, IngredientFood]: if self._foods_by_alias is None: - foods_repo = self._repos.ingredient_foods.by_group(self.group_id) + foods_repo = self.repos.ingredient_foods.by_group(self.group_id) query = PaginationQuery(page=1, per_page=-1) all_foods = foods_repo.page_all(query).items @@ -61,7 +61,7 @@ class ABCIngredientParser(ABC): @property def units_by_alias(self) -> dict[str, IngredientUnit]: if self._units_by_alias is None: - units_repo = self._repos.ingredient_units.by_group(self.group_id) + units_repo = self.repos.ingredient_units.by_group(self.group_id) query = PaginationQuery(page=1, per_page=-1) all_units = units_repo.page_all(query).items @@ -84,24 +84,6 @@ class ABCIngredientParser(ABC): return self._units_by_alias - @property - def food_fuzzy_match_threshold(self) -> int: - """Minimum threshold to fuzzy match against a database food search""" - - return 85 - - @property - def unit_fuzzy_match_threshold(self) -> int: - """Minimum threshold to fuzzy match against a database unit search""" - - return 70 - - @abstractmethod - async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ... - - @abstractmethod - async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ... - @classmethod def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None: # check for literal matches @@ -126,7 +108,7 @@ class ABCIngredientParser(ABC): return self.find_match( match_value, store_map=self.foods_by_alias, - fuzzy_match_threshold=self.food_fuzzy_match_threshold, + fuzzy_match_threshold=self._food_fuzzy_match_threshold, ) def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit | str) -> IngredientUnit | None: @@ -138,21 +120,56 @@ class ABCIngredientParser(ABC): return self.find_match( match_value, store_map=self.units_by_alias, - fuzzy_match_threshold=self.unit_fuzzy_match_threshold, + fuzzy_match_threshold=self._unit_fuzzy_match_threshold, ) + +class ABCIngredientParser(ABC): + """ + Abstract class for ingredient parsers. + """ + + def __init__(self, group_id: UUID4, session: Session) -> None: + self.group_id = group_id + self.session = session + self.data_matcher = DataMatcher( + self.group_id, self._repos, self.food_fuzzy_match_threshold, self.unit_fuzzy_match_threshold + ) + + @property + def _repos(self) -> AllRepositories: + return get_repositories(self.session) + + @property + def food_fuzzy_match_threshold(self) -> int: + """Minimum threshold to fuzzy match against a database food search""" + + return 85 + + @property + def unit_fuzzy_match_threshold(self) -> int: + """Minimum threshold to fuzzy match against a database unit search""" + + return 70 + + @abstractmethod + async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ... + + @abstractmethod + async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ... + def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient: - if ingredient.ingredient.food and (food_match := self.find_food_match(ingredient.ingredient.food)): + if ingredient.ingredient.food and (food_match := self.data_matcher.find_food_match(ingredient.ingredient.food)): ingredient.ingredient.food = food_match - if ingredient.ingredient.unit and (unit_match := self.find_unit_match(ingredient.ingredient.unit)): + if ingredient.ingredient.unit and (unit_match := self.data_matcher.find_unit_match(ingredient.ingredient.unit)): ingredient.ingredient.unit = unit_match # Parser might have wrongly split a food into a unit and food. if isinstance(ingredient.ingredient.food, CreateIngredientFood) and isinstance( ingredient.ingredient.unit, CreateIngredientUnit ): - if food_match := self.find_food_match( + if food_match := self.data_matcher.find_food_match( f"{ingredient.ingredient.unit.name} {ingredient.ingredient.food.name}" ): ingredient.ingredient.food = food_match diff --git a/mealie/services/parser_services/brute/process.py b/mealie/services/parser_services/brute/process.py index b2d559f3cb51..8d7e7fa9916a 100644 --- a/mealie/services/parser_services/brute/process.py +++ b/mealie/services/parser_services/brute/process.py @@ -194,7 +194,7 @@ def parse(ing_str, parser) -> BruteParsedIngredient: # try to parse as unit and ingredient (e.g. "a tblsp salt"), with unit in first three tokens # won't work for units that have spaces for index, token in enumerate(tokens[:3]): - if parser.find_unit_match(token): + if parser.data_matcher.find_unit_match(token): unit = token ingredient, note = parse_ingredient(tokens[index + 1 :]) break diff --git a/mealie/services/parser_services/openai/parser.py b/mealie/services/parser_services/openai/parser.py index c63a5b056b51..386a8ccfecd6 100644 --- a/mealie/services/parser_services/openai/parser.py +++ b/mealie/services/parser_services/openai/parser.py @@ -45,7 +45,7 @@ class OpenAIParser(ABCIngredientParser): ), ] - if service.send_db_data and self.units_by_alias: + if service.send_db_data and self.data_matcher.units_by_alias: data_injections.extend( [ OpenAIDataInjection( @@ -55,7 +55,7 @@ class OpenAIParser(ABCIngredientParser): "find a unit in the input that does not exist in this list. This should not prevent " "you from parsing that text as a unit, however it may lower your confidence level." ), - value=list(set(self.units_by_alias)), + value=list(set(self.data_matcher.units_by_alias)), ), ] ) diff --git a/tests/integration_tests/user_group_tests/test_group_shopping_list_items.py b/tests/integration_tests/user_group_tests/test_group_shopping_list_items.py index 6f18edb9a81c..9d807adec323 100644 --- a/tests/integration_tests/user_group_tests/test_group_shopping_list_items.py +++ b/tests/integration_tests/user_group_tests/test_group_shopping_list_items.py @@ -2,22 +2,21 @@ import random from math import ceil, floor from uuid import uuid4 +import pytest from fastapi.testclient import TestClient from pydantic import UUID4 +from mealie.repos.repository_factory import AllRepositories from mealie.schema.group.group_shopping_list import ShoppingListItemOut, ShoppingListOut +from mealie.schema.recipe.recipe_ingredient import SaveIngredientFood from tests import utils from tests.utils import api_routes from tests.utils.factories import random_int, random_string from tests.utils.fixture_schemas import TestUser -def create_item(list_id: UUID4) -> dict: - return { - "shopping_list_id": str(list_id), - "note": random_string(10), - "quantity": random_int(1, 10), - } +def create_item(list_id: UUID4, **kwargs) -> dict: + return {"shopping_list_id": str(list_id), "note": random_string(10), "quantity": random_int(1, 10), **kwargs} def serialize_list_items(list_items: list[ShoppingListItemOut]) -> list: @@ -89,6 +88,69 @@ def test_shopping_list_items_create_many( assert not created_item_ids +def test_shopping_list_items_auto_assign_label_with_food_without_label( + api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories +): + food = database.ingredient_foods.create(SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id)) + + item = create_item(shopping_list.id, food_id=str(food.id)) + response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token) + as_json = utils.assert_derserialize(response, 201) + assert len(as_json["createdItems"]) == 1 + + item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0]) + assert item_out.label_id is None + assert item_out.label is None + + +def test_shopping_list_items_auto_assign_label_with_food_with_label( + api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories +): + label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id}) + food = database.ingredient_foods.create( + SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id, label_id=label.id) + ) + + item = create_item(shopping_list.id, food_id=str(food.id)) + response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token) + as_json = utils.assert_derserialize(response, 201) + assert len(as_json["createdItems"]) == 1 + + item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0]) + assert item_out.label_id == label.id + assert item_out.label + assert item_out.label.id == label.id + + +@pytest.mark.parametrize("use_fuzzy_name", [True, False]) +def test_shopping_list_items_auto_assign_label_with_food_search( + api_client: TestClient, + unique_user: TestUser, + shopping_list: ShoppingListOut, + database: AllRepositories, + use_fuzzy_name: bool, +): + label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id}) + food = database.ingredient_foods.create( + SaveIngredientFood(name=random_string(20), group_id=unique_user.group_id, label_id=label.id) + ) + + item = create_item(shopping_list.id) + name = food.name + if use_fuzzy_name: + name = name + random_string(2) + item["note"] = name + + response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token) + as_json = utils.assert_derserialize(response, 201) + assert len(as_json["createdItems"]) == 1 + + item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0]) + assert item_out.label_id == label.id + assert item_out.label + assert item_out.label.id == label.id + + def test_shopping_list_items_get_one( api_client: TestClient, unique_user: TestUser,