From 8e2d50054c8b97406f107ad89a53904fc0298b78 Mon Sep 17 00:00:00 2001 From: Michael Genson <71845777+michael-genson@users.noreply.github.com> Date: Fri, 12 May 2023 01:28:14 -0500 Subject: [PATCH] Fix: Query Filter Date Comparisons Are Off By One Date (#2389) * fixed erroneous date -> datetime conversion * added tests for date and datetime bounds --- mealie/schema/response/query_filter.py | 3 +- .../repository_tests/test_pagination.py | 184 +++++++++++++++++- 2 files changed, 180 insertions(+), 7 deletions(-) diff --git a/mealie/schema/response/query_filter.py b/mealie/schema/response/query_filter.py index 5ff5324c040d..ad351225ecaa 100644 --- a/mealie/schema/response/query_filter.py +++ b/mealie/schema/response/query_filter.py @@ -177,7 +177,8 @@ class QueryFilterComponent: if isinstance(model_attr_type, sqltypes.Date | sqltypes.DateTime): try: - sanitized_values[i] = date_parser.parse(v) + dt = date_parser.parse(v) + sanitized_values[i] = dt.date() if isinstance(model_attr_type, sqltypes.Date) else dt except ParserError as e: raise ValueError(f"invalid query string: unknown date or datetime format '{v}'") from e diff --git a/tests/unit_tests/repository_tests/test_pagination.py b/tests/unit_tests/repository_tests/test_pagination.py index eaee104bcb4c..dab953f66fbe 100644 --- a/tests/unit_tests/repository_tests/test_pagination.py +++ b/tests/unit_tests/repository_tests/test_pagination.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from datetime import datetime +from datetime import date, datetime, timedelta from random import randint from urllib.parse import parse_qsl, urlsplit @@ -10,6 +10,7 @@ from humps import camelize from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_units import RepositoryUnit +from mealie.schema.meal_plan.new_meal import CreatePlanEntry from mealie.schema.recipe import Recipe from mealie.schema.recipe.recipe_category import CategorySave, TagSave from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit @@ -429,15 +430,186 @@ def test_pagination_filter_logical_namespace_conflict(database: AllRepositories, def test_pagination_filter_datetimes( query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit] ): - units_repo = query_units[0] - unit_1 = query_units[1] - unit_2 = query_units[2] + # units are created in order with increasing createdAt values + units_repo, unit_1, unit_2, unit_3 = query_units + + ## GT + past_dt: datetime = unit_1.created_at - timedelta(seconds=1) # type: ignore + dt = past_dt.isoformat() + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 3 + assert unit_1.id in unit_ids + assert unit_2.id in unit_ids + assert unit_3.id in unit_ids + + dt = unit_1.created_at.isoformat() # type: ignore + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 2 + assert unit_1.id not in unit_ids + assert unit_2.id in unit_ids + assert unit_3.id in unit_ids + + dt = unit_2.created_at.isoformat() # type: ignore + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 1 + assert unit_1.id not in unit_ids + assert unit_2.id not in unit_ids + assert unit_3.id in unit_ids + + dt = unit_3.created_at.isoformat() # type: ignore + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 0 + + future_dt: datetime = unit_3.created_at + timedelta(seconds=1) # type: ignore + dt = future_dt.isoformat() + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 0 + + ## GTE + past_dt = unit_1.created_at - timedelta(seconds=1) # type: ignore + dt = past_dt.isoformat() + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 3 + assert unit_1.id in unit_ids + assert unit_2.id in unit_ids + assert unit_3.id in unit_ids + + dt = unit_1.created_at.isoformat() # type: ignore + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 3 + assert unit_1.id in unit_ids + assert unit_2.id in unit_ids + assert unit_3.id in unit_ids dt = unit_2.created_at.isoformat() # type: ignore query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') unit_results = units_repo.page_all(query).items - assert len(unit_results) == 2 - assert unit_1.id not in [unit.id for unit in unit_results] + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 2 + assert unit_1.id not in unit_ids + assert unit_2.id in unit_ids + assert unit_3.id in unit_ids + + dt = unit_3.created_at.isoformat() # type: ignore + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 1 + assert unit_1.id not in unit_ids + assert unit_2.id not in unit_ids + assert unit_3.id in unit_ids + + future_dt = unit_3.created_at + timedelta(seconds=1) # type: ignore + dt = future_dt.isoformat() + query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') + unit_results = units_repo.page_all(query).items + unit_ids = set(unit.id for unit in unit_results) + assert len(unit_ids) == 0 + + +def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser): + yesterday = date.today() - timedelta(days=1) + today = date.today() + tomorrow = date.today() + timedelta(days=1) + day_after_tomorrow = date.today() + timedelta(days=2) + + mealplan_today = CreatePlanEntry(date=today, entry_type="breakfast", title=random_string(), text=random_string()) + mealplan_tomorrow = CreatePlanEntry( + date=tomorrow, entry_type="breakfast", title=random_string(), text=random_string() + ) + + for mealplan_to_create in [mealplan_today, mealplan_tomorrow]: + data = mealplan_to_create.dict() + data["date"] = data["date"].strftime("%Y-%m-%d") + response = api_client.post(api_routes.groups_mealplans, json=data, headers=unique_user.token) + assert response.status_code == 201 + + ## Yesterday + params = {f"page": 1, "perPage": -1, "queryFilter": f"date >= {yesterday.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + + assert len(response_json["items"]) == 2 + fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) + assert mealplan_today.title in fetched_mealplan_titles + assert mealplan_tomorrow.title in fetched_mealplan_titles + + params = {f"page": 1, "perPage": -1, "queryFilter": f"date > {yesterday.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + + assert len(response_json["items"]) == 2 + fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) + assert mealplan_today.title in fetched_mealplan_titles + assert mealplan_tomorrow.title in fetched_mealplan_titles + + ## Today + params = {f"page": 1, "perPage": -1, "queryFilter": f"date >= {today.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + + assert len(response_json["items"]) == 2 + fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) + assert mealplan_today.title in fetched_mealplan_titles + assert mealplan_tomorrow.title in fetched_mealplan_titles + + params = {f"page": 1, "perPage": -1, "queryFilter": f"date > {today.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + + assert len(response_json["items"]) == 1 + fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) + assert mealplan_today.title not in fetched_mealplan_titles + assert mealplan_tomorrow.title in fetched_mealplan_titles + + ## Tomorrow + params = {f"page": 1, "perPage": -1, "queryFilter": f"date >= {tomorrow.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + + assert len(response_json["items"]) == 1 + fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) + assert mealplan_today.title not in fetched_mealplan_titles + assert mealplan_tomorrow.title in fetched_mealplan_titles + + params = {f"page": 1, "perPage": -1, "queryFilter": f"date > {tomorrow.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + + assert len(response_json["items"]) == 0 + + ## Day After Tomorrow + params = {f"page": 1, "perPage": -1, "queryFilter": f"date >= {day_after_tomorrow.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json["items"]) == 0 + + params = {f"page": 1, "perPage": -1, "queryFilter": f"date > {day_after_tomorrow.strftime('%Y-%m-%d')}"} + response = api_client.get(api_routes.groups_mealplans, params=params, headers=unique_user.token) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json["items"]) == 0 def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):