feat: Update Shopping List Timestamp on List Item Update (#3453)

This commit is contained in:
Michael Genson 2024-06-01 06:07:50 -05:00 committed by GitHub
parent d6ce607a4e
commit 94e91d3602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 158 additions and 3 deletions

View File

@ -1,7 +1,9 @@
from contextvars import ContextVar
from datetime import datetime
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from pydantic import ConfigDict from pydantic import ConfigDict
from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, UniqueConstraint, orm from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, UniqueConstraint, event, orm
from sqlalchemy.ext.orderinglist import ordering_list from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -150,3 +152,60 @@ class ShoppingList(SqlAlchemyBase, BaseMixins):
@auto_init() @auto_init()
def __init__(self, **_) -> None: def __init__(self, **_) -> None:
pass pass
class SessionBuffer:
def __init__(self) -> None:
self.shopping_list_ids: set[GUID] = set()
def add(self, shopping_list_id: GUID) -> None:
self.shopping_list_ids.add(shopping_list_id)
def pop(self) -> GUID | None:
try:
return self.shopping_list_ids.pop()
except KeyError:
return None
def clear(self) -> None:
self.shopping_list_ids.clear()
session_buffer_context = ContextVar("session_buffer", default=SessionBuffer())
@event.listens_for(ShoppingListItem, "after_insert")
@event.listens_for(ShoppingListItem, "after_update")
@event.listens_for(ShoppingListItem, "after_delete")
def buffer_shopping_list_updates(_, connection, target: ShoppingListItem):
"""Adds the shopping list id to the session buffer so its `update_at` property can be updated later"""
session_buffer = session_buffer_context.get()
session_buffer.add(target.shopping_list_id)
@event.listens_for(orm.Session, "after_flush")
def update_shopping_lists(session: orm.Session, _):
"""Pulls all pending shopping list updates from the buffer and updates their `update_at` property"""
session_buffer = session_buffer_context.get()
if not session_buffer.shopping_list_ids:
return
local_session = orm.Session(bind=session.connection())
try:
local_session.begin()
while True:
shopping_list_id = session_buffer.pop()
if not shopping_list_id:
break
shopping_list = local_session.query(ShoppingList).filter(ShoppingList.id == shopping_list_id).first()
if not shopping_list:
continue
shopping_list.update_at = datetime.now()
local_session.commit()
except Exception:
local_session.rollback()
raise

View File

@ -2,11 +2,17 @@ import random
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from mealie.schema.group.group_shopping_list import ShoppingListOut from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_shopping_list import (
ShoppingListItemOut,
ShoppingListItemUpdate,
ShoppingListItemUpdateBulk,
ShoppingListOut,
)
from mealie.schema.recipe.recipe import Recipe from mealie.schema.recipe.recipe import Recipe
from mealie.schema.recipe.recipe_ingredient import RecipeIngredient
from tests import utils from tests import utils
from tests.utils import api_routes from tests.utils import api_routes
from tests.utils.assertion_helpers import assert_derserialize
from tests.utils.factories import random_int, random_string from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser from tests.utils.fixture_schemas import TestUser
@ -755,3 +761,93 @@ def test_shopping_list_extras(
assert key_str_2 in extras assert key_str_2 in extras
assert extras[key_str_1] == val_str_1 assert extras[key_str_1] == val_str_1
assert extras[key_str_2] == val_str_2 assert extras[key_str_2] == val_str_2
def test_modify_shopping_list_items_updates_shopping_list(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, shopping_lists: list[ShoppingListOut]
):
shopping_list = random.choice(shopping_lists)
last_update_at = shopping_list.update_at
assert last_update_at
# Create
new_item_data = {"note": random_string(), "shopping_list_id": str(shopping_list.id)}
response = api_client.post(api_routes.groups_shopping_items, json=new_item_data, headers=unique_user.token)
data = assert_derserialize(response, 201)
updated_list = database.group_shopping_lists.get_one(shopping_list.id)
assert updated_list and updated_list.update_at
assert updated_list.update_at > last_update_at
last_update_at = updated_list.update_at
list_item_id = data["createdItems"][0]["id"]
list_item = database.group_shopping_list_item.get_one(list_item_id)
assert list_item
# Update
list_item.note = random_string()
response = api_client.put(
api_routes.groups_shopping_items_item_id(list_item_id),
json=utils.jsonify(list_item.cast(ShoppingListItemUpdate).model_dump()),
headers=unique_user.token,
)
assert response.status_code == 200
updated_list = database.group_shopping_lists.get_one(shopping_list.id)
assert updated_list and updated_list.update_at
assert updated_list.update_at > last_update_at
last_update_at = updated_list.update_at
# Delete
response = api_client.delete(api_routes.groups_shopping_items_item_id(list_item_id), headers=unique_user.token)
assert response.status_code == 200
updated_list = database.group_shopping_lists.get_one(shopping_list.id)
assert updated_list and updated_list.update_at
assert updated_list.update_at > last_update_at
def test_bulk_modify_shopping_list_items_updates_shopping_list(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, shopping_lists: list[ShoppingListOut]
):
shopping_list = random.choice(shopping_lists)
last_update_at = shopping_list.update_at
assert last_update_at
# Create
new_item_data = [
{"note": random_string(), "shopping_list_id": str(shopping_list.id)} for _ in range(random_int(3, 5))
]
response = api_client.post(
api_routes.groups_shopping_items_create_bulk, json=new_item_data, headers=unique_user.token
)
data = assert_derserialize(response, 201)
updated_list = database.group_shopping_lists.get_one(shopping_list.id)
assert updated_list and updated_list.update_at
assert updated_list.update_at > last_update_at
last_update_at = updated_list.update_at
# Update
list_item_ids = [item["id"] for item in data["createdItems"]]
list_items: list[ShoppingListItemOut] = []
for list_item_id in list_item_ids:
list_item = database.group_shopping_list_item.get_one(list_item_id)
assert list_item
list_item.note = random_string()
list_items.append(list_item)
payload = [utils.jsonify(list_item.cast(ShoppingListItemUpdateBulk).model_dump()) for list_item in list_items]
response = api_client.put(api_routes.groups_shopping_items, json=payload, headers=unique_user.token)
assert response.status_code == 200
updated_list = database.group_shopping_lists.get_one(shopping_list.id)
assert updated_list and updated_list.update_at
assert updated_list.update_at > last_update_at
last_update_at = updated_list.update_at
# Delete
response = api_client.delete(
api_routes.groups_shopping_items,
params={"ids": [str(list_item.id) for list_item in list_items]},
headers=unique_user.token,
)
assert response.status_code == 200
updated_list = database.group_shopping_lists.get_one(shopping_list.id)
assert updated_list and updated_list.update_at
assert updated_list.update_at > last_update_at