mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-07-09 03:04:54 -04:00
fix: strict optional errors (#1759)
* fix strict optional errors * fix typing in repository * fix backup db files location * update workspace settings
This commit is contained in:
parent
97d9e2a109
commit
84c23765cd
8
.vscode/settings.json
vendored
8
.vscode/settings.json
vendored
@ -1,10 +1,4 @@
|
||||
{
|
||||
"conventionalCommits.scopes": [
|
||||
"frontend",
|
||||
"docs",
|
||||
"backend",
|
||||
"code-generation"
|
||||
],
|
||||
"cSpell.enableFiletypes": ["!javascript", "!python", "!yaml"],
|
||||
"cSpell.words": [
|
||||
"chowdown",
|
||||
@ -44,7 +38,7 @@
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"python.linting.mypyEnabled": true,
|
||||
"python.sortImports.path": "${workspaceFolder}/.venv/bin/isort",
|
||||
"isort.path": ["${workspaceFolder}/.venv/bin/isort"],
|
||||
"search.mode": "reuseEditor",
|
||||
"python.testing.unittestArgs": ["-v", "-s", "./tests", "-p", "test_*.py"],
|
||||
"explorer.fileNesting.enabled": true,
|
||||
|
@ -117,35 +117,58 @@ def validate_long_live_token(session: Session, client_token: str, user_id: str)
|
||||
|
||||
|
||||
def validate_file_token(token: Optional[str] = None) -> Path:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="could not validate file token",
|
||||
)
|
||||
"""
|
||||
Args:
|
||||
token (Optional[str], optional): _description_. Defaults to None.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 Bad Request when no token or the file doesn't exist
|
||||
HTTPException: 401 Unauthorized when the token is invalid
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
|
||||
file_path = Path(payload.get("file"))
|
||||
except JWTError as e:
|
||||
raise credentials_exception from e
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="could not validate file token",
|
||||
) from e
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def validate_recipe_token(token: Optional[str] = None) -> str:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="could not validate file token",
|
||||
)
|
||||
"""
|
||||
Args:
|
||||
token (Optional[str], optional): _description_. Defaults to None.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 Bad Request when no token or the recipe doesn't exist
|
||||
HTTPException: 401 JWTError when token is invalid
|
||||
|
||||
Returns:
|
||||
str: token data
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
|
||||
slug = payload.get("slug")
|
||||
slug: str | None = payload.get("slug")
|
||||
except JWTError as e:
|
||||
raise credentials_exception from e
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="could not validate file token",
|
||||
) from e
|
||||
|
||||
if slug is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return slug
|
||||
|
||||
|
@ -3,6 +3,17 @@ from sqlite3 import IntegrityError
|
||||
from mealie.lang.providers import Translator
|
||||
|
||||
|
||||
class UnexpectedNone(Exception):
|
||||
"""Exception raised when a value is None when it should not be."""
|
||||
|
||||
def __init__(self, message: str = "Unexpected None Value"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.message}"
|
||||
|
||||
|
||||
class PermissionDenied(Exception):
|
||||
"""
|
||||
This exception is raised when a user tries to access a resource that they do not have permission to access.
|
||||
|
@ -62,11 +62,16 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
|
||||
conn.set_option(ldap.OPT_X_TLS_CACERTFILE, settings.LDAP_TLS_CACERTFILE)
|
||||
conn.set_option(ldap.OPT_X_TLS_NEWCTX, 0)
|
||||
user = db.users.get_one(username, "email", any_case=True)
|
||||
|
||||
if not settings.LDAP_BIND_TEMPLATE:
|
||||
return False
|
||||
|
||||
if not user:
|
||||
user_bind = settings.LDAP_BIND_TEMPLATE.format(username)
|
||||
user = db.users.get_one(username, "username", any_case=True)
|
||||
else:
|
||||
user_bind = settings.LDAP_BIND_TEMPLATE.format(user.username)
|
||||
|
||||
try:
|
||||
conn.simple_bind_s(user_bind, password)
|
||||
except (ldap.INVALID_CREDENTIALS, ldap.NO_SUCH_OBJECT):
|
||||
@ -86,7 +91,7 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
|
||||
else:
|
||||
return False
|
||||
|
||||
if not user:
|
||||
if user is None:
|
||||
user = db.users.create(
|
||||
{
|
||||
"username": username,
|
||||
@ -96,6 +101,7 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
|
||||
"admin": False,
|
||||
},
|
||||
)
|
||||
|
||||
if settings.LDAP_ADMIN_FILTER:
|
||||
user.admin = len(conn.search_s(user_dn, ldap.SCOPE_BASE, settings.LDAP_ADMIN_FILTER, [])) > 0
|
||||
db.users.update(user.id, user)
|
||||
|
@ -93,18 +93,18 @@ class AppSettings(BaseSettings):
|
||||
|
||||
@staticmethod
|
||||
def validate_smtp(
|
||||
host: str,
|
||||
port: str,
|
||||
from_name: str,
|
||||
from_email: str,
|
||||
strategy: str,
|
||||
host: str | None,
|
||||
port: str | None,
|
||||
from_name: str | None,
|
||||
from_email: str | None,
|
||||
strategy: str | None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
) -> bool:
|
||||
"""Validates all SMTP variables are set"""
|
||||
required = {host, port, from_name, from_email, strategy}
|
||||
|
||||
if strategy.upper() in {"TLS", "SSL"}:
|
||||
if strategy and strategy.upper() in {"TLS", "SSL"}:
|
||||
required.add(user)
|
||||
required.add(password)
|
||||
|
||||
|
@ -138,13 +138,19 @@ class RecipeModel(SqlAlchemyBase, BaseMixins):
|
||||
**_,
|
||||
) -> None:
|
||||
self.nutrition = Nutrition(**nutrition) if nutrition else Nutrition()
|
||||
self.recipe_instructions = [RecipeInstruction(**step, session=session) for step in recipe_instructions]
|
||||
self.recipe_ingredient = [RecipeIngredient(**ingr, session=session) for ingr in recipe_ingredient]
|
||||
self.assets = [RecipeAsset(**a) for a in assets]
|
||||
|
||||
# Mealie Specific
|
||||
if recipe_instructions:
|
||||
self.recipe_instructions = [RecipeInstruction(**step, session=session) for step in recipe_instructions]
|
||||
|
||||
if recipe_ingredient:
|
||||
self.recipe_ingredient = [RecipeIngredient(**ingr, session=session) for ingr in recipe_ingredient]
|
||||
|
||||
if assets:
|
||||
self.assets = [RecipeAsset(**a) for a in assets]
|
||||
|
||||
self.settings = RecipeSettings(**settings) if settings else RecipeSettings()
|
||||
self.notes = [Note(**note) for note in notes]
|
||||
|
||||
# Time Stampes
|
||||
if notes:
|
||||
self.notes = [Note(**n) for n in notes]
|
||||
|
||||
self.date_updated = datetime.datetime.now()
|
||||
|
@ -1 +1,5 @@
|
||||
from .repository_factory import AllRepositories
|
||||
|
||||
__all__ = [
|
||||
"AllRepositories",
|
||||
]
|
||||
|
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
from typing import Any, Generic, TypeVar, Union
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import UUID4, BaseModel
|
||||
@ -24,8 +26,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
Generic ([Model]): Represents the SqlAlchemyModel Model
|
||||
"""
|
||||
|
||||
user_id: UUID4 = None
|
||||
group_id: UUID4 = None
|
||||
user_id: UUID4 | None = None
|
||||
group_id: UUID4 | None = None
|
||||
|
||||
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
|
||||
self.session = session
|
||||
@ -35,11 +37,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
|
||||
self.logger = get_logger()
|
||||
|
||||
def by_user(self, user_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
|
||||
def by_user(self, user_id: UUID4) -> RepositoryGeneric[Schema, Model]:
|
||||
self.user_id = user_id
|
||||
return self
|
||||
|
||||
def by_group(self, group_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
|
||||
def by_group(self, group_id: UUID4) -> RepositoryGeneric[Schema, Model]:
|
||||
self.group_id = group_id
|
||||
return self
|
||||
|
||||
@ -221,7 +223,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
attr_match: str = None,
|
||||
count=True,
|
||||
override_schema=None,
|
||||
) -> Union[int, list[Schema]]: # sourcery skip: assign-if-exp
|
||||
) -> int | list[Schema]: # sourcery skip: assign-if-exp
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
q = self._query().filter(attribute_name == attr_match)
|
||||
|
@ -19,6 +19,10 @@ class UserFavoritesController(BaseUserController):
|
||||
def add_favorite(self, id: UUID4, slug: str):
|
||||
"""Adds a Recipe to the users favorites"""
|
||||
assert_user_change_allowed(id, self.user)
|
||||
|
||||
if not self.user.favorite_recipes:
|
||||
self.user.favorite_recipes = []
|
||||
|
||||
self.user.favorite_recipes.append(slug)
|
||||
self.repos.users.update(self.user.id, self.user)
|
||||
|
||||
@ -26,6 +30,10 @@ class UserFavoritesController(BaseUserController):
|
||||
def remove_favorite(self, id: UUID4, slug: str):
|
||||
"""Adds a Recipe to the users favorites"""
|
||||
assert_user_change_allowed(id, self.user)
|
||||
|
||||
if not self.user.favorite_recipes:
|
||||
self.user.favorite_recipes = []
|
||||
|
||||
self.user.favorite_recipes = [x for x in self.user.favorite_recipes if x != slug]
|
||||
self.repos.users.update(self.user.id, self.user)
|
||||
return
|
||||
|
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from starlette.responses import FileResponse
|
||||
@ -10,7 +9,7 @@ router = APIRouter(prefix="/api/utils", tags=["Utils"], include_in_schema=True)
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
async def download_file(file_path: Optional[Path] = Depends(validate_file_token)):
|
||||
async def download_file(file_path: Path = Depends(validate_file_token)):
|
||||
"""Uses a file token obtained by an active user to retrieve a file from the operating
|
||||
system."""
|
||||
if not file_path.is_file():
|
||||
|
@ -7,7 +7,7 @@ class EmailReady(MealieModel):
|
||||
|
||||
class EmailSuccess(MealieModel):
|
||||
success: bool
|
||||
error: str = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class EmailTest(MealieModel):
|
||||
|
@ -34,9 +34,9 @@ class ShoppingListItemCreate(MealieModel):
|
||||
|
||||
note: str | None = ""
|
||||
quantity: float = 1
|
||||
unit_id: UUID4 = None
|
||||
unit_id: UUID4 | None = None
|
||||
unit: IngredientUnit | None
|
||||
food_id: UUID4 = None
|
||||
food_id: UUID4 | None = None
|
||||
food: IngredientFood | None
|
||||
|
||||
label_id: UUID4 | None = None
|
||||
@ -67,7 +67,7 @@ class ShoppingListItemOut(ShoppingListItemUpdate):
|
||||
|
||||
|
||||
class ShoppingListCreate(MealieModel):
|
||||
name: str = None
|
||||
name: str | None = None
|
||||
extras: dict | None = {}
|
||||
|
||||
created_at: datetime | None
|
||||
|
@ -25,7 +25,7 @@ app_dirs = get_app_dirs()
|
||||
|
||||
|
||||
class RecipeTag(MealieModel):
|
||||
id: UUID4 = None
|
||||
id: UUID4 | None = None
|
||||
name: str
|
||||
slug: str
|
||||
|
||||
@ -56,8 +56,8 @@ class RecipeToolPagination(PaginationBase):
|
||||
|
||||
class CreateRecipeBulk(BaseModel):
|
||||
url: str
|
||||
categories: list[RecipeCategory] = None
|
||||
tags: list[RecipeTag] = None
|
||||
categories: list[RecipeCategory] | None = None
|
||||
tags: list[RecipeTag] | None = None
|
||||
|
||||
|
||||
class CreateRecipeByUrlBulk(BaseModel):
|
||||
|
@ -21,7 +21,7 @@ class PaginationQuery(MealieModel):
|
||||
per_page: int = 50
|
||||
order_by: str = "created_at"
|
||||
order_direction: OrderDirection = OrderDirection.desc
|
||||
query_filter: str = None
|
||||
query_filter: str | None = None
|
||||
|
||||
|
||||
class PaginationBase(GenericModel, Generic[DataT]):
|
||||
@ -52,15 +52,12 @@ class PaginationBase(GenericModel, Generic[DataT]):
|
||||
self.previous = PaginationBase.merge_query_parameters(route, query_params)
|
||||
|
||||
def set_pagination_guides(self, route: str, query_params: dict[str, Any] | None) -> None:
|
||||
if not query_params:
|
||||
query_params = {}
|
||||
|
||||
query_params = camelize(query_params)
|
||||
valid_dict: dict[str, Any] = camelize(query_params) if query_params else {}
|
||||
|
||||
# sanitize user input
|
||||
self.page = max(self.page, 1)
|
||||
self._set_next(route, query_params)
|
||||
self._set_prev(route, query_params)
|
||||
self._set_next(route, valid_dict)
|
||||
self._set_prev(route, valid_dict)
|
||||
|
||||
@staticmethod
|
||||
def merge_query_parameters(url: str, params: dict[str, Any]):
|
||||
|
@ -21,9 +21,9 @@ class AlchemyExporter(BaseService):
|
||||
look_for_time = {"scheduled_time"}
|
||||
|
||||
class DateTimeParser(BaseModel):
|
||||
date: datetime.date = None
|
||||
dt: datetime.datetime = None
|
||||
time: datetime.time = None
|
||||
date: datetime.date | None = None
|
||||
dt: datetime.datetime | None = None
|
||||
time: datetime.time | None = None
|
||||
|
||||
def __init__(self, connection_str: str) -> None:
|
||||
super().__init__()
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
|
||||
class BackupContents:
|
||||
_tables: dict = None
|
||||
_tables: dict | None = None
|
||||
|
||||
def __init__(self, file: Path) -> None:
|
||||
self.base = file
|
||||
|
@ -17,15 +17,17 @@ class BackupV2(BaseService):
|
||||
def __init__(self, db_url: str = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.db_url = db_url or self.settings.DB_URL
|
||||
# type - one of these has to be a string
|
||||
self.db_url: str = db_url or self.settings.DB_URL # type: ignore
|
||||
|
||||
self.db_exporter = AlchemyExporter(self.db_url)
|
||||
|
||||
def _sqlite(self) -> None:
|
||||
db_file = self.settings.DB_URL.removeprefix("sqlite:///")
|
||||
db_file = self.settings.DB_URL.removeprefix("sqlite:///") # type: ignore
|
||||
|
||||
# Create a backup of the SQLite database
|
||||
timestamp = datetime.datetime.now().strftime("%Y.%m.%d")
|
||||
shutil.copy(db_file, f"mealie_{timestamp}.bak.db")
|
||||
shutil.copy(db_file, self.directories.DATA_DIR.joinpath(f"mealie_{timestamp}.bak.db"))
|
||||
|
||||
def _postgres(self) -> None:
|
||||
pass
|
||||
|
@ -11,8 +11,8 @@ from mealie.services._base_service import BaseService
|
||||
class EmailOptions:
|
||||
host: str
|
||||
port: int
|
||||
username: str = None
|
||||
password: str = None
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
tls: bool = False
|
||||
ssl: bool = False
|
||||
|
||||
@ -39,7 +39,9 @@ class Message:
|
||||
|
||||
if smtp.ssl:
|
||||
with smtplib.SMTP_SSL(smtp.host, smtp.port) as server:
|
||||
server.login(smtp.username, smtp.password)
|
||||
if smtp.username and smtp.password:
|
||||
server.login(smtp.username, smtp.password)
|
||||
|
||||
errors = server.send_message(msg)
|
||||
else:
|
||||
with smtplib.SMTP(smtp.host, smtp.port) as server:
|
||||
@ -66,17 +68,24 @@ class DefaultEmailSender(ABCEmailSender, BaseService):
|
||||
"""
|
||||
|
||||
def send(self, email_to: str, subject: str, html: str) -> bool:
|
||||
|
||||
if self.settings.SMTP_FROM_EMAIL is None or self.settings.SMTP_FROM_NAME is None:
|
||||
raise ValueError("SMTP_FROM_EMAIL and SMTP_FROM_NAME must be set in the config file.")
|
||||
|
||||
message = Message(
|
||||
subject=subject,
|
||||
html=html,
|
||||
mail_from=(self.settings.SMTP_FROM_NAME, self.settings.SMTP_FROM_EMAIL),
|
||||
)
|
||||
|
||||
if self.settings.SMTP_HOST is None or self.settings.SMTP_PORT is None:
|
||||
raise ValueError("SMTP_HOST, SMTP_PORT must be set in the config file.")
|
||||
|
||||
smtp_options = EmailOptions(
|
||||
self.settings.SMTP_HOST,
|
||||
int(self.settings.SMTP_PORT),
|
||||
tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS",
|
||||
ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL",
|
||||
tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS" if self.settings.SMTP_AUTH_STRATEGY else False,
|
||||
ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL" if self.settings.SMTP_AUTH_STRATEGY else False,
|
||||
)
|
||||
|
||||
if self.settings.SMTP_USER:
|
||||
|
@ -41,7 +41,7 @@ class EventListenerBase(ABC):
|
||||
...
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ensure_session(self) -> Generator[None, None, None]:
|
||||
def ensure_session(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
ensure_session ensures that a session is available for the caller by checking if a session
|
||||
was provided during construction, and if not, creating a new session with the `with_session`
|
||||
@ -54,10 +54,9 @@ class EventListenerBase(ABC):
|
||||
if self.session is None:
|
||||
with session_context() as session:
|
||||
self.session = session
|
||||
yield
|
||||
|
||||
yield self.session
|
||||
else:
|
||||
yield
|
||||
yield self.session
|
||||
|
||||
|
||||
class AppriseEventListener(EventListenerBase):
|
||||
@ -87,7 +86,7 @@ class AppriseEventListener(EventListenerBase):
|
||||
"integration_id": event.integration_id,
|
||||
"document_data": json.dumps(jsonable_encoder(event.document_data)),
|
||||
"event_id": str(event.event_id),
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"timestamp": event.timestamp.isoformat() if event.timestamp else None,
|
||||
}
|
||||
|
||||
return [
|
||||
@ -148,9 +147,9 @@ class WebhookEventListener(EventListenerBase):
|
||||
|
||||
def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]:
|
||||
"""Fetches all scheduled webhooks from the database"""
|
||||
with self.ensure_session():
|
||||
with self.ensure_session() as session:
|
||||
return (
|
||||
self.session.query(GroupWebhooksModel)
|
||||
session.query(GroupWebhooksModel)
|
||||
.where(
|
||||
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
|
||||
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),
|
||||
|
@ -1,29 +0,0 @@
|
||||
import random
|
||||
|
||||
from pydantic import UUID4
|
||||
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.schema.recipe.recipe import Recipe, RecipeCategory
|
||||
from mealie.services._base_service import BaseService
|
||||
|
||||
|
||||
class MealPlanService(BaseService):
|
||||
def __init__(self, group_id: UUID4, repos: AllRepositories):
|
||||
self.group_id = group_id
|
||||
self.repos = repos
|
||||
|
||||
def get_random_recipe(self, categories: list[RecipeCategory] = None) -> Recipe:
|
||||
"""get_random_recipe returns a single recipe matching a specific criteria of
|
||||
categories. if no categories are provided, a single recipe is returned from the
|
||||
entire recipe database.
|
||||
|
||||
Note that the recipe must contain ALL categories in the list provided.
|
||||
|
||||
Args:
|
||||
categories (list[RecipeCategory], optional): [description]. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Recipe: [description]
|
||||
"""
|
||||
recipes = self.repos.recipes.by_group(self.group_id).get_by_categories(categories)
|
||||
return random.choice(recipes)
|
@ -1,5 +1,6 @@
|
||||
from pydantic import UUID4
|
||||
|
||||
from mealie.core.exceptions import UnexpectedNone
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.schema.group import ShoppingListItemCreate, ShoppingListOut
|
||||
from mealie.schema.group.group_shopping_list import (
|
||||
@ -120,8 +121,10 @@ class ShoppingListService:
|
||||
- deleted_shopping_list_items
|
||||
"""
|
||||
recipe = self.repos.recipes.get_one(recipe_id, "id")
|
||||
to_create = []
|
||||
if not recipe:
|
||||
raise UnexpectedNone("Recipe not found")
|
||||
|
||||
to_create = []
|
||||
for ingredient in recipe.recipe_ingredient:
|
||||
food_id = None
|
||||
try:
|
||||
@ -144,7 +147,7 @@ class ShoppingListService:
|
||||
to_create.append(
|
||||
ShoppingListItemCreate(
|
||||
shopping_list_id=list_id,
|
||||
is_food=not recipe.settings.disable_amount,
|
||||
is_food=not recipe.settings.disable_amount if recipe.settings else False,
|
||||
food_id=food_id,
|
||||
unit_id=unit_id,
|
||||
quantity=ingredient.quantity,
|
||||
@ -163,6 +166,9 @@ class ShoppingListService:
|
||||
new_shopping_list_items = [self.repos.group_shopping_list_item.create(item) for item in to_create]
|
||||
|
||||
updated_shopping_list = self.shopping_lists.get_one(list_id)
|
||||
if not updated_shopping_list:
|
||||
raise UnexpectedNone("Shopping List not found")
|
||||
|
||||
updated_shopping_list_items, deleted_shopping_list_items = self.consolidate_and_save(updated_shopping_list.list_items) # type: ignore
|
||||
updated_shopping_list.list_items = updated_shopping_list_items
|
||||
|
||||
@ -219,13 +225,16 @@ class ShoppingListService:
|
||||
"""
|
||||
shopping_list = self.shopping_lists.get_one(list_id)
|
||||
|
||||
if shopping_list is None:
|
||||
raise UnexpectedNone("Shopping list not found, cannot remove recipe ingredients")
|
||||
|
||||
updated_shopping_list_items = []
|
||||
deleted_shopping_list_items = []
|
||||
for item in shopping_list.list_items:
|
||||
found = False
|
||||
|
||||
for ref in item.recipe_references:
|
||||
remove_qty = 0.0
|
||||
remove_qty: None | float = 0.0
|
||||
|
||||
if ref.recipe_id == recipe_id:
|
||||
self.list_item_refs.delete(ref.id) # type: ignore
|
||||
@ -236,7 +245,9 @@ class ShoppingListService:
|
||||
|
||||
# If the item was found decrement the quantity by the remove_qty
|
||||
if found:
|
||||
item.quantity = item.quantity - remove_qty
|
||||
|
||||
if remove_qty is not None:
|
||||
item.quantity = item.quantity - remove_qty
|
||||
|
||||
if item.quantity <= 0:
|
||||
self.list_items.delete(item.id)
|
||||
@ -246,16 +257,16 @@ class ShoppingListService:
|
||||
updated_shopping_list_items.append(item)
|
||||
|
||||
# Decrement the list recipe reference count
|
||||
for ref in shopping_list.recipe_references: # type: ignore
|
||||
if ref.recipe_id == recipe_id:
|
||||
ref.recipe_quantity -= 1
|
||||
for recipe_ref in shopping_list.recipe_references:
|
||||
if recipe_ref.recipe_id == recipe_id and recipe_ref.recipe_quantity is not None:
|
||||
recipe_ref.recipe_quantity -= 1.0
|
||||
|
||||
if ref.recipe_quantity <= 0:
|
||||
self.list_refs.delete(ref.id) # type: ignore
|
||||
if recipe_ref.recipe_quantity <= 0.0:
|
||||
self.list_refs.delete(recipe_ref.id)
|
||||
|
||||
else:
|
||||
self.list_refs.update(ref.id, ref) # type: ignore
|
||||
self.list_refs.update(recipe_ref.id, ref)
|
||||
break
|
||||
|
||||
# Save Changes
|
||||
return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items
|
||||
return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items # type: ignore
|
||||
|
@ -96,7 +96,12 @@ class BaseMigrator(BaseService):
|
||||
self._migrate()
|
||||
self._save_all_entries()
|
||||
|
||||
return self.db.group_reports.get_one(self.report_id)
|
||||
result = self.db.group_reports.get_one(self.report_id)
|
||||
|
||||
if not result:
|
||||
raise ValueError("Report not found")
|
||||
|
||||
return result
|
||||
|
||||
def import_recipes_to_database(self, validated_recipes: list[Recipe]) -> list[tuple[str, UUID4, bool]]:
|
||||
"""
|
||||
@ -111,10 +116,13 @@ class BaseMigrator(BaseService):
|
||||
if self.add_migration_tag:
|
||||
migration_tag = self.helpers.get_or_set_tags([self.name])[0]
|
||||
|
||||
return_vars = []
|
||||
return_vars: list[tuple[str, UUID4, bool]] = []
|
||||
|
||||
group = self.db.groups.get_one(self.group_id)
|
||||
|
||||
if not group or not group.preferences:
|
||||
raise ValueError("Group preferences not found")
|
||||
|
||||
default_settings = RecipeSettings(
|
||||
public=group.preferences.recipe_public,
|
||||
show_nutrition=group.preferences.recipe_show_nutrition,
|
||||
@ -132,6 +140,8 @@ class BaseMigrator(BaseService):
|
||||
|
||||
if recipe.tags:
|
||||
recipe.tags = self.helpers.get_or_set_tags(x.name for x in recipe.tags)
|
||||
else:
|
||||
recipe.tags = []
|
||||
|
||||
if recipe.recipe_category:
|
||||
recipe.recipe_category = self.helpers.get_or_set_category(x.name for x in recipe.recipe_category)
|
||||
@ -155,7 +165,7 @@ class BaseMigrator(BaseService):
|
||||
else:
|
||||
message = f"Failed to import {recipe.name}"
|
||||
|
||||
return_vars.append((recipe.slug, recipe.id, status))
|
||||
return_vars.append((recipe.slug, recipe.id, status)) # type: ignore
|
||||
|
||||
self.report_entries.append(
|
||||
ReportEntryCreate(
|
||||
|
@ -42,8 +42,13 @@ class ChowdownMigrator(BaseMigrator):
|
||||
for slug, recipe_id, status in results:
|
||||
if status:
|
||||
try:
|
||||
original_image = recipe_lookup.get(slug).image
|
||||
cd_image = image_dir.joinpath(original_image)
|
||||
r = recipe_lookup.get(slug)
|
||||
|
||||
if not r:
|
||||
continue
|
||||
|
||||
if r.image:
|
||||
cd_image = image_dir.joinpath(r.image)
|
||||
except StopIteration:
|
||||
continue
|
||||
if cd_image:
|
||||
|
@ -1,5 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
from mealie.core.exceptions import UnexpectedNone
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.schema.group.group_exports import GroupDataExport
|
||||
from mealie.schema.recipe import CategoryBase
|
||||
@ -41,6 +42,9 @@ class RecipeBulkActionsService(BaseService):
|
||||
|
||||
group = self.repos.groups.get_one(self.group.id)
|
||||
|
||||
if group is None:
|
||||
raise UnexpectedNone("Failed to purge exports for group, no group found")
|
||||
|
||||
for match in group.directory.glob("**/export/*zip"):
|
||||
if match.is_file():
|
||||
match.unlink()
|
||||
@ -52,8 +56,8 @@ class RecipeBulkActionsService(BaseService):
|
||||
for slug in recipes:
|
||||
recipe = self.repos.recipes.get_one(slug)
|
||||
|
||||
if recipe is None:
|
||||
self.logger.error(f"Failed to set settings for recipe {slug}, no recipe found")
|
||||
if recipe is None or recipe.settings is None:
|
||||
raise UnexpectedNone(f"Failed to set settings for recipe {slug}, no recipe found")
|
||||
|
||||
settings.locked = recipe.settings.locked
|
||||
recipe.settings = settings
|
||||
@ -69,9 +73,12 @@ class RecipeBulkActionsService(BaseService):
|
||||
recipe = self.repos.recipes.get_one(slug)
|
||||
|
||||
if recipe is None:
|
||||
self.logger.error(f"Failed to tag recipe {slug}, no recipe found")
|
||||
raise UnexpectedNone(f"Failed to tag recipe {slug}, no recipe found")
|
||||
|
||||
recipe.tags += tags
|
||||
if recipe.tags is None:
|
||||
recipe.tags = []
|
||||
|
||||
recipe.tags += tags # type: ignore
|
||||
|
||||
try:
|
||||
self.repos.recipes.update(slug, recipe)
|
||||
@ -84,9 +91,12 @@ class RecipeBulkActionsService(BaseService):
|
||||
recipe = self.repos.recipes.get_one(slug)
|
||||
|
||||
if recipe is None:
|
||||
self.logger.error(f"Failed to categorize recipe {slug}, no recipe found")
|
||||
raise UnexpectedNone(f"Failed to categorize recipe {slug}, no recipe found")
|
||||
|
||||
recipe.recipe_category += categories
|
||||
if recipe.recipe_category is None:
|
||||
recipe.recipe_category = []
|
||||
|
||||
recipe.recipe_category += categories # type: ignore
|
||||
|
||||
try:
|
||||
self.repos.recipes.update(slug, recipe)
|
||||
|
@ -50,6 +50,8 @@ class RecipeService(BaseService):
|
||||
return recipe
|
||||
|
||||
def can_update(self, recipe: Recipe) -> bool:
|
||||
if recipe.settings is None:
|
||||
raise exceptions.UnexpectedNone("Recipe Settings is None")
|
||||
return recipe.settings.locked is False or self.user.id == recipe.user_id
|
||||
|
||||
def can_lock_unlock(self, recipe: Recipe) -> bool:
|
||||
@ -66,6 +68,9 @@ class RecipeService(BaseService):
|
||||
except FileNotFoundError:
|
||||
self.logger.error(f"Recipe Directory not Found: {original_slug}")
|
||||
|
||||
if recipe.assets is None:
|
||||
recipe.assets = []
|
||||
|
||||
all_asset_files = [x.file_name for x in recipe.assets]
|
||||
|
||||
for file in recipe.asset_dir.iterdir():
|
||||
@ -92,7 +97,7 @@ class RecipeService(BaseService):
|
||||
additional_attrs["group_id"] = user.group_id
|
||||
|
||||
if additional_attrs.get("tags"):
|
||||
for i in range(len(additional_attrs.get("tags"))):
|
||||
for i in range(len(additional_attrs.get("tags", []))):
|
||||
additional_attrs["tags"][i]["group_id"] = user.group_id
|
||||
|
||||
if not additional_attrs.get("recipe_ingredient"):
|
||||
@ -105,6 +110,9 @@ class RecipeService(BaseService):
|
||||
|
||||
def create_one(self, create_data: Union[Recipe, CreateRecipe]) -> Recipe:
|
||||
|
||||
if create_data.name is None:
|
||||
create_data.name = "New Recipe"
|
||||
|
||||
data: Recipe = self._recipe_creation_factory(
|
||||
self.user,
|
||||
name=create_data.name,
|
||||
@ -134,8 +142,8 @@ class RecipeService(BaseService):
|
||||
with temp_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(archive.file, buffer)
|
||||
|
||||
recipe_dict = None
|
||||
recipe_image = None
|
||||
recipe_dict: dict | None = None
|
||||
recipe_image: bytes | None = None
|
||||
|
||||
with ZipFile(temp_path) as myzip:
|
||||
for file in myzip.namelist():
|
||||
@ -146,10 +154,15 @@ class RecipeService(BaseService):
|
||||
with myzip.open(file) as myfile:
|
||||
recipe_image = myfile.read()
|
||||
|
||||
if recipe_dict is None:
|
||||
raise exceptions.UnexpectedNone("No json data found in Zip")
|
||||
|
||||
recipe = self.create_one(Recipe(**recipe_dict))
|
||||
|
||||
if recipe:
|
||||
if recipe and recipe.id:
|
||||
data_service = RecipeDataService(recipe.id)
|
||||
|
||||
if recipe_image:
|
||||
data_service.write_image(recipe_image, "webp")
|
||||
|
||||
return recipe
|
||||
@ -172,6 +185,10 @@ class RecipeService(BaseService):
|
||||
"""
|
||||
|
||||
recipe = self._get_recipe(slug)
|
||||
|
||||
if recipe is None or recipe.settings is None:
|
||||
raise exceptions.NoEntryFound("Recipe not found.")
|
||||
|
||||
if not self.can_update(recipe):
|
||||
raise exceptions.PermissionDenied("You do not have permission to edit this recipe.")
|
||||
|
||||
@ -189,9 +206,12 @@ class RecipeService(BaseService):
|
||||
return new_data
|
||||
|
||||
def patch_one(self, slug: str, patch_data: Recipe) -> Recipe:
|
||||
recipe = self._pre_update_check(slug, patch_data)
|
||||
recipe: Recipe | None = self._pre_update_check(slug, patch_data)
|
||||
recipe = self.repos.recipes.by_group(self.group.id).get_one(slug)
|
||||
|
||||
if recipe is None:
|
||||
raise exceptions.NoEntryFound("Recipe not found.")
|
||||
|
||||
new_data = self.repos.recipes.by_group(self.group.id).patch(recipe.slug, patch_data.dict(exclude_unset=True))
|
||||
|
||||
self.check_assets(new_data, recipe.slug)
|
||||
@ -210,6 +230,6 @@ class RecipeService(BaseService):
|
||||
# =================================================================
|
||||
# Recipe Template Methods
|
||||
|
||||
def render_template(self, recipe: Recipe, temp_dir: Path, template: str = None) -> Path:
|
||||
def render_template(self, recipe: Recipe, temp_dir: Path, template: str) -> Path:
|
||||
t_service = TemplateService(temp_dir)
|
||||
return t_service.render(recipe, template)
|
||||
|
@ -16,7 +16,7 @@ class TemplateType(str, enum.Enum):
|
||||
|
||||
|
||||
class TemplateService(BaseService):
|
||||
def __init__(self, temp: Path = None) -> None:
|
||||
def __init__(self, temp: Path | None = None) -> None:
|
||||
"""Creates a template service that can be used for multiple template generations
|
||||
A temporary directory must be provided as a place holder for where to render all templates
|
||||
Args:
|
||||
@ -58,7 +58,7 @@ class TemplateService(BaseService):
|
||||
|
||||
return TemplateType(t_type)
|
||||
|
||||
def render(self, recipe: Recipe, template: str = None) -> Path:
|
||||
def render(self, recipe: Recipe, template: str) -> Path:
|
||||
"""
|
||||
Renders a TemplateType in a temporary directory and returns the path to the file.
|
||||
|
||||
@ -87,6 +87,9 @@ class TemplateService(BaseService):
|
||||
"""
|
||||
self.__check_temp(self._render_json)
|
||||
|
||||
if self.temp is None:
|
||||
raise ValueError("Temporary directory must be provided for method _render_json")
|
||||
|
||||
save_path = self.temp.joinpath(f"{recipe.slug}.json")
|
||||
with open(save_path, "w") as f:
|
||||
f.write(recipe.json(indent=4, by_alias=True))
|
||||
@ -100,6 +103,9 @@ class TemplateService(BaseService):
|
||||
"""
|
||||
self.__check_temp(self._render_jinja2)
|
||||
|
||||
if j2_template is None:
|
||||
raise ValueError("Template must be provided for method _render_jinja2")
|
||||
|
||||
j2_path: Path = self.directories.TEMPLATE_DIR / j2_template
|
||||
|
||||
if not j2_path.is_file():
|
||||
@ -113,6 +119,9 @@ class TemplateService(BaseService):
|
||||
|
||||
save_name = f"{recipe.slug}{j2_path.suffix}"
|
||||
|
||||
if self.temp is None:
|
||||
raise ValueError("Temporary directory must be provided for method _render_jinja2")
|
||||
|
||||
save_path = self.temp.joinpath(save_name)
|
||||
|
||||
with open(save_path, "w") as f:
|
||||
@ -124,6 +133,10 @@ class TemplateService(BaseService):
|
||||
self.__check_temp(self._render_jinja2)
|
||||
|
||||
image_asset = recipe.image_dir.joinpath(RecipeImageTypes.original.value)
|
||||
|
||||
if self.temp is None:
|
||||
raise ValueError("Temporary directory must be provided for method _render_zip")
|
||||
|
||||
zip_temp = self.temp.joinpath(f"{recipe.slug}.zip")
|
||||
|
||||
with ZipFile(zip_temp, "w") as myzip:
|
||||
|
@ -19,7 +19,7 @@ class ParserErrors(str, Enum):
|
||||
CONNECTION_ERROR = "CONNECTION_ERROR"
|
||||
|
||||
|
||||
def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]:
|
||||
def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras | None]:
|
||||
"""Main entry point for generating a recipe from a URL. Pass in a URL and
|
||||
a Recipe object will be returned if successful.
|
||||
|
||||
@ -43,6 +43,10 @@ def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]:
|
||||
|
||||
try:
|
||||
recipe_data_service.scrape_image(new_recipe.image)
|
||||
|
||||
if new_recipe.name is None:
|
||||
new_recipe.name = "Untitled"
|
||||
|
||||
new_recipe.slug = slugify(new_recipe.name)
|
||||
new_recipe.image = cache.new_key(4)
|
||||
except Exception as e:
|
||||
|
@ -176,7 +176,7 @@ class RecipeScraperPackage(ABCScraperStrategy):
|
||||
ingredients = []
|
||||
|
||||
try:
|
||||
instruct = scraped_schema.instructions()
|
||||
instruct: list | str = scraped_schema.instructions()
|
||||
except Exception:
|
||||
instruct = []
|
||||
|
||||
@ -212,7 +212,7 @@ class RecipeScraperOpenGraph(ABCScraperStrategy):
|
||||
"""
|
||||
|
||||
def og_field(properties: dict, field_name: str) -> str:
|
||||
return next((val for name, val in properties if name == field_name), None)
|
||||
return next((val for name, val in properties if name == field_name), "")
|
||||
|
||||
def og_fields(properties: list[tuple[str, str]], field_name: str) -> list[str]:
|
||||
return list({val for name, val in properties if name == field_name})
|
||||
|
@ -31,6 +31,9 @@ class PasswordResetService(BaseService):
|
||||
def send_reset_email(self, email: str):
|
||||
token_entry = self.generate_reset_token(email)
|
||||
|
||||
if token_entry is None:
|
||||
return None
|
||||
|
||||
# Send Email
|
||||
email_servive = EmailService()
|
||||
reset_url = f"{self.settings.BASE_URL}/reset-password?token={token_entry.token}"
|
||||
|
@ -35,7 +35,8 @@ class RegistrationService:
|
||||
can_organize=new_group,
|
||||
)
|
||||
|
||||
return self.repos.users.create(new_user)
|
||||
# TODO: problem with repository type, not type here
|
||||
return self.repos.users.create(new_user) # type: ignore
|
||||
|
||||
def _register_new_group(self) -> GroupInDB:
|
||||
group_data = GroupBase(name=self.registration.group)
|
||||
@ -74,7 +75,13 @@ class RegistrationService:
|
||||
token_entry = self.repos.group_invite_tokens.get_one(registration.group_token)
|
||||
if not token_entry:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
|
||||
group = self.repos.groups.get_one(token_entry.group_id)
|
||||
|
||||
maybe_none_group = self.repos.groups.get_one(token_entry.group_id)
|
||||
|
||||
if maybe_none_group is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
|
||||
|
||||
group = maybe_none_group
|
||||
else:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Missing group"})
|
||||
|
||||
|
@ -105,5 +105,5 @@ pgsql = ["psycopg2-binary"]
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
follow_imports = "skip"
|
||||
strict_optional = false # TODO: Fix none type checks - temporary stop-gap to implement mypy
|
||||
strict_optional = true
|
||||
plugins = "pydantic.mypy"
|
||||
|
Loading…
x
Reference in New Issue
Block a user