mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-07-09 03:04:54 -04:00
feat: Auto-label new shopping list items (#3800)
This commit is contained in:
parent
9795b4c553
commit
da11204cd7
@ -17,12 +17,19 @@ from mealie.schema.group.group_shopping_list import (
|
|||||||
ShoppingListMultiPurposeLabelCreate,
|
ShoppingListMultiPurposeLabelCreate,
|
||||||
ShoppingListSave,
|
ShoppingListSave,
|
||||||
)
|
)
|
||||||
from mealie.schema.recipe.recipe_ingredient import IngredientFood, IngredientUnit, RecipeIngredient
|
from mealie.schema.recipe.recipe_ingredient import (
|
||||||
|
IngredientFood,
|
||||||
|
IngredientUnit,
|
||||||
|
RecipeIngredient,
|
||||||
|
)
|
||||||
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
||||||
from mealie.schema.user.user import GroupInDB, UserOut
|
from mealie.schema.user.user import GroupInDB, UserOut
|
||||||
|
from mealie.services.parser_services._base import DataMatcher
|
||||||
|
|
||||||
|
|
||||||
class ShoppingListService:
|
class ShoppingListService:
|
||||||
|
DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD = 80
|
||||||
|
|
||||||
def __init__(self, repos: AllRepositories, group: GroupInDB, user: UserOut):
|
def __init__(self, repos: AllRepositories, group: GroupInDB, user: UserOut):
|
||||||
self.repos = repos
|
self.repos = repos
|
||||||
self.group = group
|
self.group = group
|
||||||
@ -31,6 +38,9 @@ class ShoppingListService:
|
|||||||
self.list_items = repos.group_shopping_list_item
|
self.list_items = repos.group_shopping_list_item
|
||||||
self.list_item_refs = repos.group_shopping_list_item_references
|
self.list_item_refs = repos.group_shopping_list_item_references
|
||||||
self.list_refs = repos.group_shopping_list_recipe_refs
|
self.list_refs = repos.group_shopping_list_recipe_refs
|
||||||
|
self.data_matcher = DataMatcher(
|
||||||
|
self.group.id, self.repos, food_fuzzy_match_threshold=self.DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool:
|
def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool:
|
||||||
@ -108,7 +118,23 @@ class ShoppingListService:
|
|||||||
if list_refs_to_delete:
|
if list_refs_to_delete:
|
||||||
self.list_refs.delete_many(list_refs_to_delete)
|
self.list_refs.delete_many(list_refs_to_delete)
|
||||||
|
|
||||||
def bulk_create_items(self, create_items: list[ShoppingListItemCreate]) -> ShoppingListItemsCollectionOut:
|
def find_matching_label(self, item: ShoppingListItemBase) -> UUID4 | None:
|
||||||
|
if item.label_id:
|
||||||
|
return item.label_id
|
||||||
|
if item.food:
|
||||||
|
return item.food.label_id
|
||||||
|
|
||||||
|
food_search = self.data_matcher.find_food_match(item.display)
|
||||||
|
return food_search.label_id if food_search else None
|
||||||
|
|
||||||
|
def bulk_create_items(
|
||||||
|
self, create_items: list[ShoppingListItemCreate], auto_find_labels=True
|
||||||
|
) -> ShoppingListItemsCollectionOut:
|
||||||
|
"""
|
||||||
|
Create a list of items, merging into existing ones where possible.
|
||||||
|
Optionally try to find a label for each item if one isn't provided using the item's food data or display name.
|
||||||
|
"""
|
||||||
|
|
||||||
# consolidate items to be created
|
# consolidate items to be created
|
||||||
consolidated_create_items: list[ShoppingListItemCreate] = []
|
consolidated_create_items: list[ShoppingListItemCreate] = []
|
||||||
for create_item in create_items:
|
for create_item in create_items:
|
||||||
@ -157,18 +183,13 @@ class ShoppingListService:
|
|||||||
if create_item.checked:
|
if create_item.checked:
|
||||||
# checked items should not have recipe references
|
# checked items should not have recipe references
|
||||||
create_item.recipe_references = []
|
create_item.recipe_references = []
|
||||||
|
if auto_find_labels:
|
||||||
|
create_item.label_id = self.find_matching_label(create_item)
|
||||||
|
|
||||||
filtered_create_items.append(create_item)
|
filtered_create_items.append(create_item)
|
||||||
|
|
||||||
created_items = cast(
|
created_items = self.list_items.create_many(filtered_create_items) if filtered_create_items else []
|
||||||
list[ShoppingListItemOut],
|
updated_items = self.list_items.update_many(update_items) if update_items else []
|
||||||
self.list_items.create_many(filtered_create_items) if filtered_create_items else [], # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_items = cast(
|
|
||||||
list[ShoppingListItemOut],
|
|
||||||
self.list_items.update_many(update_items) if update_items else [], # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
for list_id in set(item.shopping_list_id for item in created_items + updated_items):
|
for list_id in set(item.shopping_list_id for item in created_items + updated_items):
|
||||||
self.remove_unused_recipe_references(list_id)
|
self.remove_unused_recipe_references(list_id)
|
||||||
|
@ -20,26 +20,26 @@ from mealie.schema.response.pagination import PaginationQuery
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class ABCIngredientParser(ABC):
|
class DataMatcher:
|
||||||
"""
|
def __init__(
|
||||||
Abstract class for ingredient parsers.
|
self,
|
||||||
"""
|
group_id: UUID4,
|
||||||
|
repos: AllRepositories,
|
||||||
def __init__(self, group_id: UUID4, session: Session) -> None:
|
food_fuzzy_match_threshold: int = 85,
|
||||||
|
unit_fuzzy_match_threshold: int = 70,
|
||||||
|
) -> None:
|
||||||
self.group_id = group_id
|
self.group_id = group_id
|
||||||
self.session = session
|
self.repos = repos
|
||||||
|
|
||||||
|
self._food_fuzzy_match_threshold = food_fuzzy_match_threshold
|
||||||
|
self._unit_fuzzy_match_threshold = unit_fuzzy_match_threshold
|
||||||
self._foods_by_alias: dict[str, IngredientFood] | None = None
|
self._foods_by_alias: dict[str, IngredientFood] | None = None
|
||||||
self._units_by_alias: dict[str, IngredientUnit] | None = None
|
self._units_by_alias: dict[str, IngredientUnit] | None = None
|
||||||
|
|
||||||
@property
|
|
||||||
def _repos(self) -> AllRepositories:
|
|
||||||
return get_repositories(self.session)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def foods_by_alias(self) -> dict[str, IngredientFood]:
|
def foods_by_alias(self) -> dict[str, IngredientFood]:
|
||||||
if self._foods_by_alias is None:
|
if self._foods_by_alias is None:
|
||||||
foods_repo = self._repos.ingredient_foods.by_group(self.group_id)
|
foods_repo = self.repos.ingredient_foods.by_group(self.group_id)
|
||||||
query = PaginationQuery(page=1, per_page=-1)
|
query = PaginationQuery(page=1, per_page=-1)
|
||||||
all_foods = foods_repo.page_all(query).items
|
all_foods = foods_repo.page_all(query).items
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class ABCIngredientParser(ABC):
|
|||||||
@property
|
@property
|
||||||
def units_by_alias(self) -> dict[str, IngredientUnit]:
|
def units_by_alias(self) -> dict[str, IngredientUnit]:
|
||||||
if self._units_by_alias is None:
|
if self._units_by_alias is None:
|
||||||
units_repo = self._repos.ingredient_units.by_group(self.group_id)
|
units_repo = self.repos.ingredient_units.by_group(self.group_id)
|
||||||
query = PaginationQuery(page=1, per_page=-1)
|
query = PaginationQuery(page=1, per_page=-1)
|
||||||
all_units = units_repo.page_all(query).items
|
all_units = units_repo.page_all(query).items
|
||||||
|
|
||||||
@ -84,24 +84,6 @@ class ABCIngredientParser(ABC):
|
|||||||
|
|
||||||
return self._units_by_alias
|
return self._units_by_alias
|
||||||
|
|
||||||
@property
|
|
||||||
def food_fuzzy_match_threshold(self) -> int:
|
|
||||||
"""Minimum threshold to fuzzy match against a database food search"""
|
|
||||||
|
|
||||||
return 85
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unit_fuzzy_match_threshold(self) -> int:
|
|
||||||
"""Minimum threshold to fuzzy match against a database unit search"""
|
|
||||||
|
|
||||||
return 70
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None:
|
def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None:
|
||||||
# check for literal matches
|
# check for literal matches
|
||||||
@ -126,7 +108,7 @@ class ABCIngredientParser(ABC):
|
|||||||
return self.find_match(
|
return self.find_match(
|
||||||
match_value,
|
match_value,
|
||||||
store_map=self.foods_by_alias,
|
store_map=self.foods_by_alias,
|
||||||
fuzzy_match_threshold=self.food_fuzzy_match_threshold,
|
fuzzy_match_threshold=self._food_fuzzy_match_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit | str) -> IngredientUnit | None:
|
def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit | str) -> IngredientUnit | None:
|
||||||
@ -138,21 +120,56 @@ class ABCIngredientParser(ABC):
|
|||||||
return self.find_match(
|
return self.find_match(
|
||||||
match_value,
|
match_value,
|
||||||
store_map=self.units_by_alias,
|
store_map=self.units_by_alias,
|
||||||
fuzzy_match_threshold=self.unit_fuzzy_match_threshold,
|
fuzzy_match_threshold=self._unit_fuzzy_match_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ABCIngredientParser(ABC):
|
||||||
|
"""
|
||||||
|
Abstract class for ingredient parsers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, group_id: UUID4, session: Session) -> None:
|
||||||
|
self.group_id = group_id
|
||||||
|
self.session = session
|
||||||
|
self.data_matcher = DataMatcher(
|
||||||
|
self.group_id, self._repos, self.food_fuzzy_match_threshold, self.unit_fuzzy_match_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _repos(self) -> AllRepositories:
|
||||||
|
return get_repositories(self.session)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def food_fuzzy_match_threshold(self) -> int:
|
||||||
|
"""Minimum threshold to fuzzy match against a database food search"""
|
||||||
|
|
||||||
|
return 85
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unit_fuzzy_match_threshold(self) -> int:
|
||||||
|
"""Minimum threshold to fuzzy match against a database unit search"""
|
||||||
|
|
||||||
|
return 70
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def parse_one(self, ingredient_string: str) -> ParsedIngredient: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def parse(self, ingredients: list[str]) -> list[ParsedIngredient]: ...
|
||||||
|
|
||||||
def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient:
|
def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient:
|
||||||
if ingredient.ingredient.food and (food_match := self.find_food_match(ingredient.ingredient.food)):
|
if ingredient.ingredient.food and (food_match := self.data_matcher.find_food_match(ingredient.ingredient.food)):
|
||||||
ingredient.ingredient.food = food_match
|
ingredient.ingredient.food = food_match
|
||||||
|
|
||||||
if ingredient.ingredient.unit and (unit_match := self.find_unit_match(ingredient.ingredient.unit)):
|
if ingredient.ingredient.unit and (unit_match := self.data_matcher.find_unit_match(ingredient.ingredient.unit)):
|
||||||
ingredient.ingredient.unit = unit_match
|
ingredient.ingredient.unit = unit_match
|
||||||
|
|
||||||
# Parser might have wrongly split a food into a unit and food.
|
# Parser might have wrongly split a food into a unit and food.
|
||||||
if isinstance(ingredient.ingredient.food, CreateIngredientFood) and isinstance(
|
if isinstance(ingredient.ingredient.food, CreateIngredientFood) and isinstance(
|
||||||
ingredient.ingredient.unit, CreateIngredientUnit
|
ingredient.ingredient.unit, CreateIngredientUnit
|
||||||
):
|
):
|
||||||
if food_match := self.find_food_match(
|
if food_match := self.data_matcher.find_food_match(
|
||||||
f"{ingredient.ingredient.unit.name} {ingredient.ingredient.food.name}"
|
f"{ingredient.ingredient.unit.name} {ingredient.ingredient.food.name}"
|
||||||
):
|
):
|
||||||
ingredient.ingredient.food = food_match
|
ingredient.ingredient.food = food_match
|
||||||
|
@ -194,7 +194,7 @@ def parse(ing_str, parser) -> BruteParsedIngredient:
|
|||||||
# try to parse as unit and ingredient (e.g. "a tblsp salt"), with unit in first three tokens
|
# try to parse as unit and ingredient (e.g. "a tblsp salt"), with unit in first three tokens
|
||||||
# won't work for units that have spaces
|
# won't work for units that have spaces
|
||||||
for index, token in enumerate(tokens[:3]):
|
for index, token in enumerate(tokens[:3]):
|
||||||
if parser.find_unit_match(token):
|
if parser.data_matcher.find_unit_match(token):
|
||||||
unit = token
|
unit = token
|
||||||
ingredient, note = parse_ingredient(tokens[index + 1 :])
|
ingredient, note = parse_ingredient(tokens[index + 1 :])
|
||||||
break
|
break
|
||||||
|
@ -45,7 +45,7 @@ class OpenAIParser(ABCIngredientParser):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
if service.send_db_data and self.units_by_alias:
|
if service.send_db_data and self.data_matcher.units_by_alias:
|
||||||
data_injections.extend(
|
data_injections.extend(
|
||||||
[
|
[
|
||||||
OpenAIDataInjection(
|
OpenAIDataInjection(
|
||||||
@ -55,7 +55,7 @@ class OpenAIParser(ABCIngredientParser):
|
|||||||
"find a unit in the input that does not exist in this list. This should not prevent "
|
"find a unit in the input that does not exist in this list. This should not prevent "
|
||||||
"you from parsing that text as a unit, however it may lower your confidence level."
|
"you from parsing that text as a unit, however it may lower your confidence level."
|
||||||
),
|
),
|
||||||
value=list(set(self.units_by_alias)),
|
value=list(set(self.data_matcher.units_by_alias)),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -2,22 +2,21 @@ import random
|
|||||||
from math import ceil, floor
|
from math import ceil, floor
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from pydantic import UUID4
|
from pydantic import UUID4
|
||||||
|
|
||||||
|
from mealie.repos.repository_factory import AllRepositories
|
||||||
from mealie.schema.group.group_shopping_list import ShoppingListItemOut, ShoppingListOut
|
from mealie.schema.group.group_shopping_list import ShoppingListItemOut, ShoppingListOut
|
||||||
|
from mealie.schema.recipe.recipe_ingredient import SaveIngredientFood
|
||||||
from tests import utils
|
from tests import utils
|
||||||
from tests.utils import api_routes
|
from tests.utils import api_routes
|
||||||
from tests.utils.factories import random_int, random_string
|
from tests.utils.factories import random_int, random_string
|
||||||
from tests.utils.fixture_schemas import TestUser
|
from tests.utils.fixture_schemas import TestUser
|
||||||
|
|
||||||
|
|
||||||
def create_item(list_id: UUID4) -> dict:
|
def create_item(list_id: UUID4, **kwargs) -> dict:
|
||||||
return {
|
return {"shopping_list_id": str(list_id), "note": random_string(10), "quantity": random_int(1, 10), **kwargs}
|
||||||
"shopping_list_id": str(list_id),
|
|
||||||
"note": random_string(10),
|
|
||||||
"quantity": random_int(1, 10),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_list_items(list_items: list[ShoppingListItemOut]) -> list:
|
def serialize_list_items(list_items: list[ShoppingListItemOut]) -> list:
|
||||||
@ -89,6 +88,69 @@ def test_shopping_list_items_create_many(
|
|||||||
assert not created_item_ids
|
assert not created_item_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_shopping_list_items_auto_assign_label_with_food_without_label(
|
||||||
|
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
|
||||||
|
):
|
||||||
|
food = database.ingredient_foods.create(SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id))
|
||||||
|
|
||||||
|
item = create_item(shopping_list.id, food_id=str(food.id))
|
||||||
|
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
|
||||||
|
as_json = utils.assert_derserialize(response, 201)
|
||||||
|
assert len(as_json["createdItems"]) == 1
|
||||||
|
|
||||||
|
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
|
||||||
|
assert item_out.label_id is None
|
||||||
|
assert item_out.label is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_shopping_list_items_auto_assign_label_with_food_with_label(
|
||||||
|
api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut, database: AllRepositories
|
||||||
|
):
|
||||||
|
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
|
||||||
|
food = database.ingredient_foods.create(
|
||||||
|
SaveIngredientFood(name=random_string(10), group_id=unique_user.group_id, label_id=label.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
item = create_item(shopping_list.id, food_id=str(food.id))
|
||||||
|
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
|
||||||
|
as_json = utils.assert_derserialize(response, 201)
|
||||||
|
assert len(as_json["createdItems"]) == 1
|
||||||
|
|
||||||
|
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
|
||||||
|
assert item_out.label_id == label.id
|
||||||
|
assert item_out.label
|
||||||
|
assert item_out.label.id == label.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_fuzzy_name", [True, False])
|
||||||
|
def test_shopping_list_items_auto_assign_label_with_food_search(
|
||||||
|
api_client: TestClient,
|
||||||
|
unique_user: TestUser,
|
||||||
|
shopping_list: ShoppingListOut,
|
||||||
|
database: AllRepositories,
|
||||||
|
use_fuzzy_name: bool,
|
||||||
|
):
|
||||||
|
label = database.group_multi_purpose_labels.create({"name": random_string(10), "group_id": unique_user.group_id})
|
||||||
|
food = database.ingredient_foods.create(
|
||||||
|
SaveIngredientFood(name=random_string(20), group_id=unique_user.group_id, label_id=label.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
item = create_item(shopping_list.id)
|
||||||
|
name = food.name
|
||||||
|
if use_fuzzy_name:
|
||||||
|
name = name + random_string(2)
|
||||||
|
item["note"] = name
|
||||||
|
|
||||||
|
response = api_client.post(api_routes.groups_shopping_items, json=item, headers=unique_user.token)
|
||||||
|
as_json = utils.assert_derserialize(response, 201)
|
||||||
|
assert len(as_json["createdItems"]) == 1
|
||||||
|
|
||||||
|
item_out = ShoppingListItemOut.model_validate(as_json["createdItems"][0])
|
||||||
|
assert item_out.label_id == label.id
|
||||||
|
assert item_out.label
|
||||||
|
assert item_out.label.id == label.id
|
||||||
|
|
||||||
|
|
||||||
def test_shopping_list_items_get_one(
|
def test_shopping_list_items_get_one(
|
||||||
api_client: TestClient,
|
api_client: TestClient,
|
||||||
unique_user: TestUser,
|
unique_user: TestUser,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user