fix: strict optional errors (#1759)

* fix strict optional errors

* fix typing in repository

* fix backup db files location

* update workspace settings
This commit is contained in:
Hayden 2022-10-23 13:04:04 -08:00 committed by GitHub
parent 97d9e2a109
commit 84c23765cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 253 additions and 139 deletions

View File

@ -1,10 +1,4 @@
{ {
"conventionalCommits.scopes": [
"frontend",
"docs",
"backend",
"code-generation"
],
"cSpell.enableFiletypes": ["!javascript", "!python", "!yaml"], "cSpell.enableFiletypes": ["!javascript", "!python", "!yaml"],
"cSpell.words": [ "cSpell.words": [
"chowdown", "chowdown",
@ -44,7 +38,7 @@
"python.testing.unittestEnabled": false, "python.testing.unittestEnabled": false,
"python.analysis.typeCheckingMode": "off", "python.analysis.typeCheckingMode": "off",
"python.linting.mypyEnabled": true, "python.linting.mypyEnabled": true,
"python.sortImports.path": "${workspaceFolder}/.venv/bin/isort", "isort.path": ["${workspaceFolder}/.venv/bin/isort"],
"search.mode": "reuseEditor", "search.mode": "reuseEditor",
"python.testing.unittestArgs": ["-v", "-s", "./tests", "-p", "test_*.py"], "python.testing.unittestArgs": ["-v", "-s", "./tests", "-p", "test_*.py"],
"explorer.fileNesting.enabled": true, "explorer.fileNesting.enabled": true,

View File

@ -117,35 +117,58 @@ def validate_long_live_token(session: Session, client_token: str, user_id: str)
def validate_file_token(token: Optional[str] = None) -> Path: def validate_file_token(token: Optional[str] = None) -> Path:
credentials_exception = HTTPException( """
status_code=status.HTTP_401_UNAUTHORIZED, Args:
detail="could not validate file token", token (Optional[str], optional): _description_. Defaults to None.
)
Raises:
HTTPException: 400 Bad Request when no token or the file doesn't exist
HTTPException: 401 Unauthorized when the token is invalid
"""
if not token: if not token:
return None raise HTTPException(status.HTTP_400_BAD_REQUEST)
try: try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM]) payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
file_path = Path(payload.get("file")) file_path = Path(payload.get("file"))
except JWTError as e: except JWTError as e:
raise credentials_exception from e raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
) from e
if not file_path.exists():
raise HTTPException(status.HTTP_400_BAD_REQUEST)
return file_path return file_path
def validate_recipe_token(token: Optional[str] = None) -> str: def validate_recipe_token(token: Optional[str] = None) -> str:
credentials_exception = HTTPException( """
status_code=status.HTTP_401_UNAUTHORIZED, Args:
detail="could not validate file token", token (Optional[str], optional): _description_. Defaults to None.
)
Raises:
HTTPException: 400 Bad Request when no token or the recipe doesn't exist
HTTPException: 401 JWTError when token is invalid
Returns:
str: token data
"""
if not token: if not token:
return None raise HTTPException(status.HTTP_400_BAD_REQUEST)
try: try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM]) payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
slug = payload.get("slug") slug: str | None = payload.get("slug")
except JWTError as e: except JWTError as e:
raise credentials_exception from e raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
) from e
if slug is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST)
return slug return slug

View File

@ -3,6 +3,17 @@ from sqlite3 import IntegrityError
from mealie.lang.providers import Translator from mealie.lang.providers import Translator
class UnexpectedNone(Exception):
"""Exception raised when a value is None when it should not be."""
def __init__(self, message: str = "Unexpected None Value"):
self.message = message
super().__init__(self.message)
def __str__(self):
return f"{self.message}"
class PermissionDenied(Exception): class PermissionDenied(Exception):
""" """
This exception is raised when a user tries to access a resource that they do not have permission to access. This exception is raised when a user tries to access a resource that they do not have permission to access.

View File

@ -62,11 +62,16 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
conn.set_option(ldap.OPT_X_TLS_CACERTFILE, settings.LDAP_TLS_CACERTFILE) conn.set_option(ldap.OPT_X_TLS_CACERTFILE, settings.LDAP_TLS_CACERTFILE)
conn.set_option(ldap.OPT_X_TLS_NEWCTX, 0) conn.set_option(ldap.OPT_X_TLS_NEWCTX, 0)
user = db.users.get_one(username, "email", any_case=True) user = db.users.get_one(username, "email", any_case=True)
if not settings.LDAP_BIND_TEMPLATE:
return False
if not user: if not user:
user_bind = settings.LDAP_BIND_TEMPLATE.format(username) user_bind = settings.LDAP_BIND_TEMPLATE.format(username)
user = db.users.get_one(username, "username", any_case=True) user = db.users.get_one(username, "username", any_case=True)
else: else:
user_bind = settings.LDAP_BIND_TEMPLATE.format(user.username) user_bind = settings.LDAP_BIND_TEMPLATE.format(user.username)
try: try:
conn.simple_bind_s(user_bind, password) conn.simple_bind_s(user_bind, password)
except (ldap.INVALID_CREDENTIALS, ldap.NO_SUCH_OBJECT): except (ldap.INVALID_CREDENTIALS, ldap.NO_SUCH_OBJECT):
@ -86,7 +91,7 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
else: else:
return False return False
if not user: if user is None:
user = db.users.create( user = db.users.create(
{ {
"username": username, "username": username,
@ -96,6 +101,7 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
"admin": False, "admin": False,
}, },
) )
if settings.LDAP_ADMIN_FILTER: if settings.LDAP_ADMIN_FILTER:
user.admin = len(conn.search_s(user_dn, ldap.SCOPE_BASE, settings.LDAP_ADMIN_FILTER, [])) > 0 user.admin = len(conn.search_s(user_dn, ldap.SCOPE_BASE, settings.LDAP_ADMIN_FILTER, [])) > 0
db.users.update(user.id, user) db.users.update(user.id, user)

View File

@ -93,18 +93,18 @@ class AppSettings(BaseSettings):
@staticmethod @staticmethod
def validate_smtp( def validate_smtp(
host: str, host: str | None,
port: str, port: str | None,
from_name: str, from_name: str | None,
from_email: str, from_email: str | None,
strategy: str, strategy: str | None,
user: str | None = None, user: str | None = None,
password: str | None = None, password: str | None = None,
) -> bool: ) -> bool:
"""Validates all SMTP variables are set""" """Validates all SMTP variables are set"""
required = {host, port, from_name, from_email, strategy} required = {host, port, from_name, from_email, strategy}
if strategy.upper() in {"TLS", "SSL"}: if strategy and strategy.upper() in {"TLS", "SSL"}:
required.add(user) required.add(user)
required.add(password) required.add(password)

View File

@ -138,13 +138,19 @@ class RecipeModel(SqlAlchemyBase, BaseMixins):
**_, **_,
) -> None: ) -> None:
self.nutrition = Nutrition(**nutrition) if nutrition else Nutrition() self.nutrition = Nutrition(**nutrition) if nutrition else Nutrition()
self.recipe_instructions = [RecipeInstruction(**step, session=session) for step in recipe_instructions]
self.recipe_ingredient = [RecipeIngredient(**ingr, session=session) for ingr in recipe_ingredient]
self.assets = [RecipeAsset(**a) for a in assets]
# Mealie Specific if recipe_instructions:
self.recipe_instructions = [RecipeInstruction(**step, session=session) for step in recipe_instructions]
if recipe_ingredient:
self.recipe_ingredient = [RecipeIngredient(**ingr, session=session) for ingr in recipe_ingredient]
if assets:
self.assets = [RecipeAsset(**a) for a in assets]
self.settings = RecipeSettings(**settings) if settings else RecipeSettings() self.settings = RecipeSettings(**settings) if settings else RecipeSettings()
self.notes = [Note(**note) for note in notes]
# Time Stampes if notes:
self.notes = [Note(**n) for n in notes]
self.date_updated = datetime.datetime.now() self.date_updated = datetime.datetime.now()

View File

@ -1 +1,5 @@
from .repository_factory import AllRepositories from .repository_factory import AllRepositories
__all__ = [
"AllRepositories",
]

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from math import ceil from math import ceil
from typing import Any, Generic, TypeVar, Union from typing import Any, Generic, TypeVar
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import UUID4, BaseModel from pydantic import UUID4, BaseModel
@ -24,8 +26,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
Generic ([Model]): Represents the SqlAlchemyModel Model Generic ([Model]): Represents the SqlAlchemyModel Model
""" """
user_id: UUID4 = None user_id: UUID4 | None = None
group_id: UUID4 = None group_id: UUID4 | None = None
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None: def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
self.session = session self.session = session
@ -35,11 +37,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger = get_logger() self.logger = get_logger()
def by_user(self, user_id: UUID4) -> "RepositoryGeneric[Schema, Model]": def by_user(self, user_id: UUID4) -> RepositoryGeneric[Schema, Model]:
self.user_id = user_id self.user_id = user_id
return self return self
def by_group(self, group_id: UUID4) -> "RepositoryGeneric[Schema, Model]": def by_group(self, group_id: UUID4) -> RepositoryGeneric[Schema, Model]:
self.group_id = group_id self.group_id = group_id
return self return self
@ -221,7 +223,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
attr_match: str = None, attr_match: str = None,
count=True, count=True,
override_schema=None, override_schema=None,
) -> Union[int, list[Schema]]: # sourcery skip: assign-if-exp ) -> int | list[Schema]: # sourcery skip: assign-if-exp
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
q = self._query().filter(attribute_name == attr_match) q = self._query().filter(attribute_name == attr_match)

View File

@ -19,6 +19,10 @@ class UserFavoritesController(BaseUserController):
def add_favorite(self, id: UUID4, slug: str): def add_favorite(self, id: UUID4, slug: str):
"""Adds a Recipe to the users favorites""" """Adds a Recipe to the users favorites"""
assert_user_change_allowed(id, self.user) assert_user_change_allowed(id, self.user)
if not self.user.favorite_recipes:
self.user.favorite_recipes = []
self.user.favorite_recipes.append(slug) self.user.favorite_recipes.append(slug)
self.repos.users.update(self.user.id, self.user) self.repos.users.update(self.user.id, self.user)
@ -26,6 +30,10 @@ class UserFavoritesController(BaseUserController):
def remove_favorite(self, id: UUID4, slug: str): def remove_favorite(self, id: UUID4, slug: str):
"""Adds a Recipe to the users favorites""" """Adds a Recipe to the users favorites"""
assert_user_change_allowed(id, self.user) assert_user_change_allowed(id, self.user)
if not self.user.favorite_recipes:
self.user.favorite_recipes = []
self.user.favorite_recipes = [x for x in self.user.favorite_recipes if x != slug] self.user.favorite_recipes = [x for x in self.user.favorite_recipes if x != slug]
self.repos.users.update(self.user.id, self.user) self.repos.users.update(self.user.id, self.user)
return return

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from starlette.responses import FileResponse from starlette.responses import FileResponse
@ -10,7 +9,7 @@ router = APIRouter(prefix="/api/utils", tags=["Utils"], include_in_schema=True)
@router.get("/download") @router.get("/download")
async def download_file(file_path: Optional[Path] = Depends(validate_file_token)): async def download_file(file_path: Path = Depends(validate_file_token)):
"""Uses a file token obtained by an active user to retrieve a file from the operating """Uses a file token obtained by an active user to retrieve a file from the operating
system.""" system."""
if not file_path.is_file(): if not file_path.is_file():

View File

@ -7,7 +7,7 @@ class EmailReady(MealieModel):
class EmailSuccess(MealieModel): class EmailSuccess(MealieModel):
success: bool success: bool
error: str = None error: str | None = None
class EmailTest(MealieModel): class EmailTest(MealieModel):

View File

@ -34,9 +34,9 @@ class ShoppingListItemCreate(MealieModel):
note: str | None = "" note: str | None = ""
quantity: float = 1 quantity: float = 1
unit_id: UUID4 = None unit_id: UUID4 | None = None
unit: IngredientUnit | None unit: IngredientUnit | None
food_id: UUID4 = None food_id: UUID4 | None = None
food: IngredientFood | None food: IngredientFood | None
label_id: UUID4 | None = None label_id: UUID4 | None = None
@ -67,7 +67,7 @@ class ShoppingListItemOut(ShoppingListItemUpdate):
class ShoppingListCreate(MealieModel): class ShoppingListCreate(MealieModel):
name: str = None name: str | None = None
extras: dict | None = {} extras: dict | None = {}
created_at: datetime | None created_at: datetime | None

View File

@ -25,7 +25,7 @@ app_dirs = get_app_dirs()
class RecipeTag(MealieModel): class RecipeTag(MealieModel):
id: UUID4 = None id: UUID4 | None = None
name: str name: str
slug: str slug: str
@ -56,8 +56,8 @@ class RecipeToolPagination(PaginationBase):
class CreateRecipeBulk(BaseModel): class CreateRecipeBulk(BaseModel):
url: str url: str
categories: list[RecipeCategory] = None categories: list[RecipeCategory] | None = None
tags: list[RecipeTag] = None tags: list[RecipeTag] | None = None
class CreateRecipeByUrlBulk(BaseModel): class CreateRecipeByUrlBulk(BaseModel):

View File

@ -21,7 +21,7 @@ class PaginationQuery(MealieModel):
per_page: int = 50 per_page: int = 50
order_by: str = "created_at" order_by: str = "created_at"
order_direction: OrderDirection = OrderDirection.desc order_direction: OrderDirection = OrderDirection.desc
query_filter: str = None query_filter: str | None = None
class PaginationBase(GenericModel, Generic[DataT]): class PaginationBase(GenericModel, Generic[DataT]):
@ -52,15 +52,12 @@ class PaginationBase(GenericModel, Generic[DataT]):
self.previous = PaginationBase.merge_query_parameters(route, query_params) self.previous = PaginationBase.merge_query_parameters(route, query_params)
def set_pagination_guides(self, route: str, query_params: dict[str, Any] | None) -> None: def set_pagination_guides(self, route: str, query_params: dict[str, Any] | None) -> None:
if not query_params: valid_dict: dict[str, Any] = camelize(query_params) if query_params else {}
query_params = {}
query_params = camelize(query_params)
# sanitize user input # sanitize user input
self.page = max(self.page, 1) self.page = max(self.page, 1)
self._set_next(route, query_params) self._set_next(route, valid_dict)
self._set_prev(route, query_params) self._set_prev(route, valid_dict)
@staticmethod @staticmethod
def merge_query_parameters(url: str, params: dict[str, Any]): def merge_query_parameters(url: str, params: dict[str, Any]):

View File

@ -21,9 +21,9 @@ class AlchemyExporter(BaseService):
look_for_time = {"scheduled_time"} look_for_time = {"scheduled_time"}
class DateTimeParser(BaseModel): class DateTimeParser(BaseModel):
date: datetime.date = None date: datetime.date | None = None
dt: datetime.datetime = None dt: datetime.datetime | None = None
time: datetime.time = None time: datetime.time | None = None
def __init__(self, connection_str: str) -> None: def __init__(self, connection_str: str) -> None:
super().__init__() super().__init__()

View File

@ -5,7 +5,7 @@ from pathlib import Path
class BackupContents: class BackupContents:
_tables: dict = None _tables: dict | None = None
def __init__(self, file: Path) -> None: def __init__(self, file: Path) -> None:
self.base = file self.base = file

View File

@ -17,15 +17,17 @@ class BackupV2(BaseService):
def __init__(self, db_url: str = None) -> None: def __init__(self, db_url: str = None) -> None:
super().__init__() super().__init__()
self.db_url = db_url or self.settings.DB_URL # type - one of these has to be a string
self.db_url: str = db_url or self.settings.DB_URL # type: ignore
self.db_exporter = AlchemyExporter(self.db_url) self.db_exporter = AlchemyExporter(self.db_url)
def _sqlite(self) -> None: def _sqlite(self) -> None:
db_file = self.settings.DB_URL.removeprefix("sqlite:///") db_file = self.settings.DB_URL.removeprefix("sqlite:///") # type: ignore
# Create a backup of the SQLite database # Create a backup of the SQLite database
timestamp = datetime.datetime.now().strftime("%Y.%m.%d") timestamp = datetime.datetime.now().strftime("%Y.%m.%d")
shutil.copy(db_file, f"mealie_{timestamp}.bak.db") shutil.copy(db_file, self.directories.DATA_DIR.joinpath(f"mealie_{timestamp}.bak.db"))
def _postgres(self) -> None: def _postgres(self) -> None:
pass pass

View File

@ -11,8 +11,8 @@ from mealie.services._base_service import BaseService
class EmailOptions: class EmailOptions:
host: str host: str
port: int port: int
username: str = None username: str | None = None
password: str = None password: str | None = None
tls: bool = False tls: bool = False
ssl: bool = False ssl: bool = False
@ -39,7 +39,9 @@ class Message:
if smtp.ssl: if smtp.ssl:
with smtplib.SMTP_SSL(smtp.host, smtp.port) as server: with smtplib.SMTP_SSL(smtp.host, smtp.port) as server:
server.login(smtp.username, smtp.password) if smtp.username and smtp.password:
server.login(smtp.username, smtp.password)
errors = server.send_message(msg) errors = server.send_message(msg)
else: else:
with smtplib.SMTP(smtp.host, smtp.port) as server: with smtplib.SMTP(smtp.host, smtp.port) as server:
@ -66,17 +68,24 @@ class DefaultEmailSender(ABCEmailSender, BaseService):
""" """
def send(self, email_to: str, subject: str, html: str) -> bool: def send(self, email_to: str, subject: str, html: str) -> bool:
if self.settings.SMTP_FROM_EMAIL is None or self.settings.SMTP_FROM_NAME is None:
raise ValueError("SMTP_FROM_EMAIL and SMTP_FROM_NAME must be set in the config file.")
message = Message( message = Message(
subject=subject, subject=subject,
html=html, html=html,
mail_from=(self.settings.SMTP_FROM_NAME, self.settings.SMTP_FROM_EMAIL), mail_from=(self.settings.SMTP_FROM_NAME, self.settings.SMTP_FROM_EMAIL),
) )
if self.settings.SMTP_HOST is None or self.settings.SMTP_PORT is None:
raise ValueError("SMTP_HOST, SMTP_PORT must be set in the config file.")
smtp_options = EmailOptions( smtp_options = EmailOptions(
self.settings.SMTP_HOST, self.settings.SMTP_HOST,
int(self.settings.SMTP_PORT), int(self.settings.SMTP_PORT),
tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS", tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS" if self.settings.SMTP_AUTH_STRATEGY else False,
ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL", ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL" if self.settings.SMTP_AUTH_STRATEGY else False,
) )
if self.settings.SMTP_USER: if self.settings.SMTP_USER:

View File

@ -41,7 +41,7 @@ class EventListenerBase(ABC):
... ...
@contextlib.contextmanager @contextlib.contextmanager
def ensure_session(self) -> Generator[None, None, None]: def ensure_session(self) -> Generator[Session, None, None]:
""" """
ensure_session ensures that a session is available for the caller by checking if a session ensure_session ensures that a session is available for the caller by checking if a session
was provided during construction, and if not, creating a new session with the `with_session` was provided during construction, and if not, creating a new session with the `with_session`
@ -54,10 +54,9 @@ class EventListenerBase(ABC):
if self.session is None: if self.session is None:
with session_context() as session: with session_context() as session:
self.session = session self.session = session
yield yield self.session
else: else:
yield yield self.session
class AppriseEventListener(EventListenerBase): class AppriseEventListener(EventListenerBase):
@ -87,7 +86,7 @@ class AppriseEventListener(EventListenerBase):
"integration_id": event.integration_id, "integration_id": event.integration_id,
"document_data": json.dumps(jsonable_encoder(event.document_data)), "document_data": json.dumps(jsonable_encoder(event.document_data)),
"event_id": str(event.event_id), "event_id": str(event.event_id),
"timestamp": event.timestamp.isoformat(), "timestamp": event.timestamp.isoformat() if event.timestamp else None,
} }
return [ return [
@ -148,9 +147,9 @@ class WebhookEventListener(EventListenerBase):
def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]: def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]:
"""Fetches all scheduled webhooks from the database""" """Fetches all scheduled webhooks from the database"""
with self.ensure_session(): with self.ensure_session() as session:
return ( return (
self.session.query(GroupWebhooksModel) session.query(GroupWebhooksModel)
.where( .where(
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(), GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),

View File

@ -1,29 +0,0 @@
import random
from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe.recipe import Recipe, RecipeCategory
from mealie.services._base_service import BaseService
class MealPlanService(BaseService):
def __init__(self, group_id: UUID4, repos: AllRepositories):
self.group_id = group_id
self.repos = repos
def get_random_recipe(self, categories: list[RecipeCategory] = None) -> Recipe:
"""get_random_recipe returns a single recipe matching a specific criteria of
categories. if no categories are provided, a single recipe is returned from the
entire recipe database.
Note that the recipe must contain ALL categories in the list provided.
Args:
categories (list[RecipeCategory], optional): [description]. Defaults to None.
Returns:
Recipe: [description]
"""
recipes = self.repos.recipes.by_group(self.group_id).get_by_categories(categories)
return random.choice(recipes)

View File

@ -1,5 +1,6 @@
from pydantic import UUID4 from pydantic import UUID4
from mealie.core.exceptions import UnexpectedNone
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group import ShoppingListItemCreate, ShoppingListOut from mealie.schema.group import ShoppingListItemCreate, ShoppingListOut
from mealie.schema.group.group_shopping_list import ( from mealie.schema.group.group_shopping_list import (
@ -120,8 +121,10 @@ class ShoppingListService:
- deleted_shopping_list_items - deleted_shopping_list_items
""" """
recipe = self.repos.recipes.get_one(recipe_id, "id") recipe = self.repos.recipes.get_one(recipe_id, "id")
to_create = [] if not recipe:
raise UnexpectedNone("Recipe not found")
to_create = []
for ingredient in recipe.recipe_ingredient: for ingredient in recipe.recipe_ingredient:
food_id = None food_id = None
try: try:
@ -144,7 +147,7 @@ class ShoppingListService:
to_create.append( to_create.append(
ShoppingListItemCreate( ShoppingListItemCreate(
shopping_list_id=list_id, shopping_list_id=list_id,
is_food=not recipe.settings.disable_amount, is_food=not recipe.settings.disable_amount if recipe.settings else False,
food_id=food_id, food_id=food_id,
unit_id=unit_id, unit_id=unit_id,
quantity=ingredient.quantity, quantity=ingredient.quantity,
@ -163,6 +166,9 @@ class ShoppingListService:
new_shopping_list_items = [self.repos.group_shopping_list_item.create(item) for item in to_create] new_shopping_list_items = [self.repos.group_shopping_list_item.create(item) for item in to_create]
updated_shopping_list = self.shopping_lists.get_one(list_id) updated_shopping_list = self.shopping_lists.get_one(list_id)
if not updated_shopping_list:
raise UnexpectedNone("Shopping List not found")
updated_shopping_list_items, deleted_shopping_list_items = self.consolidate_and_save(updated_shopping_list.list_items) # type: ignore updated_shopping_list_items, deleted_shopping_list_items = self.consolidate_and_save(updated_shopping_list.list_items) # type: ignore
updated_shopping_list.list_items = updated_shopping_list_items updated_shopping_list.list_items = updated_shopping_list_items
@ -219,13 +225,16 @@ class ShoppingListService:
""" """
shopping_list = self.shopping_lists.get_one(list_id) shopping_list = self.shopping_lists.get_one(list_id)
if shopping_list is None:
raise UnexpectedNone("Shopping list not found, cannot remove recipe ingredients")
updated_shopping_list_items = [] updated_shopping_list_items = []
deleted_shopping_list_items = [] deleted_shopping_list_items = []
for item in shopping_list.list_items: for item in shopping_list.list_items:
found = False found = False
for ref in item.recipe_references: for ref in item.recipe_references:
remove_qty = 0.0 remove_qty: None | float = 0.0
if ref.recipe_id == recipe_id: if ref.recipe_id == recipe_id:
self.list_item_refs.delete(ref.id) # type: ignore self.list_item_refs.delete(ref.id) # type: ignore
@ -236,7 +245,9 @@ class ShoppingListService:
# If the item was found decrement the quantity by the remove_qty # If the item was found decrement the quantity by the remove_qty
if found: if found:
item.quantity = item.quantity - remove_qty
if remove_qty is not None:
item.quantity = item.quantity - remove_qty
if item.quantity <= 0: if item.quantity <= 0:
self.list_items.delete(item.id) self.list_items.delete(item.id)
@ -246,16 +257,16 @@ class ShoppingListService:
updated_shopping_list_items.append(item) updated_shopping_list_items.append(item)
# Decrement the list recipe reference count # Decrement the list recipe reference count
for ref in shopping_list.recipe_references: # type: ignore for recipe_ref in shopping_list.recipe_references:
if ref.recipe_id == recipe_id: if recipe_ref.recipe_id == recipe_id and recipe_ref.recipe_quantity is not None:
ref.recipe_quantity -= 1 recipe_ref.recipe_quantity -= 1.0
if ref.recipe_quantity <= 0: if recipe_ref.recipe_quantity <= 0.0:
self.list_refs.delete(ref.id) # type: ignore self.list_refs.delete(recipe_ref.id)
else: else:
self.list_refs.update(ref.id, ref) # type: ignore self.list_refs.update(recipe_ref.id, ref)
break break
# Save Changes # Save Changes
return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items # type: ignore

View File

@ -96,7 +96,12 @@ class BaseMigrator(BaseService):
self._migrate() self._migrate()
self._save_all_entries() self._save_all_entries()
return self.db.group_reports.get_one(self.report_id) result = self.db.group_reports.get_one(self.report_id)
if not result:
raise ValueError("Report not found")
return result
def import_recipes_to_database(self, validated_recipes: list[Recipe]) -> list[tuple[str, UUID4, bool]]: def import_recipes_to_database(self, validated_recipes: list[Recipe]) -> list[tuple[str, UUID4, bool]]:
""" """
@ -111,10 +116,13 @@ class BaseMigrator(BaseService):
if self.add_migration_tag: if self.add_migration_tag:
migration_tag = self.helpers.get_or_set_tags([self.name])[0] migration_tag = self.helpers.get_or_set_tags([self.name])[0]
return_vars = [] return_vars: list[tuple[str, UUID4, bool]] = []
group = self.db.groups.get_one(self.group_id) group = self.db.groups.get_one(self.group_id)
if not group or not group.preferences:
raise ValueError("Group preferences not found")
default_settings = RecipeSettings( default_settings = RecipeSettings(
public=group.preferences.recipe_public, public=group.preferences.recipe_public,
show_nutrition=group.preferences.recipe_show_nutrition, show_nutrition=group.preferences.recipe_show_nutrition,
@ -132,6 +140,8 @@ class BaseMigrator(BaseService):
if recipe.tags: if recipe.tags:
recipe.tags = self.helpers.get_or_set_tags(x.name for x in recipe.tags) recipe.tags = self.helpers.get_or_set_tags(x.name for x in recipe.tags)
else:
recipe.tags = []
if recipe.recipe_category: if recipe.recipe_category:
recipe.recipe_category = self.helpers.get_or_set_category(x.name for x in recipe.recipe_category) recipe.recipe_category = self.helpers.get_or_set_category(x.name for x in recipe.recipe_category)
@ -155,7 +165,7 @@ class BaseMigrator(BaseService):
else: else:
message = f"Failed to import {recipe.name}" message = f"Failed to import {recipe.name}"
return_vars.append((recipe.slug, recipe.id, status)) return_vars.append((recipe.slug, recipe.id, status)) # type: ignore
self.report_entries.append( self.report_entries.append(
ReportEntryCreate( ReportEntryCreate(

View File

@ -42,8 +42,13 @@ class ChowdownMigrator(BaseMigrator):
for slug, recipe_id, status in results: for slug, recipe_id, status in results:
if status: if status:
try: try:
original_image = recipe_lookup.get(slug).image r = recipe_lookup.get(slug)
cd_image = image_dir.joinpath(original_image)
if not r:
continue
if r.image:
cd_image = image_dir.joinpath(r.image)
except StopIteration: except StopIteration:
continue continue
if cd_image: if cd_image:

View File

@ -1,5 +1,6 @@
from pathlib import Path from pathlib import Path
from mealie.core.exceptions import UnexpectedNone
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_exports import GroupDataExport from mealie.schema.group.group_exports import GroupDataExport
from mealie.schema.recipe import CategoryBase from mealie.schema.recipe import CategoryBase
@ -41,6 +42,9 @@ class RecipeBulkActionsService(BaseService):
group = self.repos.groups.get_one(self.group.id) group = self.repos.groups.get_one(self.group.id)
if group is None:
raise UnexpectedNone("Failed to purge exports for group, no group found")
for match in group.directory.glob("**/export/*zip"): for match in group.directory.glob("**/export/*zip"):
if match.is_file(): if match.is_file():
match.unlink() match.unlink()
@ -52,8 +56,8 @@ class RecipeBulkActionsService(BaseService):
for slug in recipes: for slug in recipes:
recipe = self.repos.recipes.get_one(slug) recipe = self.repos.recipes.get_one(slug)
if recipe is None: if recipe is None or recipe.settings is None:
self.logger.error(f"Failed to set settings for recipe {slug}, no recipe found") raise UnexpectedNone(f"Failed to set settings for recipe {slug}, no recipe found")
settings.locked = recipe.settings.locked settings.locked = recipe.settings.locked
recipe.settings = settings recipe.settings = settings
@ -69,9 +73,12 @@ class RecipeBulkActionsService(BaseService):
recipe = self.repos.recipes.get_one(slug) recipe = self.repos.recipes.get_one(slug)
if recipe is None: if recipe is None:
self.logger.error(f"Failed to tag recipe {slug}, no recipe found") raise UnexpectedNone(f"Failed to tag recipe {slug}, no recipe found")
recipe.tags += tags if recipe.tags is None:
recipe.tags = []
recipe.tags += tags # type: ignore
try: try:
self.repos.recipes.update(slug, recipe) self.repos.recipes.update(slug, recipe)
@ -84,9 +91,12 @@ class RecipeBulkActionsService(BaseService):
recipe = self.repos.recipes.get_one(slug) recipe = self.repos.recipes.get_one(slug)
if recipe is None: if recipe is None:
self.logger.error(f"Failed to categorize recipe {slug}, no recipe found") raise UnexpectedNone(f"Failed to categorize recipe {slug}, no recipe found")
recipe.recipe_category += categories if recipe.recipe_category is None:
recipe.recipe_category = []
recipe.recipe_category += categories # type: ignore
try: try:
self.repos.recipes.update(slug, recipe) self.repos.recipes.update(slug, recipe)

View File

@ -50,6 +50,8 @@ class RecipeService(BaseService):
return recipe return recipe
def can_update(self, recipe: Recipe) -> bool: def can_update(self, recipe: Recipe) -> bool:
if recipe.settings is None:
raise exceptions.UnexpectedNone("Recipe Settings is None")
return recipe.settings.locked is False or self.user.id == recipe.user_id return recipe.settings.locked is False or self.user.id == recipe.user_id
def can_lock_unlock(self, recipe: Recipe) -> bool: def can_lock_unlock(self, recipe: Recipe) -> bool:
@ -66,6 +68,9 @@ class RecipeService(BaseService):
except FileNotFoundError: except FileNotFoundError:
self.logger.error(f"Recipe Directory not Found: {original_slug}") self.logger.error(f"Recipe Directory not Found: {original_slug}")
if recipe.assets is None:
recipe.assets = []
all_asset_files = [x.file_name for x in recipe.assets] all_asset_files = [x.file_name for x in recipe.assets]
for file in recipe.asset_dir.iterdir(): for file in recipe.asset_dir.iterdir():
@ -92,7 +97,7 @@ class RecipeService(BaseService):
additional_attrs["group_id"] = user.group_id additional_attrs["group_id"] = user.group_id
if additional_attrs.get("tags"): if additional_attrs.get("tags"):
for i in range(len(additional_attrs.get("tags"))): for i in range(len(additional_attrs.get("tags", []))):
additional_attrs["tags"][i]["group_id"] = user.group_id additional_attrs["tags"][i]["group_id"] = user.group_id
if not additional_attrs.get("recipe_ingredient"): if not additional_attrs.get("recipe_ingredient"):
@ -105,6 +110,9 @@ class RecipeService(BaseService):
def create_one(self, create_data: Union[Recipe, CreateRecipe]) -> Recipe: def create_one(self, create_data: Union[Recipe, CreateRecipe]) -> Recipe:
if create_data.name is None:
create_data.name = "New Recipe"
data: Recipe = self._recipe_creation_factory( data: Recipe = self._recipe_creation_factory(
self.user, self.user,
name=create_data.name, name=create_data.name,
@ -134,8 +142,8 @@ class RecipeService(BaseService):
with temp_path.open("wb") as buffer: with temp_path.open("wb") as buffer:
shutil.copyfileobj(archive.file, buffer) shutil.copyfileobj(archive.file, buffer)
recipe_dict = None recipe_dict: dict | None = None
recipe_image = None recipe_image: bytes | None = None
with ZipFile(temp_path) as myzip: with ZipFile(temp_path) as myzip:
for file in myzip.namelist(): for file in myzip.namelist():
@ -146,10 +154,15 @@ class RecipeService(BaseService):
with myzip.open(file) as myfile: with myzip.open(file) as myfile:
recipe_image = myfile.read() recipe_image = myfile.read()
if recipe_dict is None:
raise exceptions.UnexpectedNone("No json data found in Zip")
recipe = self.create_one(Recipe(**recipe_dict)) recipe = self.create_one(Recipe(**recipe_dict))
if recipe: if recipe and recipe.id:
data_service = RecipeDataService(recipe.id) data_service = RecipeDataService(recipe.id)
if recipe_image:
data_service.write_image(recipe_image, "webp") data_service.write_image(recipe_image, "webp")
return recipe return recipe
@ -172,6 +185,10 @@ class RecipeService(BaseService):
""" """
recipe = self._get_recipe(slug) recipe = self._get_recipe(slug)
if recipe is None or recipe.settings is None:
raise exceptions.NoEntryFound("Recipe not found.")
if not self.can_update(recipe): if not self.can_update(recipe):
raise exceptions.PermissionDenied("You do not have permission to edit this recipe.") raise exceptions.PermissionDenied("You do not have permission to edit this recipe.")
@ -189,9 +206,12 @@ class RecipeService(BaseService):
return new_data return new_data
def patch_one(self, slug: str, patch_data: Recipe) -> Recipe: def patch_one(self, slug: str, patch_data: Recipe) -> Recipe:
recipe = self._pre_update_check(slug, patch_data) recipe: Recipe | None = self._pre_update_check(slug, patch_data)
recipe = self.repos.recipes.by_group(self.group.id).get_one(slug) recipe = self.repos.recipes.by_group(self.group.id).get_one(slug)
if recipe is None:
raise exceptions.NoEntryFound("Recipe not found.")
new_data = self.repos.recipes.by_group(self.group.id).patch(recipe.slug, patch_data.dict(exclude_unset=True)) new_data = self.repos.recipes.by_group(self.group.id).patch(recipe.slug, patch_data.dict(exclude_unset=True))
self.check_assets(new_data, recipe.slug) self.check_assets(new_data, recipe.slug)
@ -210,6 +230,6 @@ class RecipeService(BaseService):
# ================================================================= # =================================================================
# Recipe Template Methods # Recipe Template Methods
def render_template(self, recipe: Recipe, temp_dir: Path, template: str = None) -> Path: def render_template(self, recipe: Recipe, temp_dir: Path, template: str) -> Path:
t_service = TemplateService(temp_dir) t_service = TemplateService(temp_dir)
return t_service.render(recipe, template) return t_service.render(recipe, template)

View File

@ -16,7 +16,7 @@ class TemplateType(str, enum.Enum):
class TemplateService(BaseService): class TemplateService(BaseService):
def __init__(self, temp: Path = None) -> None: def __init__(self, temp: Path | None = None) -> None:
"""Creates a template service that can be used for multiple template generations """Creates a template service that can be used for multiple template generations
A temporary directory must be provided as a place holder for where to render all templates A temporary directory must be provided as a place holder for where to render all templates
Args: Args:
@ -58,7 +58,7 @@ class TemplateService(BaseService):
return TemplateType(t_type) return TemplateType(t_type)
def render(self, recipe: Recipe, template: str = None) -> Path: def render(self, recipe: Recipe, template: str) -> Path:
""" """
Renders a TemplateType in a temporary directory and returns the path to the file. Renders a TemplateType in a temporary directory and returns the path to the file.
@ -87,6 +87,9 @@ class TemplateService(BaseService):
""" """
self.__check_temp(self._render_json) self.__check_temp(self._render_json)
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_json")
save_path = self.temp.joinpath(f"{recipe.slug}.json") save_path = self.temp.joinpath(f"{recipe.slug}.json")
with open(save_path, "w") as f: with open(save_path, "w") as f:
f.write(recipe.json(indent=4, by_alias=True)) f.write(recipe.json(indent=4, by_alias=True))
@ -100,6 +103,9 @@ class TemplateService(BaseService):
""" """
self.__check_temp(self._render_jinja2) self.__check_temp(self._render_jinja2)
if j2_template is None:
raise ValueError("Template must be provided for method _render_jinja2")
j2_path: Path = self.directories.TEMPLATE_DIR / j2_template j2_path: Path = self.directories.TEMPLATE_DIR / j2_template
if not j2_path.is_file(): if not j2_path.is_file():
@ -113,6 +119,9 @@ class TemplateService(BaseService):
save_name = f"{recipe.slug}{j2_path.suffix}" save_name = f"{recipe.slug}{j2_path.suffix}"
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_jinja2")
save_path = self.temp.joinpath(save_name) save_path = self.temp.joinpath(save_name)
with open(save_path, "w") as f: with open(save_path, "w") as f:
@ -124,6 +133,10 @@ class TemplateService(BaseService):
self.__check_temp(self._render_jinja2) self.__check_temp(self._render_jinja2)
image_asset = recipe.image_dir.joinpath(RecipeImageTypes.original.value) image_asset = recipe.image_dir.joinpath(RecipeImageTypes.original.value)
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_zip")
zip_temp = self.temp.joinpath(f"{recipe.slug}.zip") zip_temp = self.temp.joinpath(f"{recipe.slug}.zip")
with ZipFile(zip_temp, "w") as myzip: with ZipFile(zip_temp, "w") as myzip:

View File

@ -19,7 +19,7 @@ class ParserErrors(str, Enum):
CONNECTION_ERROR = "CONNECTION_ERROR" CONNECTION_ERROR = "CONNECTION_ERROR"
def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]: def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras | None]:
"""Main entry point for generating a recipe from a URL. Pass in a URL and """Main entry point for generating a recipe from a URL. Pass in a URL and
a Recipe object will be returned if successful. a Recipe object will be returned if successful.
@ -43,6 +43,10 @@ def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]:
try: try:
recipe_data_service.scrape_image(new_recipe.image) recipe_data_service.scrape_image(new_recipe.image)
if new_recipe.name is None:
new_recipe.name = "Untitled"
new_recipe.slug = slugify(new_recipe.name) new_recipe.slug = slugify(new_recipe.name)
new_recipe.image = cache.new_key(4) new_recipe.image = cache.new_key(4)
except Exception as e: except Exception as e:

View File

@ -176,7 +176,7 @@ class RecipeScraperPackage(ABCScraperStrategy):
ingredients = [] ingredients = []
try: try:
instruct = scraped_schema.instructions() instruct: list | str = scraped_schema.instructions()
except Exception: except Exception:
instruct = [] instruct = []
@ -212,7 +212,7 @@ class RecipeScraperOpenGraph(ABCScraperStrategy):
""" """
def og_field(properties: dict, field_name: str) -> str: def og_field(properties: dict, field_name: str) -> str:
return next((val for name, val in properties if name == field_name), None) return next((val for name, val in properties if name == field_name), "")
def og_fields(properties: list[tuple[str, str]], field_name: str) -> list[str]: def og_fields(properties: list[tuple[str, str]], field_name: str) -> list[str]:
return list({val for name, val in properties if name == field_name}) return list({val for name, val in properties if name == field_name})

View File

@ -31,6 +31,9 @@ class PasswordResetService(BaseService):
def send_reset_email(self, email: str): def send_reset_email(self, email: str):
token_entry = self.generate_reset_token(email) token_entry = self.generate_reset_token(email)
if token_entry is None:
return None
# Send Email # Send Email
email_servive = EmailService() email_servive = EmailService()
reset_url = f"{self.settings.BASE_URL}/reset-password?token={token_entry.token}" reset_url = f"{self.settings.BASE_URL}/reset-password?token={token_entry.token}"

View File

@ -35,7 +35,8 @@ class RegistrationService:
can_organize=new_group, can_organize=new_group,
) )
return self.repos.users.create(new_user) # TODO: problem with repository type, not type here
return self.repos.users.create(new_user) # type: ignore
def _register_new_group(self) -> GroupInDB: def _register_new_group(self) -> GroupInDB:
group_data = GroupBase(name=self.registration.group) group_data = GroupBase(name=self.registration.group)
@ -74,7 +75,13 @@ class RegistrationService:
token_entry = self.repos.group_invite_tokens.get_one(registration.group_token) token_entry = self.repos.group_invite_tokens.get_one(registration.group_token)
if not token_entry: if not token_entry:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"}) raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
group = self.repos.groups.get_one(token_entry.group_id)
maybe_none_group = self.repos.groups.get_one(token_entry.group_id)
if maybe_none_group is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
group = maybe_none_group
else: else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Missing group"}) raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Missing group"})

View File

@ -105,5 +105,5 @@ pgsql = ["psycopg2-binary"]
python_version = "3.10" python_version = "3.10"
ignore_missing_imports = true ignore_missing_imports = true
follow_imports = "skip" follow_imports = "skip"
strict_optional = false # TODO: Fix none type checks - temporary stop-gap to implement mypy strict_optional = true
plugins = "pydantic.mypy" plugins = "pydantic.mypy"