diff --git a/mealie/services/recipe/recipe_service.py b/mealie/services/recipe/recipe_service.py index 53138b73be84..533120a07711 100644 --- a/mealie/services/recipe/recipe_service.py +++ b/mealie/services/recipe/recipe_service.py @@ -3,6 +3,7 @@ import shutil from datetime import datetime from pathlib import Path from shutil import copytree, rmtree +from typing import Any from uuid import UUID, uuid4 from zipfile import ZipFile @@ -12,6 +13,7 @@ from slugify import slugify from mealie.core import exceptions from mealie.pkgs import cache from mealie.repos.repository_factory import AllRepositories +from mealie.repos.repository_generic import RepositoryGeneric from mealie.schema.recipe.recipe import CreateRecipe, Recipe from mealie.schema.recipe.recipe_ingredient import RecipeIngredient from mealie.schema.recipe.recipe_settings import RecipeSettings @@ -157,6 +159,60 @@ class RecipeService(BaseService): self.repos.recipe_timeline_events.create(timeline_event_data) return new_recipe + def _transform_user_id(self, user_id: str) -> str: + query = self.repos.users.by_group(self.group.id).get_one(user_id) + if query: + return user_id + else: + # default to the current user + return str(self.user.id) + + def _transform_category_or_tag(self, data: dict, repo: RepositoryGeneric) -> dict: + slug = data.get("slug") + if not slug: + return data + + # if the item exists, return the actual data + query = repo.get_one(slug, "slug") + if query: + return query.dict() + + # otherwise, create the item + new_item = repo.create(data) + return new_item.dict() + + def _process_recipe_data(self, key: str, data: list | dict | Any): + if isinstance(data, list): + return [self._process_recipe_data(key, item) for item in data] + + elif isinstance(data, str): + # make sure the user is valid + if key == "user_id": + return self._transform_user_id(str(data)) + + return data + + elif not isinstance(data, dict): + return data + + # force group_id to match the group id of the current user + data["group_id"] = str(self.group.id) + + # make sure categories and tags are valid + if key == "recipe_category": + return self._transform_category_or_tag(data, self.repos.categories.by_group(self.group.id)) + elif key == "tags": + return self._transform_category_or_tag(data, self.repos.tags.by_group(self.group.id)) + + # recursively process other objects + for k, v in data.items(): + data[k] = self._process_recipe_data(k, v) + + return data + + def clean_recipe_dict(self, recipe: dict[str, Any]) -> dict[str, Any]: + return self._process_recipe_data("recipe", recipe) + def create_from_zip(self, archive: UploadFile, temp_path: Path) -> Recipe: """ `create_from_zip` creates a recipe in the database from a zip file exported from Mealie. This is NOT @@ -180,7 +236,7 @@ class RecipeService(BaseService): if recipe_dict is None: raise exceptions.UnexpectedNone("No json data found in Zip") - recipe = self.create_one(Recipe(**recipe_dict)) + recipe = self.create_one(Recipe(**self.clean_recipe_dict(recipe_dict))) if recipe and recipe.id: data_service = RecipeDataService(recipe.id) diff --git a/tests/integration_tests/user_recipe_tests/test_recipe_crud.py b/tests/integration_tests/user_recipe_tests/test_recipe_crud.py index 03aee6944f64..e3c8b33b417e 100644 --- a/tests/integration_tests/user_recipe_tests/test_recipe_crud.py +++ b/tests/integration_tests/user_recipe_tests/test_recipe_crud.py @@ -1,5 +1,12 @@ import json +import os +import random +import shutil +import tempfile from pathlib import Path +from typing import Generator +from uuid import uuid4 +from zipfile import ZipFile import pytest from bs4 import BeautifulSoup @@ -9,18 +16,37 @@ from recipe_scrapers._abstract import AbstractScraper from recipe_scrapers._schemaorg import SchemaOrg from slugify import slugify -from mealie.schema.recipe.recipe import RecipeCategory +from mealie.repos.repository_factory import AllRepositories +from mealie.schema.recipe.recipe import RecipeCategory, RecipeSummary, RecipeTag from mealie.services.recipe.recipe_data_service import RecipeDataService from mealie.services.scraper.recipe_scraper import DEFAULT_SCRAPER_STRATEGIES from tests import data, utils from tests.utils import api_routes -from tests.utils.factories import random_string +from tests.utils.factories import random_int, random_string from tests.utils.fixture_schemas import TestUser from tests.utils.recipe_data import RecipeSiteTestCase, get_recipe_test_cases recipe_test_data = get_recipe_test_cases() +@pytest.fixture(scope="module") +def tempdir() -> Generator[str, None, None]: + with tempfile.TemporaryDirectory() as td: + yield td + + +def zip_recipe(tempdir: str, recipe: RecipeSummary) -> dict: + data_file = tempfile.NamedTemporaryFile(mode="w+", dir=tempdir, suffix=".json", delete=False) + json.dump(json.loads(recipe.json()), data_file) + data_file.flush() + + zip_file = shutil.make_archive(os.path.join(tempdir, "zipfile"), "zip") + with ZipFile(zip_file, "w") as zf: + zf.write(data_file.name) + + return {"archive": Path(zip_file).read_bytes()} + + def get_init(html_path: Path): """ Override the init method of the abstract scraper to return a bootstrapped init function that @@ -163,6 +189,257 @@ def test_create_by_url_with_tags( assert tag["name"] in expected_tags +def test_create_recipe_from_zip(database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str): + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + + +def test_create_recipe_from_zip_invalid_group( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=uuid4(), + name=recipe_name, + slug=recipe_name, + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + + # the group should always be set to the current user's group + assert str(fetched_recipe.group_id) == str(unique_user.group_id) + + +def test_create_recipe_from_zip_invalid_user( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=uuid4(), + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + + # invalid users should default to the current user + assert str(fetched_recipe.user_id) == str(unique_user.user_id) + + +def test_create_recipe_from_zip_existing_category( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + categories = database.categories.by_group(unique_user.group_id).create_many( + [{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))] + ) + category = random.choice(categories) + + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + recipe_category=[category], + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + assert fetched_recipe.recipe_category + assert len(fetched_recipe.recipe_category) == 1 + assert str(fetched_recipe.recipe_category[0].id) == str(category.id) + + +def test_create_recipe_from_zip_existing_tag( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + tags = database.tags.by_group(unique_user.group_id).create_many( + [{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))] + ) + tag = random.choice(tags) + + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + tags=[tag], + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + assert fetched_recipe.tags + assert len(fetched_recipe.tags) == 1 + assert str(fetched_recipe.tags[0].id) == str(tag.id) + + +def test_create_recipe_from_zip_existing_category_wrong_ids( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + categories = database.categories.by_group(unique_user.group_id).create_many( + [{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))] + ) + category = random.choice(categories) + invalid_category = RecipeCategory(id=uuid4(), name=category.name, slug=category.slug) + + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + recipe_category=[invalid_category], + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + assert fetched_recipe.recipe_category + assert len(fetched_recipe.recipe_category) == 1 + assert str(fetched_recipe.recipe_category[0].id) == str(category.id) + + +def test_create_recipe_from_zip_existing_tag_wrong_ids( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + tags = database.tags.by_group(unique_user.group_id).create_many( + [{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))] + ) + tag = random.choice(tags) + invalid_tag = RecipeTag(id=uuid4(), name=tag.name, slug=tag.slug) + + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + tags=[invalid_tag], + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + assert fetched_recipe.tags + assert len(fetched_recipe.tags) == 1 + assert str(fetched_recipe.tags[0].id) == str(tag.id) + + +def test_create_recipe_from_zip_invalid_category( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + invalid_name = random_string() + invalid_category = RecipeCategory(id=uuid4(), name=invalid_name, slug=invalid_name) + + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + recipe_category=[invalid_category], + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + assert fetched_recipe.recipe_category + assert len(fetched_recipe.recipe_category) == 1 + + # a new category should be created + assert fetched_recipe.recipe_category[0].name == invalid_name + assert fetched_recipe.recipe_category[0].slug == invalid_name + + +def test_create_recipe_from_zip_invalid_tag( + database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str +): + invalid_name = random_string() + invalid_tag = RecipeTag(id=uuid4(), name=invalid_name, slug=invalid_name) + + recipe_name = random_string() + recipe = RecipeSummary( + id=uuid4(), + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=recipe_name, + slug=recipe_name, + tags=[invalid_tag], + ) + + r = api_client.post( + api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token + ) + assert r.status_code == 201 + + fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug) + assert fetched_recipe + assert fetched_recipe.tags + assert len(fetched_recipe.tags) == 1 + + # a new tag should be created + assert fetched_recipe.tags[0].name == invalid_name + assert fetched_recipe.tags[0].slug == invalid_name + + @pytest.mark.parametrize("recipe_data", recipe_test_data) def test_read_update( api_client: TestClient,