feat: Auto-label new shopping list items (#3800)

This commit is contained in:
Michael Genson 2024-06-28 05:03:23 -05:00 committed by GitHub
parent 9795b4c553
commit da11204cd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 156 additions and 56 deletions

View File

@ -17,12 +17,19 @@ from mealie.schema.group.group_shopping_list import (
ShoppingListMultiPurposeLabelCreate, ShoppingListMultiPurposeLabelCreate,
ShoppingListSave, ShoppingListSave,
) )
from mealie.schema.recipe.recipe_ingredient import IngredientFood, IngredientUnit, RecipeIngredient from mealie.schema.recipe.recipe_ingredient import (
IngredientFood,
IngredientUnit,
RecipeIngredient,
)
from mealie.schema.response.pagination import OrderDirection, PaginationQuery from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.user.user import GroupInDB, UserOut from mealie.schema.user.user import GroupInDB, UserOut
from mealie.services.parser_services._base import DataMatcher
class ShoppingListService: class ShoppingListService:
DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD = 80
def __init__(self, repos: AllRepositories, group: GroupInDB, user: UserOut): def __init__(self, repos: AllRepositories, group: GroupInDB, user: UserOut):
self.repos = repos self.repos = repos
self.group = group self.group = group
@ -31,6 +38,9 @@ class ShoppingListService:
self.list_items = repos.group_shopping_list_item self.list_items = repos.group_shopping_list_item
self.list_item_refs = repos.group_shopping_list_item_references self.list_item_refs = repos.group_shopping_list_item_references
self.list_refs = repos.group_shopping_list_recipe_refs self.list_refs = repos.group_shopping_list_recipe_refs
self.data_matcher = DataMatcher(
self.group.id, self.repos, food_fuzzy_match_threshold=self.DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD
)
@staticmethod @staticmethod
def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool: def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool:
@ -108,7 +118,23 @@ class ShoppingListService:
if list_refs_to_delete: if list_refs_to_delete:
self.list_refs.delete_many(list_refs_to_delete) self.list_refs.delete_many(list_refs_to_delete)
def bulk_create_items(self, create_items: list[ShoppingListItemCreate]) -> ShoppingListItemsCollectionOut: def find_matching_label(self, item: ShoppingListItemBase) -> UUID4 | None:
if item.label_id:
return item.label_id
if item.food:
return item.food.label_id
food_search = self.data_matcher.find_food_match(item.display)
return food_search.label_id if food_search else None
def bulk_create_items(
self, create_items: list[ShoppingListItemCreate], auto_find_labels=True
) -> ShoppingListItemsCollectionOut:
"""
Create a list of items, merging into existing ones where possible.
Optionally try to find a label for each item if one isn't provided using the item's food data or display name.
"""
# consolidate items to be created # consolidate items to be created
consolidated_create_items: list[ShoppingListItemCreate] = [] consolidated_create_items: list[ShoppingListItemCreate] = []
for create_item in create_items: for create_item in create_items:
@ -157,18 +183,13 @@ class ShoppingListService:
if create_item.checked: if create_item.checked:
# checked items should not have recipe references # checked items should not have recipe references
create_item.recipe_references = [] create_item.recipe_references = []
if auto_find_labels:
create_item.label_id = self.find_matching_label(create_item)
filtered_create_items.append(create_item) filtered_create_items.append(create_item)
created_items = cast( created_items = self.list_items.create_many(filtered_create_items) if filtered_create_items else []
list[ShoppingListItemOut], updated_items = self.list_items.update_many(update_items) if update_items else []
self.list_items.create_many(filtered_create_items) if filtered_create_items else [], # type: ignore
)
updated_items = cast(
list[ShoppingListItemOut],
self.list_items.update_many(update_items) if update_items else [], # type: ignore
)
for list_id in set(item.shopping_list_id for item in created_items + updated_items): for list_id in set(item.shopping_list_id for item in created_items + updated_items):
self.remove_unused_recipe_references(list_id) self.remove_unused_recipe_references(list_id)

View File

@ -20,26 +20,26 @@ from mealie.schema.response.pagination import PaginationQuery
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
class ABCIngredientParser(ABC): class DataMatcher:
""" def __init__(
Abstract class for ingredient parsers. self,
""" group_id: UUID4,
repos: AllRepositories,
def __init__(self, group_id: UUID4, session: Session) -> None: food_fuzzy_match_threshold: int = 85,
unit_fuzzy_match_threshold: int = 70,
) -> None:
self.group_id = group_id self.group_id = group_id
self.session = session self.repos = repos
self._food_fuzzy_match_threshold = food_fuzzy_match_threshold
self._unit_fuzzy_match_threshold = unit_fuzzy_match_threshold
self._foods_by_alias: dict[str, IngredientFood] | None = None self._foods_by_alias: dict[str, IngredientFood] | None = None
self._units_by_alias: dict[str, IngredientUnit] | None = None self._units_by_alias: dict[str, IngredientUnit] | None = None
@property
def _repos(self) -> AllRepositories:
return get_repositories(self.session)
@property @property
def foods_by_alias(self) -> dict[str, IngredientFood]: def foods_by_alias(self) -> dict[str, IngredientFood]:
if self._foods_by_alias is None: if self._foods_by_alias is None:
foods_repo = self._repos.ingredient_foods.by_group(self.group_id) foods_repo = self.repos.ingredient_foods.by_group(self.group_id)
query = PaginationQuery(page=1, per_page=-1) query = PaginationQuery(page=1, per_page=-1)
all_foods = foods_repo.page_all(query).items all_foods = foods_repo.page_all(query).items
@ -61,7 +61,7 @@ class ABCIngredientParser(ABC):
@property @property
def units_by_alias(self) -> dict[str, IngredientUnit]: def units_by_alias(self) -> dict[str, IngredientUnit]:
if self._units_by_alias is None: if self._units_by_alias is None:
units_repo = self._repos.ingredient_units.by_group(self.group_id) units_repo = self.repos.ingredient_units.by_group(self.group_id)
query = PaginationQuery(page=1, per_page=-1) query = PaginationQuery(page=1, per_page=-1)
all_units = units_repo.page_all(query).items all_units = units_repo.page_all(query).items
@ -84,24 +84,6 @@ class ABCIngredientParser(ABC):
return self._units_by_alias return self._units_by_alias
@property
def food_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database food search"""
return 85
@property
def unit_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database unit search"""
return 70
@abstractmethod
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
@abstractmethod
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
@classmethod @classmethod
def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None: def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None:
# check for literal matches # check for literal matches
@ -126,7 +108,7 @@ class ABCIngredientParser(ABC):
return self.find_match( return self.find_match(
match_value, match_value,
store_map=self.foods_by_alias, store_map=self.foods_by_alias,
fuzzy_match_threshold=self.food_fuzzy_match_threshold, fuzzy_match_threshold=self._food_fuzzy_match_threshold,
) )
def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit | str) -> IngredientUnit | None: def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit | str) -> IngredientUnit | None:
@ -138,21 +120,56 @@ class ABCIngredientParser(ABC):
return self.find_match( return self.find_match(
match_value, match_value,
store_map=self.units_by_alias, store_map=self.units_by_alias,
fuzzy_match_threshold=self.unit_fuzzy_match_threshold, fuzzy_match_threshold=self._unit_fuzzy_match_threshold,
) )
class ABCIngredientParser(ABC):
"""
Abstract class for ingredient parsers.
"""
def __init__(self, group_id: UUID4, session: Session) -> None:
self.group_id = group_id
self.session = session
self.data_matcher = DataMatcher(
self.group_id, self._repos, self.food_fuzzy_match_threshold, self.unit_fuzzy_match_threshold
)
@property
def _repos(self) -> AllRepositories:
return get_repositories(self.session)
@property
def food_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database food search"""
return 85
@property
def unit_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database unit search"""
return 70
@abstractmethod
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
@abstractmethod
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient: def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient:
if ingredient.ingredient.food and (food_match := self.find_food_match(ingredient.ingredient.food)): if ingredient.ingredient.food and (food_match := self.data_matcher.find_food_match(ingredient.ingredient.food)):
ingredient.ingredient.food = food_match ingredient.ingredient.food = food_match
if ingredient.ingredient.unit and (unit_match := self.find_unit_match(ingredient.ingredient.unit)): if ingredient.ingredient.unit and (unit_match := self.data_matcher.find_unit_match(ingredient.ingredient.unit)):
ingredient.ingredient.unit = unit_match ingredient.ingredient.unit = unit_match
# Parser might have wrongly split a food into a unit and food. # Parser might have wrongly split a food into a unit and food.
if isinstance(ingredient.ingredient.food, CreateIngredientFood) and isinstance( if isinstance(ingredient.ingredient.food, CreateIngredientFood) and isinstance(
ingredient.ingredient.unit, CreateIngredientUnit ingredient.ingredient.unit, CreateIngredientUnit
): ):
if food_match := self.find_food_match( if food_match := self.data_matcher.find_food_match(
f"{ingredient.ingredient.unit.name} {ingredient.ingredient.food.name}" f"{ingredient.ingredient.unit.name} {ingredient.ingredient.food.name}"
): ):
ingredient.ingredient.food = food_match ingredient.ingredient.food = food_match

View File

@ -194,7 +194,7 @@ def parse(ing_str, parser) -> BruteParsedIngredient:
# try to parse as unit and ingredient (e.g. "a tblsp salt"), with unit in first three tokens # try to parse as unit and ingredient (e.g. "a tblsp salt"), with unit in first three tokens
# won't work for units that have spaces # won't work for units that have spaces
for index, token in enumerate(tokens[:3]): for index, token in enumerate(tokens[:3]):
if parser.find_unit_match(token): if parser.data_matcher.find_unit_match(token):
unit = token unit = token
ingredient, note = parse_ingredient(tokens[index + 1 :]) ingredient, note = parse_ingredient(tokens[index + 1 :])
break break

View File

@ -45,7 +45,7 @@ class OpenAIParser(ABCIngredientParser):
), ),
] ]
if service.send_db_data and self.units_by_alias: if service.send_db_data and self.data_matcher.units_by_alias:
data_injections.extend( data_injections.extend(
[ [
OpenAIDataInjection( OpenAIDataInjection(
@ -55,7 +55,7 @@ class OpenAIParser(ABCIngredientParser):
"find a unit in the input that does not exist in this list. This should not prevent " "find a unit in the input that does not exist in this list. This should not prevent "
"you from parsing that text as a unit, however it may lower your confidence level." "you from parsing that text as a unit, however it may lower your confidence level."
), ),
value=list(set(self.units_by_alias)), value=list(set(self.data_matcher.units_by_alias)),
), ),
] ]
) )

View File

@ -2,22 +2,21 @@ import random
from math import ceil, floor from math import ceil, floor
from uuid import uuid4 from uuid import uuid4
import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import UUID4 from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_shopping_list import ShoppingListItemOut, ShoppingListOut from mealie.schema.group.group_shopping_list import ShoppingListItemOut, ShoppingListOut
from mealie.schema.recipe.recipe_ingredient import SaveIngredientFood
from tests import utils from tests import utils
from tests.utils import api_routes from tests.utils import api_routes
from tests.utils.factories import random_int, random_string from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser from tests.utils.fixture_schemas import TestUser
def create_item(list_id: UUID4) -> dict: def create_item(list_id: UUID4, **kwargs) -> dict:
return { return {"shopping_list_id": str(list_id), "note": random_string(10), "quantity": random_int(1, 10), **kwargs}
"shopping_list_id": str(list_id),
"note": random_string(10),
"quantity": random_int(1, 10),
}
def serialize_list_items(list_items: list[ShoppingListItemOut]) -> list: def serialize_list_items(list_items: list[ShoppingListItemOut]) -> list:
@ -89,6 +88,69 @@ def test_shopping_list_items_create_many(
assert not created_item_ids assert not created_item_ids
def test_shopping_list_items_auto_assign_label_with_food_without_label(
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
):
food = database.ingredient_foods.create(SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id))
item = create_item(shopping_list.id, food_id=str(food.id))
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
as_json = utils.assert_derserialize(response, 201)
assert len(as_json["createdItems"]) == 1
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
assert item_out.label_id is None
assert item_out.label is None
def test_shopping_list_items_auto_assign_label_with_food_with_label(
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
):
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
food = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id, label_id=label.id)
)
item = create_item(shopping_list.id, food_id=str(food.id))
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
as_json = utils.assert_derserialize(response, 201)
assert len(as_json["createdItems"]) == 1
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
assert item_out.label_id == label.id
assert item_out.label
assert item_out.label.id == label.id
@pytest.mark.parametrize("use_fuzzy_name", [True, False])
def test_shopping_list_items_auto_assign_label_with_food_search(
api_client: TestClient,
unique_user: TestUser,
shopping_list: ShoppingListOut,
database: AllRepositories,
use_fuzzy_name: bool,
):
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
food = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(20), group_id=unique_user.group_id, label_id=label.id)
)
item = create_item(shopping_list.id)
name = food.name
if use_fuzzy_name:
name = name + random_string(2)
item["note"] = name
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
as_json = utils.assert_derserialize(response, 201)
assert len(as_json["createdItems"]) == 1
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
assert item_out.label_id == label.id
assert item_out.label
assert item_out.label.id == label.id
def test_shopping_list_items_get_one( def test_shopping_list_items_get_one(
api_client: TestClient, api_client: TestClient,
unique_user: TestUser, unique_user: TestUser,