fix: Recipe Zip Export Can't Be Imported (#2585)

* clean recipe data when importing via zip

* added tests

* simplified recursive logic
This commit is contained in:
Michael Genson 2023-10-07 14:18:55 -05:00 committed by GitHub
parent 84b477edf6
commit 26ef351ae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 336 additions and 3 deletions

View File

@ -3,6 +3,7 @@ import shutil
from datetime import datetime
from pathlib import Path
from shutil import copytree, rmtree
from typing import Any
from uuid import UUID, uuid4
from zipfile import ZipFile
@ -12,6 +13,7 @@ from slugify import slugify
from mealie.core import exceptions
from mealie.pkgs import cache
from mealie.repos.repository_factory import AllRepositories
from mealie.repos.repository_generic import RepositoryGeneric
from mealie.schema.recipe.recipe import CreateRecipe, Recipe
from mealie.schema.recipe.recipe_ingredient import RecipeIngredient
from mealie.schema.recipe.recipe_settings import RecipeSettings
@ -157,6 +159,60 @@ class RecipeService(BaseService):
self.repos.recipe_timeline_events.create(timeline_event_data)
return new_recipe
def _transform_user_id(self, user_id: str) -> str:
query = self.repos.users.by_group(self.group.id).get_one(user_id)
if query:
return user_id
else:
# default to the current user
return str(self.user.id)
def _transform_category_or_tag(self, data: dict, repo: RepositoryGeneric) -> dict:
slug = data.get("slug")
if not slug:
return data
# if the item exists, return the actual data
query = repo.get_one(slug, "slug")
if query:
return query.dict()
# otherwise, create the item
new_item = repo.create(data)
return new_item.dict()
def _process_recipe_data(self, key: str, data: list | dict | Any):
if isinstance(data, list):
return [self._process_recipe_data(key, item) for item in data]
elif isinstance(data, str):
# make sure the user is valid
if key == "user_id":
return self._transform_user_id(str(data))
return data
elif not isinstance(data, dict):
return data
# force group_id to match the group id of the current user
data["group_id"] = str(self.group.id)
# make sure categories and tags are valid
if key == "recipe_category":
return self._transform_category_or_tag(data, self.repos.categories.by_group(self.group.id))
elif key == "tags":
return self._transform_category_or_tag(data, self.repos.tags.by_group(self.group.id))
# recursively process other objects
for k, v in data.items():
data[k] = self._process_recipe_data(k, v)
return data
def clean_recipe_dict(self, recipe: dict[str, Any]) -> dict[str, Any]:
return self._process_recipe_data("recipe", recipe)
def create_from_zip(self, archive: UploadFile, temp_path: Path) -> Recipe:
"""
`create_from_zip` creates a recipe in the database from a zip file exported from Mealie. This is NOT
@ -180,7 +236,7 @@ class RecipeService(BaseService):
if recipe_dict is None:
raise exceptions.UnexpectedNone("No json data found in Zip")
recipe = self.create_one(Recipe(**recipe_dict))
recipe = self.create_one(Recipe(**self.clean_recipe_dict(recipe_dict)))
if recipe and recipe.id:
data_service = RecipeDataService(recipe.id)

View File

@ -1,5 +1,12 @@
import json
import os
import random
import shutil
import tempfile
from pathlib import Path
from typing import Generator
from uuid import uuid4
from zipfile import ZipFile
import pytest
from bs4 import BeautifulSoup
@ -9,18 +16,37 @@ from recipe_scrapers._abstract import AbstractScraper
from recipe_scrapers._schemaorg import SchemaOrg
from slugify import slugify
from mealie.schema.recipe.recipe import RecipeCategory
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe.recipe import RecipeCategory, RecipeSummary, RecipeTag
from mealie.services.recipe.recipe_data_service import RecipeDataService
from mealie.services.scraper.recipe_scraper import DEFAULT_SCRAPER_STRATEGIES
from tests import data, utils
from tests.utils import api_routes
from tests.utils.factories import random_string
from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser
from tests.utils.recipe_data import RecipeSiteTestCase, get_recipe_test_cases
recipe_test_data = get_recipe_test_cases()
@pytest.fixture(scope="module")
def tempdir() -> Generator[str, None, None]:
with tempfile.TemporaryDirectory() as td:
yield td
def zip_recipe(tempdir: str, recipe: RecipeSummary) -> dict:
data_file = tempfile.NamedTemporaryFile(mode="w+", dir=tempdir, suffix=".json", delete=False)
json.dump(json.loads(recipe.json()), data_file)
data_file.flush()
zip_file = shutil.make_archive(os.path.join(tempdir, "zipfile"), "zip")
with ZipFile(zip_file, "w") as zf:
zf.write(data_file.name)
return {"archive": Path(zip_file).read_bytes()}
def get_init(html_path: Path):
"""
Override the init method of the abstract scraper to return a bootstrapped init function that
@ -163,6 +189,257 @@ def test_create_by_url_with_tags(
assert tag["name"] in expected_tags
def test_create_recipe_from_zip(database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str):
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
def test_create_recipe_from_zip_invalid_group(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=uuid4(),
name=recipe_name,
slug=recipe_name,
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
# the group should always be set to the current user's group
assert str(fetched_recipe.group_id) == str(unique_user.group_id)
def test_create_recipe_from_zip_invalid_user(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=uuid4(),
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
# invalid users should default to the current user
assert str(fetched_recipe.user_id) == str(unique_user.user_id)
def test_create_recipe_from_zip_existing_category(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
categories = database.categories.by_group(unique_user.group_id).create_many(
[{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))]
)
category = random.choice(categories)
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
recipe_category=[category],
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
assert fetched_recipe.recipe_category
assert len(fetched_recipe.recipe_category) == 1
assert str(fetched_recipe.recipe_category[0].id) == str(category.id)
def test_create_recipe_from_zip_existing_tag(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
tags = database.tags.by_group(unique_user.group_id).create_many(
[{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))]
)
tag = random.choice(tags)
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
tags=[tag],
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
assert fetched_recipe.tags
assert len(fetched_recipe.tags) == 1
assert str(fetched_recipe.tags[0].id) == str(tag.id)
def test_create_recipe_from_zip_existing_category_wrong_ids(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
categories = database.categories.by_group(unique_user.group_id).create_many(
[{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))]
)
category = random.choice(categories)
invalid_category = RecipeCategory(id=uuid4(), name=category.name, slug=category.slug)
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
recipe_category=[invalid_category],
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
assert fetched_recipe.recipe_category
assert len(fetched_recipe.recipe_category) == 1
assert str(fetched_recipe.recipe_category[0].id) == str(category.id)
def test_create_recipe_from_zip_existing_tag_wrong_ids(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
tags = database.tags.by_group(unique_user.group_id).create_many(
[{"name": random_string(), "group_id": unique_user.group_id} for _ in range(random_int(5, 10))]
)
tag = random.choice(tags)
invalid_tag = RecipeTag(id=uuid4(), name=tag.name, slug=tag.slug)
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
tags=[invalid_tag],
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
assert fetched_recipe.tags
assert len(fetched_recipe.tags) == 1
assert str(fetched_recipe.tags[0].id) == str(tag.id)
def test_create_recipe_from_zip_invalid_category(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
invalid_name = random_string()
invalid_category = RecipeCategory(id=uuid4(), name=invalid_name, slug=invalid_name)
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
recipe_category=[invalid_category],
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
assert fetched_recipe.recipe_category
assert len(fetched_recipe.recipe_category) == 1
# a new category should be created
assert fetched_recipe.recipe_category[0].name == invalid_name
assert fetched_recipe.recipe_category[0].slug == invalid_name
def test_create_recipe_from_zip_invalid_tag(
database: AllRepositories, api_client: TestClient, unique_user: TestUser, tempdir: str
):
invalid_name = random_string()
invalid_tag = RecipeTag(id=uuid4(), name=invalid_name, slug=invalid_name)
recipe_name = random_string()
recipe = RecipeSummary(
id=uuid4(),
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=recipe_name,
slug=recipe_name,
tags=[invalid_tag],
)
r = api_client.post(
api_routes.recipes_create_from_zip, files=zip_recipe(tempdir, recipe), headers=unique_user.token
)
assert r.status_code == 201
fetched_recipe = database.recipes.get_by_slug(unique_user.group_id, recipe.slug)
assert fetched_recipe
assert fetched_recipe.tags
assert len(fetched_recipe.tags) == 1
# a new tag should be created
assert fetched_recipe.tags[0].name == invalid_name
assert fetched_recipe.tags[0].slug == invalid_name
@pytest.mark.parametrize("recipe_data", recipe_test_data)
def test_read_update(
api_client: TestClient,