diff --git a/mealie/db/models/group/shopping_list.py b/mealie/db/models/group/shopping_list.py index 9e76861c780a..e2b3b245da2b 100644 --- a/mealie/db/models/group/shopping_list.py +++ b/mealie/db/models/group/shopping_list.py @@ -1,7 +1,9 @@ +from contextvars import ContextVar +from datetime import datetime from typing import TYPE_CHECKING, Optional from pydantic import ConfigDict -from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, UniqueConstraint, orm +from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, UniqueConstraint, event, orm from sqlalchemy.ext.orderinglist import ordering_list from sqlalchemy.orm import Mapped, mapped_column @@ -150,3 +152,60 @@ class ShoppingList(SqlAlchemyBase, BaseMixins): @auto_init() def __init__(self, **_) -> None: pass + + +class SessionBuffer: + def __init__(self) -> None: + self.shopping_list_ids: set[GUID] = set() + + def add(self, shopping_list_id: GUID) -> None: + self.shopping_list_ids.add(shopping_list_id) + + def pop(self) -> GUID | None: + try: + return self.shopping_list_ids.pop() + except KeyError: + return None + + def clear(self) -> None: + self.shopping_list_ids.clear() + + +session_buffer_context = ContextVar("session_buffer", default=SessionBuffer()) + + +@event.listens_for(ShoppingListItem, "after_insert") +@event.listens_for(ShoppingListItem, "after_update") +@event.listens_for(ShoppingListItem, "after_delete") +def buffer_shopping_list_updates(_, connection, target: ShoppingListItem): + """Adds the shopping list id to the session buffer so its `update_at` property can be updated later""" + + session_buffer = session_buffer_context.get() + session_buffer.add(target.shopping_list_id) + + +@event.listens_for(orm.Session, "after_flush") +def update_shopping_lists(session: orm.Session, _): + """Pulls all pending shopping list updates from the buffer and updates their `update_at` property""" + + session_buffer = session_buffer_context.get() + if not session_buffer.shopping_list_ids: + return + + local_session = orm.Session(bind=session.connection()) + try: + local_session.begin() + while True: + shopping_list_id = session_buffer.pop() + if not shopping_list_id: + break + + shopping_list = local_session.query(ShoppingList).filter(ShoppingList.id == shopping_list_id).first() + if not shopping_list: + continue + + shopping_list.update_at = datetime.now() + local_session.commit() + except Exception: + local_session.rollback() + raise diff --git a/tests/integration_tests/user_group_tests/test_group_shopping_lists.py b/tests/integration_tests/user_group_tests/test_group_shopping_lists.py index dd6481089033..769e120d0207 100644 --- a/tests/integration_tests/user_group_tests/test_group_shopping_lists.py +++ b/tests/integration_tests/user_group_tests/test_group_shopping_lists.py @@ -2,11 +2,17 @@ import random from fastapi.testclient import TestClient -from mealie.schema.group.group_shopping_list import ShoppingListOut +from mealie.repos.repository_factory import AllRepositories +from mealie.schema.group.group_shopping_list import ( + ShoppingListItemOut, + ShoppingListItemUpdate, + ShoppingListItemUpdateBulk, + ShoppingListOut, +) from mealie.schema.recipe.recipe import Recipe -from mealie.schema.recipe.recipe_ingredient import RecipeIngredient from tests import utils from tests.utils import api_routes +from tests.utils.assertion_helpers import assert_derserialize from tests.utils.factories import random_int, random_string from tests.utils.fixture_schemas import TestUser @@ -755,3 +761,93 @@ def test_shopping_list_extras( assert key_str_2 in extras assert extras[key_str_1] == val_str_1 assert extras[key_str_2] == val_str_2 + + +def test_modify_shopping_list_items_updates_shopping_list( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, shopping_lists: list[ShoppingListOut] +): + shopping_list = random.choice(shopping_lists) + last_update_at = shopping_list.update_at + assert last_update_at + + # Create + new_item_data = {"note": random_string(), "shopping_list_id": str(shopping_list.id)} + response = api_client.post(api_routes.groups_shopping_items, json=new_item_data, headers=unique_user.token) + data = assert_derserialize(response, 201) + updated_list = database.group_shopping_lists.get_one(shopping_list.id) + assert updated_list and updated_list.update_at + assert updated_list.update_at > last_update_at + last_update_at = updated_list.update_at + + list_item_id = data["createdItems"][0]["id"] + list_item = database.group_shopping_list_item.get_one(list_item_id) + assert list_item + + # Update + list_item.note = random_string() + response = api_client.put( + api_routes.groups_shopping_items_item_id(list_item_id), + json=utils.jsonify(list_item.cast(ShoppingListItemUpdate).model_dump()), + headers=unique_user.token, + ) + assert response.status_code == 200 + updated_list = database.group_shopping_lists.get_one(shopping_list.id) + assert updated_list and updated_list.update_at + assert updated_list.update_at > last_update_at + last_update_at = updated_list.update_at + + # Delete + response = api_client.delete(api_routes.groups_shopping_items_item_id(list_item_id), headers=unique_user.token) + assert response.status_code == 200 + updated_list = database.group_shopping_lists.get_one(shopping_list.id) + assert updated_list and updated_list.update_at + assert updated_list.update_at > last_update_at + + +def test_bulk_modify_shopping_list_items_updates_shopping_list( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, shopping_lists: list[ShoppingListOut] +): + shopping_list = random.choice(shopping_lists) + last_update_at = shopping_list.update_at + assert last_update_at + + # Create + new_item_data = [ + {"note": random_string(), "shopping_list_id": str(shopping_list.id)} for _ in range(random_int(3, 5)) + ] + response = api_client.post( + api_routes.groups_shopping_items_create_bulk, json=new_item_data, headers=unique_user.token + ) + data = assert_derserialize(response, 201) + updated_list = database.group_shopping_lists.get_one(shopping_list.id) + assert updated_list and updated_list.update_at + assert updated_list.update_at > last_update_at + last_update_at = updated_list.update_at + + # Update + list_item_ids = [item["id"] for item in data["createdItems"]] + list_items: list[ShoppingListItemOut] = [] + for list_item_id in list_item_ids: + list_item = database.group_shopping_list_item.get_one(list_item_id) + assert list_item + list_item.note = random_string() + list_items.append(list_item) + + payload = [utils.jsonify(list_item.cast(ShoppingListItemUpdateBulk).model_dump()) for list_item in list_items] + response = api_client.put(api_routes.groups_shopping_items, json=payload, headers=unique_user.token) + assert response.status_code == 200 + updated_list = database.group_shopping_lists.get_one(shopping_list.id) + assert updated_list and updated_list.update_at + assert updated_list.update_at > last_update_at + last_update_at = updated_list.update_at + + # Delete + response = api_client.delete( + api_routes.groups_shopping_items, + params={"ids": [str(list_item.id) for list_item in list_items]}, + headers=unique_user.token, + ) + assert response.status_code == 200 + updated_list = database.group_shopping_lists.get_one(shopping_list.id) + assert updated_list and updated_list.update_at + assert updated_list.update_at > last_update_at