mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-07-09 03:04:54 -04:00
feat: Auto-label new shopping list items (#3800)
This commit is contained in:
parent
9795b4c553
commit
da11204cd7
@ -17,12 +17,19 @@ from mealie.schema.group.group_shopping_list import (
|
||||
ShoppingListMultiPurposeLabelCreate,
|
||||
ShoppingListSave,
|
||||
)
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientFood, IngredientUnit, RecipeIngredient
|
||||
from mealie.schema.recipe.recipe_ingredient import (
|
||||
IngredientFood,
|
||||
IngredientUnit,
|
||||
RecipeIngredient,
|
||||
)
|
||||
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
||||
from mealie.schema.user.user import GroupInDB, UserOut
|
||||
from mealie.services.parser_services._base import DataMatcher
|
||||
|
||||
|
||||
class ShoppingListService:
|
||||
DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD = 80
|
||||
|
||||
def __init__(self, repos: AllRepositories, group: GroupInDB, user: UserOut):
|
||||
self.repos = repos
|
||||
self.group = group
|
||||
@ -31,6 +38,9 @@ class ShoppingListService:
|
||||
self.list_items = repos.group_shopping_list_item
|
||||
self.list_item_refs = repos.group_shopping_list_item_references
|
||||
self.list_refs = repos.group_shopping_list_recipe_refs
|
||||
self.data_matcher = DataMatcher(
|
||||
self.group.id, self.repos, food_fuzzy_match_threshold=self.DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool:
|
||||
@ -108,7 +118,23 @@ class ShoppingListService:
|
||||
if list_refs_to_delete:
|
||||
self.list_refs.delete_many(list_refs_to_delete)
|
||||
|
||||
def bulk_create_items(self, create_items: list[ShoppingListItemCreate]) -> ShoppingListItemsCollectionOut:
|
||||
def find_matching_label(self, item: ShoppingListItemBase) -> UUID4 | None:
|
||||
if item.label_id:
|
||||
return item.label_id
|
||||
if item.food:
|
||||
return item.food.label_id
|
||||
|
||||
food_search = self.data_matcher.find_food_match(item.display)
|
||||
return food_search.label_id if food_search else None
|
||||
|
||||
def bulk_create_items(
|
||||
self, create_items: list[ShoppingListItemCreate], auto_find_labels=True
|
||||
) -> ShoppingListItemsCollectionOut:
|
||||
"""
|
||||
Create a list of items, merging into existing ones where possible.
|
||||
Optionally try to find a label for each item if one isn't provided using the item's food data or display name.
|
||||
"""
|
||||
|
||||
# consolidate items to be created
|
||||
consolidated_create_items: list[ShoppingListItemCreate] = []
|
||||
for create_item in create_items:
|
||||
@ -157,18 +183,13 @@ class ShoppingListService:
|
||||
if create_item.checked:
|
||||
# checked items should not have recipe references
|
||||
create_item.recipe_references = []
|
||||
if auto_find_labels:
|
||||
create_item.label_id = self.find_matching_label(create_item)
|
||||
|
||||
filtered_create_items.append(create_item)
|
||||
|
||||
created_items = cast(
|
||||
list[ShoppingListItemOut],
|
||||
self.list_items.create_many(filtered_create_items) if filtered_create_items else [], # type: ignore
|
||||
)
|
||||
|
||||
updated_items = cast(
|
||||
list[ShoppingListItemOut],
|
||||
self.list_items.update_many(update_items) if update_items else [], # type: ignore
|
||||
)
|
||||
created_items = self.list_items.create_many(filtered_create_items) if filtered_create_items else []
|
||||
updated_items = self.list_items.update_many(update_items) if update_items else []
|
||||
|
||||
for list_id in set(item.shopping_list_id for item in created_items + updated_items):
|
||||
self.remove_unused_recipe_references(list_id)
|
||||
|
@ -20,26 +20,26 @@ from mealie.schema.response.pagination import PaginationQuery
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ABCIngredientParser(ABC):
|
||||
"""
|
||||
Abstract class for ingredient parsers.
|
||||
"""
|
||||
|
||||
def __init__(self, group_id: UUID4, session: Session) -> None:
|
||||
class DataMatcher:
|
||||
def __init__(
|
||||
self,
|
||||
group_id: UUID4,
|
||||
repos: AllRepositories,
|
||||
food_fuzzy_match_threshold: int = 85,
|
||||
unit_fuzzy_match_threshold: int = 70,
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.session = session
|
||||
self.repos = repos
|
||||
|
||||
self._food_fuzzy_match_threshold = food_fuzzy_match_threshold
|
||||
self._unit_fuzzy_match_threshold = unit_fuzzy_match_threshold
|
||||
self._foods_by_alias: dict[str, IngredientFood] | None = None
|
||||
self._units_by_alias: dict[str, IngredientUnit] | None = None
|
||||
|
||||
@property
|
||||
def _repos(self) -> AllRepositories:
|
||||
return get_repositories(self.session)
|
||||
|
||||
@property
|
||||
def foods_by_alias(self) -> dict[str, IngredientFood]:
|
||||
if self._foods_by_alias is None:
|
||||
foods_repo = self._repos.ingredient_foods.by_group(self.group_id)
|
||||
foods_repo = self.repos.ingredient_foods.by_group(self.group_id)
|
||||
query = PaginationQuery(page=1, per_page=-1)
|
||||
all_foods = foods_repo.page_all(query).items
|
||||
|
||||
@ -61,7 +61,7 @@ class ABCIngredientParser(ABC):
|
||||
@property
|
||||
def units_by_alias(self) -> dict[str, IngredientUnit]:
|
||||
if self._units_by_alias is None:
|
||||
units_repo = self._repos.ingredient_units.by_group(self.group_id)
|
||||
units_repo = self.repos.ingredient_units.by_group(self.group_id)
|
||||
query = PaginationQuery(page=1, per_page=-1)
|
||||
all_units = units_repo.page_all(query).items
|
||||
|
||||
@ -84,24 +84,6 @@ class ABCIngredientParser(ABC):
|
||||
|
||||
return self._units_by_alias
|
||||
|
||||
@property
|
||||
def food_fuzzy_match_threshold(self) -> int:
|
||||
"""Minimum threshold to fuzzy match against a database food search"""
|
||||
|
||||
return 85
|
||||
|
||||
@property
|
||||
def unit_fuzzy_match_threshold(self) -> int:
|
||||
"""Minimum threshold to fuzzy match against a database unit search"""
|
||||
|
||||
return 70
|
||||
|
||||
@abstractmethod
|
||||
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
|
||||
|
||||
@abstractmethod
|
||||
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
|
||||
|
||||
@classmethod
|
||||
def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None:
|
||||
# check for literal matches
|
||||
@ -126,7 +108,7 @@ class ABCIngredientParser(ABC):
|
||||
return self.find_match(
|
||||
match_value,
|
||||
store_map=self.foods_by_alias,
|
||||
fuzzy_match_threshold=self.food_fuzzy_match_threshold,
|
||||
fuzzy_match_threshold=self._food_fuzzy_match_threshold,
|
||||
)
|
||||
|
||||
def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit | str) -> IngredientUnit | None:
|
||||
@ -138,21 +120,56 @@ class ABCIngredientParser(ABC):
|
||||
return self.find_match(
|
||||
match_value,
|
||||
store_map=self.units_by_alias,
|
||||
fuzzy_match_threshold=self.unit_fuzzy_match_threshold,
|
||||
fuzzy_match_threshold=self._unit_fuzzy_match_threshold,
|
||||
)
|
||||
|
||||
|
||||
class ABCIngredientParser(ABC):
|
||||
"""
|
||||
Abstract class for ingredient parsers.
|
||||
"""
|
||||
|
||||
def __init__(self, group_id: UUID4, session: Session) -> None:
|
||||
self.group_id = group_id
|
||||
self.session = session
|
||||
self.data_matcher = DataMatcher(
|
||||
self.group_id, self._repos, self.food_fuzzy_match_threshold, self.unit_fuzzy_match_threshold
|
||||
)
|
||||
|
||||
@property
|
||||
def _repos(self) -> AllRepositories:
|
||||
return get_repositories(self.session)
|
||||
|
||||
@property
|
||||
def food_fuzzy_match_threshold(self) -> int:
|
||||
"""Minimum threshold to fuzzy match against a database food search"""
|
||||
|
||||
return 85
|
||||
|
||||
@property
|
||||
def unit_fuzzy_match_threshold(self) -> int:
|
||||
"""Minimum threshold to fuzzy match against a database unit search"""
|
||||
|
||||
return 70
|
||||
|
||||
@abstractmethod
|
||||
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
|
||||
|
||||
@abstractmethod
|
||||
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
|
||||
|
||||
def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient:
|
||||
if ingredient.ingredient.food and (food_match := self.find_food_match(ingredient.ingredient.food)):
|
||||
if ingredient.ingredient.food and (food_match := self.data_matcher.find_food_match(ingredient.ingredient.food)):
|
||||
ingredient.ingredient.food = food_match
|
||||
|
||||
if ingredient.ingredient.unit and (unit_match := self.find_unit_match(ingredient.ingredient.unit)):
|
||||
if ingredient.ingredient.unit and (unit_match := self.data_matcher.find_unit_match(ingredient.ingredient.unit)):
|
||||
ingredient.ingredient.unit = unit_match
|
||||
|
||||
# Parser might have wrongly split a food into a unit and food.
|
||||
if isinstance(ingredient.ingredient.food, CreateIngredientFood) and isinstance(
|
||||
ingredient.ingredient.unit, CreateIngredientUnit
|
||||
):
|
||||
if food_match := self.find_food_match(
|
||||
if food_match := self.data_matcher.find_food_match(
|
||||
f"{ingredient.ingredient.unit.name} {ingredient.ingredient.food.name}"
|
||||
):
|
||||
ingredient.ingredient.food = food_match
|
||||
|
@ -194,7 +194,7 @@ def parse(ing_str, parser) -> BruteParsedIngredient:
|
||||
# try to parse as unit and ingredient (e.g. "a tblsp salt"), with unit in first three tokens
|
||||
# won't work for units that have spaces
|
||||
for index, token in enumerate(tokens[:3]):
|
||||
if parser.find_unit_match(token):
|
||||
if parser.data_matcher.find_unit_match(token):
|
||||
unit = token
|
||||
ingredient, note = parse_ingredient(tokens[index + 1 :])
|
||||
break
|
||||
|
@ -45,7 +45,7 @@ class OpenAIParser(ABCIngredientParser):
|
||||
),
|
||||
]
|
||||
|
||||
if service.send_db_data and self.units_by_alias:
|
||||
if service.send_db_data and self.data_matcher.units_by_alias:
|
||||
data_injections.extend(
|
||||
[
|
||||
OpenAIDataInjection(
|
||||
@ -55,7 +55,7 @@ class OpenAIParser(ABCIngredientParser):
|
||||
"find a unit in the input that does not exist in this list. This should not prevent "
|
||||
"you from parsing that text as a unit, however it may lower your confidence level."
|
||||
),
|
||||
value=list(set(self.units_by_alias)),
|
||||
value=list(set(self.data_matcher.units_by_alias)),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
@ -2,22 +2,21 @@ import random
|
||||
from math import ceil, floor
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import UUID4
|
||||
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.schema.group.group_shopping_list import ShoppingListItemOut, ShoppingListOut
|
||||
from mealie.schema.recipe.recipe_ingredient import SaveIngredientFood
|
||||
from tests import utils
|
||||
from tests.utils import api_routes
|
||||
from tests.utils.factories import random_int, random_string
|
||||
from tests.utils.fixture_schemas import TestUser
|
||||
|
||||
|
||||
def create_item(list_id: UUID4) -> dict:
|
||||
return {
|
||||
"shopping_list_id": str(list_id),
|
||||
"note": random_string(10),
|
||||
"quantity": random_int(1, 10),
|
||||
}
|
||||
def create_item(list_id: UUID4, **kwargs) -> dict:
|
||||
return {"shopping_list_id": str(list_id), "note": random_string(10), "quantity": random_int(1, 10), **kwargs}
|
||||
|
||||
|
||||
def serialize_list_items(list_items: list[ShoppingListItemOut]) -> list:
|
||||
@ -89,6 +88,69 @@ def test_shopping_list_items_create_many(
|
||||
assert not created_item_ids
|
||||
|
||||
|
||||
def test_shopping_list_items_auto_assign_label_with_food_without_label(
|
||||
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
|
||||
):
|
||||
food = database.ingredient_foods.create(SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id))
|
||||
|
||||
item = create_item(shopping_list.id, food_id=str(food.id))
|
||||
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
|
||||
as_json = utils.assert_derserialize(response, 201)
|
||||
assert len(as_json["createdItems"]) == 1
|
||||
|
||||
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
|
||||
assert item_out.label_id is None
|
||||
assert item_out.label is None
|
||||
|
||||
|
||||
def test_shopping_list_items_auto_assign_label_with_food_with_label(
|
||||
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
|
||||
):
|
||||
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
|
||||
food = database.ingredient_foods.create(
|
||||
SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id, label_id=label.id)
|
||||
)
|
||||
|
||||
item = create_item(shopping_list.id, food_id=str(food.id))
|
||||
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
|
||||
as_json = utils.assert_derserialize(response, 201)
|
||||
assert len(as_json["createdItems"]) == 1
|
||||
|
||||
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
|
||||
assert item_out.label_id == label.id
|
||||
assert item_out.label
|
||||
assert item_out.label.id == label.id
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fuzzy_name", [True, False])
|
||||
def test_shopping_list_items_auto_assign_label_with_food_search(
|
||||
api_client: TestClient,
|
||||
unique_user: TestUser,
|
||||
shopping_list: ShoppingListOut,
|
||||
database: AllRepositories,
|
||||
use_fuzzy_name: bool,
|
||||
):
|
||||
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
|
||||
food = database.ingredient_foods.create(
|
||||
SaveIngredientFood(name=random_string(20), group_id=unique_user.group_id, label_id=label.id)
|
||||
)
|
||||
|
||||
item = create_item(shopping_list.id)
|
||||
name = food.name
|
||||
if use_fuzzy_name:
|
||||
name = name + random_string(2)
|
||||
item["note"] = name
|
||||
|
||||
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
|
||||
as_json = utils.assert_derserialize(response, 201)
|
||||
assert len(as_json["createdItems"]) == 1
|
||||
|
||||
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
|
||||
assert item_out.label_id == label.id
|
||||
assert item_out.label
|
||||
assert item_out.label.id == label.id
|
||||
|
||||
|
||||
def test_shopping_list_items_get_one(
|
||||
api_client: TestClient,
|
||||
unique_user: TestUser,
|
||||
|
Loading…
x
Reference in New Issue
Block a user