fix: strict optional errors (#1759)

* fix strict optional errors

* fix typing in repository

* fix backup db files location

* update workspace settings
This commit is contained in:
Hayden 2022-10-23 13:04:04 -08:00 committed by GitHub
parent 97d9e2a109
commit 84c23765cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 253 additions and 139 deletions

View File

@ -1,10 +1,4 @@
{
"conventionalCommits.scopes": [
"frontend",
"docs",
"backend",
"code-generation"
],
"cSpell.enableFiletypes": ["!javascript", "!python", "!yaml"],
"cSpell.words": [
"chowdown",
@ -44,7 +38,7 @@
"python.testing.unittestEnabled": false,
"python.analysis.typeCheckingMode": "off",
"python.linting.mypyEnabled": true,
"python.sortImports.path": "${workspaceFolder}/.venv/bin/isort",
"isort.path": ["${workspaceFolder}/.venv/bin/isort"],
"search.mode": "reuseEditor",
"python.testing.unittestArgs": ["-v", "-s", "./tests", "-p", "test_*.py"],
"explorer.fileNesting.enabled": true,

View File

@ -117,35 +117,58 @@ def validate_long_live_token(session: Session, client_token: str, user_id: str)
def validate_file_token(token: Optional[str] = None) -> Path:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
)
"""
Args:
token (Optional[str], optional): _description_. Defaults to None.
Raises:
HTTPException: 400 Bad Request when no token or the file doesn't exist
HTTPException: 401 Unauthorized when the token is invalid
"""
if not token:
return None
raise HTTPException(status.HTTP_400_BAD_REQUEST)
try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
file_path = Path(payload.get("file"))
except JWTError as e:
raise credentials_exception from e
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
) from e
if not file_path.exists():
raise HTTPException(status.HTTP_400_BAD_REQUEST)
return file_path
def validate_recipe_token(token: Optional[str] = None) -> str:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
)
"""
Args:
token (Optional[str], optional): _description_. Defaults to None.
Raises:
HTTPException: 400 Bad Request when no token or the recipe doesn't exist
HTTPException: 401 JWTError when token is invalid
Returns:
str: token data
"""
if not token:
return None
raise HTTPException(status.HTTP_400_BAD_REQUEST)
try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
slug = payload.get("slug")
slug: str | None = payload.get("slug")
except JWTError as e:
raise credentials_exception from e
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
) from e
if slug is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST)
return slug

View File

@ -3,6 +3,17 @@ from sqlite3 import IntegrityError
from mealie.lang.providers import Translator
class UnexpectedNone(Exception):
"""Exception raised when a value is None when it should not be."""
def __init__(self, message: str = "Unexpected None Value"):
self.message = message
super().__init__(self.message)
def __str__(self):
return f"{self.message}"
class PermissionDenied(Exception):
"""
This exception is raised when a user tries to access a resource that they do not have permission to access.

View File

@ -62,11 +62,16 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
conn.set_option(ldap.OPT_X_TLS_CACERTFILE, settings.LDAP_TLS_CACERTFILE)
conn.set_option(ldap.OPT_X_TLS_NEWCTX, 0)
user = db.users.get_one(username, "email", any_case=True)
if not settings.LDAP_BIND_TEMPLATE:
return False
if not user:
user_bind = settings.LDAP_BIND_TEMPLATE.format(username)
user = db.users.get_one(username, "username", any_case=True)
else:
user_bind = settings.LDAP_BIND_TEMPLATE.format(user.username)
try:
conn.simple_bind_s(user_bind, password)
except (ldap.INVALID_CREDENTIALS, ldap.NO_SUCH_OBJECT):
@ -86,7 +91,7 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
else:
return False
if not user:
if user is None:
user = db.users.create(
{
"username": username,
@ -96,6 +101,7 @@ def user_from_ldap(db: AllRepositories, username: str, password: str) -> Private
"admin": False,
},
)
if settings.LDAP_ADMIN_FILTER:
user.admin = len(conn.search_s(user_dn, ldap.SCOPE_BASE, settings.LDAP_ADMIN_FILTER, [])) > 0
db.users.update(user.id, user)

View File

@ -93,18 +93,18 @@ class AppSettings(BaseSettings):
@staticmethod
def validate_smtp(
host: str,
port: str,
from_name: str,
from_email: str,
strategy: str,
host: str | None,
port: str | None,
from_name: str | None,
from_email: str | None,
strategy: str | None,
user: str | None = None,
password: str | None = None,
) -> bool:
"""Validates all SMTP variables are set"""
required = {host, port, from_name, from_email, strategy}
if strategy.upper() in {"TLS", "SSL"}:
if strategy and strategy.upper() in {"TLS", "SSL"}:
required.add(user)
required.add(password)

View File

@ -138,13 +138,19 @@ class RecipeModel(SqlAlchemyBase, BaseMixins):
**_,
) -> None:
self.nutrition = Nutrition(**nutrition) if nutrition else Nutrition()
self.recipe_instructions = [RecipeInstruction(**step, session=session) for step in recipe_instructions]
self.recipe_ingredient = [RecipeIngredient(**ingr, session=session) for ingr in recipe_ingredient]
self.assets = [RecipeAsset(**a) for a in assets]
# Mealie Specific
if recipe_instructions:
self.recipe_instructions = [RecipeInstruction(**step, session=session) for step in recipe_instructions]
if recipe_ingredient:
self.recipe_ingredient = [RecipeIngredient(**ingr, session=session) for ingr in recipe_ingredient]
if assets:
self.assets = [RecipeAsset(**a) for a in assets]
self.settings = RecipeSettings(**settings) if settings else RecipeSettings()
self.notes = [Note(**note) for note in notes]
# Time Stampes
if notes:
self.notes = [Note(**n) for n in notes]
self.date_updated = datetime.datetime.now()

View File

@ -1 +1,5 @@
from .repository_factory import AllRepositories
__all__ = [
"AllRepositories",
]

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from math import ceil
from typing import Any, Generic, TypeVar, Union
from typing import Any, Generic, TypeVar
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
@ -24,8 +26,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
Generic ([Model]): Represents the SqlAlchemyModel Model
"""
user_id: UUID4 = None
group_id: UUID4 = None
user_id: UUID4 | None = None
group_id: UUID4 | None = None
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
self.session = session
@ -35,11 +37,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger = get_logger()
def by_user(self, user_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
def by_user(self, user_id: UUID4) -> RepositoryGeneric[Schema, Model]:
self.user_id = user_id
return self
def by_group(self, group_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
def by_group(self, group_id: UUID4) -> RepositoryGeneric[Schema, Model]:
self.group_id = group_id
return self
@ -221,7 +223,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
attr_match: str = None,
count=True,
override_schema=None,
) -> Union[int, list[Schema]]: # sourcery skip: assign-if-exp
) -> int | list[Schema]: # sourcery skip: assign-if-exp
eff_schema = override_schema or self.schema
q = self._query().filter(attribute_name == attr_match)

View File

@ -19,6 +19,10 @@ class UserFavoritesController(BaseUserController):
def add_favorite(self, id: UUID4, slug: str):
"""Adds a Recipe to the users favorites"""
assert_user_change_allowed(id, self.user)
if not self.user.favorite_recipes:
self.user.favorite_recipes = []
self.user.favorite_recipes.append(slug)
self.repos.users.update(self.user.id, self.user)
@ -26,6 +30,10 @@ class UserFavoritesController(BaseUserController):
def remove_favorite(self, id: UUID4, slug: str):
"""Adds a Recipe to the users favorites"""
assert_user_change_allowed(id, self.user)
if not self.user.favorite_recipes:
self.user.favorite_recipes = []
self.user.favorite_recipes = [x for x in self.user.favorite_recipes if x != slug]
self.repos.users.update(self.user.id, self.user)
return

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from starlette.responses import FileResponse
@ -10,7 +9,7 @@ router = APIRouter(prefix="/api/utils", tags=["Utils"], include_in_schema=True)
@router.get("/download")
async def download_file(file_path: Optional[Path] = Depends(validate_file_token)):
async def download_file(file_path: Path = Depends(validate_file_token)):
"""Uses a file token obtained by an active user to retrieve a file from the operating
system."""
if not file_path.is_file():

View File

@ -7,7 +7,7 @@ class EmailReady(MealieModel):
class EmailSuccess(MealieModel):
success: bool
error: str = None
error: str | None = None
class EmailTest(MealieModel):

View File

@ -34,9 +34,9 @@ class ShoppingListItemCreate(MealieModel):
note: str | None = ""
quantity: float = 1
unit_id: UUID4 = None
unit_id: UUID4 | None = None
unit: IngredientUnit | None
food_id: UUID4 = None
food_id: UUID4 | None = None
food: IngredientFood | None
label_id: UUID4 | None = None
@ -67,7 +67,7 @@ class ShoppingListItemOut(ShoppingListItemUpdate):
class ShoppingListCreate(MealieModel):
name: str = None
name: str | None = None
extras: dict | None = {}
created_at: datetime | None

View File

@ -25,7 +25,7 @@ app_dirs = get_app_dirs()
class RecipeTag(MealieModel):
id: UUID4 = None
id: UUID4 | None = None
name: str
slug: str
@ -56,8 +56,8 @@ class RecipeToolPagination(PaginationBase):
class CreateRecipeBulk(BaseModel):
url: str
categories: list[RecipeCategory] = None
tags: list[RecipeTag] = None
categories: list[RecipeCategory] | None = None
tags: list[RecipeTag] | None = None
class CreateRecipeByUrlBulk(BaseModel):

View File

@ -21,7 +21,7 @@ class PaginationQuery(MealieModel):
per_page: int = 50
order_by: str = "created_at"
order_direction: OrderDirection = OrderDirection.desc
query_filter: str = None
query_filter: str | None = None
class PaginationBase(GenericModel, Generic[DataT]):
@ -52,15 +52,12 @@ class PaginationBase(GenericModel, Generic[DataT]):
self.previous = PaginationBase.merge_query_parameters(route, query_params)
def set_pagination_guides(self, route: str, query_params: dict[str, Any] | None) -> None:
if not query_params:
query_params = {}
query_params = camelize(query_params)
valid_dict: dict[str, Any] = camelize(query_params) if query_params else {}
# sanitize user input
self.page = max(self.page, 1)
self._set_next(route, query_params)
self._set_prev(route, query_params)
self._set_next(route, valid_dict)
self._set_prev(route, valid_dict)
@staticmethod
def merge_query_parameters(url: str, params: dict[str, Any]):

View File

@ -21,9 +21,9 @@ class AlchemyExporter(BaseService):
look_for_time = {"scheduled_time"}
class DateTimeParser(BaseModel):
date: datetime.date = None
dt: datetime.datetime = None
time: datetime.time = None
date: datetime.date | None = None
dt: datetime.datetime | None = None
time: datetime.time | None = None
def __init__(self, connection_str: str) -> None:
super().__init__()

View File

@ -5,7 +5,7 @@ from pathlib import Path
class BackupContents:
_tables: dict = None
_tables: dict | None = None
def __init__(self, file: Path) -> None:
self.base = file

View File

@ -17,15 +17,17 @@ class BackupV2(BaseService):
def __init__(self, db_url: str = None) -> None:
super().__init__()
self.db_url = db_url or self.settings.DB_URL
# type - one of these has to be a string
self.db_url: str = db_url or self.settings.DB_URL # type: ignore
self.db_exporter = AlchemyExporter(self.db_url)
def _sqlite(self) -> None:
db_file = self.settings.DB_URL.removeprefix("sqlite:///")
db_file = self.settings.DB_URL.removeprefix("sqlite:///") # type: ignore
# Create a backup of the SQLite database
timestamp = datetime.datetime.now().strftime("%Y.%m.%d")
shutil.copy(db_file, f"mealie_{timestamp}.bak.db")
shutil.copy(db_file, self.directories.DATA_DIR.joinpath(f"mealie_{timestamp}.bak.db"))
def _postgres(self) -> None:
pass

View File

@ -11,8 +11,8 @@ from mealie.services._base_service import BaseService
class EmailOptions:
host: str
port: int
username: str = None
password: str = None
username: str | None = None
password: str | None = None
tls: bool = False
ssl: bool = False
@ -39,7 +39,9 @@ class Message:
if smtp.ssl:
with smtplib.SMTP_SSL(smtp.host, smtp.port) as server:
server.login(smtp.username, smtp.password)
if smtp.username and smtp.password:
server.login(smtp.username, smtp.password)
errors = server.send_message(msg)
else:
with smtplib.SMTP(smtp.host, smtp.port) as server:
@ -66,17 +68,24 @@ class DefaultEmailSender(ABCEmailSender, BaseService):
"""
def send(self, email_to: str, subject: str, html: str) -> bool:
if self.settings.SMTP_FROM_EMAIL is None or self.settings.SMTP_FROM_NAME is None:
raise ValueError("SMTP_FROM_EMAIL and SMTP_FROM_NAME must be set in the config file.")
message = Message(
subject=subject,
html=html,
mail_from=(self.settings.SMTP_FROM_NAME, self.settings.SMTP_FROM_EMAIL),
)
if self.settings.SMTP_HOST is None or self.settings.SMTP_PORT is None:
raise ValueError("SMTP_HOST, SMTP_PORT must be set in the config file.")
smtp_options = EmailOptions(
self.settings.SMTP_HOST,
int(self.settings.SMTP_PORT),
tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS",
ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL",
tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS" if self.settings.SMTP_AUTH_STRATEGY else False,
ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL" if self.settings.SMTP_AUTH_STRATEGY else False,
)
if self.settings.SMTP_USER:

View File

@ -41,7 +41,7 @@ class EventListenerBase(ABC):
...
@contextlib.contextmanager
def ensure_session(self) -> Generator[None, None, None]:
def ensure_session(self) -> Generator[Session, None, None]:
"""
ensure_session ensures that a session is available for the caller by checking if a session
was provided during construction, and if not, creating a new session with the `with_session`
@ -54,10 +54,9 @@ class EventListenerBase(ABC):
if self.session is None:
with session_context() as session:
self.session = session
yield
yield self.session
else:
yield
yield self.session
class AppriseEventListener(EventListenerBase):
@ -87,7 +86,7 @@ class AppriseEventListener(EventListenerBase):
"integration_id": event.integration_id,
"document_data": json.dumps(jsonable_encoder(event.document_data)),
"event_id": str(event.event_id),
"timestamp": event.timestamp.isoformat(),
"timestamp": event.timestamp.isoformat() if event.timestamp else None,
}
return [
@ -148,9 +147,9 @@ class WebhookEventListener(EventListenerBase):
def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]:
"""Fetches all scheduled webhooks from the database"""
with self.ensure_session():
with self.ensure_session() as session:
return (
self.session.query(GroupWebhooksModel)
session.query(GroupWebhooksModel)
.where(
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),

View File

@ -1,29 +0,0 @@
import random
from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe.recipe import Recipe, RecipeCategory
from mealie.services._base_service import BaseService
class MealPlanService(BaseService):
def __init__(self, group_id: UUID4, repos: AllRepositories):
self.group_id = group_id
self.repos = repos
def get_random_recipe(self, categories: list[RecipeCategory] = None) -> Recipe:
"""get_random_recipe returns a single recipe matching a specific criteria of
categories. if no categories are provided, a single recipe is returned from the
entire recipe database.
Note that the recipe must contain ALL categories in the list provided.
Args:
categories (list[RecipeCategory], optional): [description]. Defaults to None.
Returns:
Recipe: [description]
"""
recipes = self.repos.recipes.by_group(self.group_id).get_by_categories(categories)
return random.choice(recipes)

View File

@ -1,5 +1,6 @@
from pydantic import UUID4
from mealie.core.exceptions import UnexpectedNone
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group import ShoppingListItemCreate, ShoppingListOut
from mealie.schema.group.group_shopping_list import (
@ -120,8 +121,10 @@ class ShoppingListService:
- deleted_shopping_list_items
"""
recipe = self.repos.recipes.get_one(recipe_id, "id")
to_create = []
if not recipe:
raise UnexpectedNone("Recipe not found")
to_create = []
for ingredient in recipe.recipe_ingredient:
food_id = None
try:
@ -144,7 +147,7 @@ class ShoppingListService:
to_create.append(
ShoppingListItemCreate(
shopping_list_id=list_id,
is_food=not recipe.settings.disable_amount,
is_food=not recipe.settings.disable_amount if recipe.settings else False,
food_id=food_id,
unit_id=unit_id,
quantity=ingredient.quantity,
@ -163,6 +166,9 @@ class ShoppingListService:
new_shopping_list_items = [self.repos.group_shopping_list_item.create(item) for item in to_create]
updated_shopping_list = self.shopping_lists.get_one(list_id)
if not updated_shopping_list:
raise UnexpectedNone("Shopping List not found")
updated_shopping_list_items, deleted_shopping_list_items = self.consolidate_and_save(updated_shopping_list.list_items) # type: ignore
updated_shopping_list.list_items = updated_shopping_list_items
@ -219,13 +225,16 @@ class ShoppingListService:
"""
shopping_list = self.shopping_lists.get_one(list_id)
if shopping_list is None:
raise UnexpectedNone("Shopping list not found, cannot remove recipe ingredients")
updated_shopping_list_items = []
deleted_shopping_list_items = []
for item in shopping_list.list_items:
found = False
for ref in item.recipe_references:
remove_qty = 0.0
remove_qty: None | float = 0.0
if ref.recipe_id == recipe_id:
self.list_item_refs.delete(ref.id) # type: ignore
@ -236,7 +245,9 @@ class ShoppingListService:
# If the item was found decrement the quantity by the remove_qty
if found:
item.quantity = item.quantity - remove_qty
if remove_qty is not None:
item.quantity = item.quantity - remove_qty
if item.quantity <= 0:
self.list_items.delete(item.id)
@ -246,16 +257,16 @@ class ShoppingListService:
updated_shopping_list_items.append(item)
# Decrement the list recipe reference count
for ref in shopping_list.recipe_references: # type: ignore
if ref.recipe_id == recipe_id:
ref.recipe_quantity -= 1
for recipe_ref in shopping_list.recipe_references:
if recipe_ref.recipe_id == recipe_id and recipe_ref.recipe_quantity is not None:
recipe_ref.recipe_quantity -= 1.0
if ref.recipe_quantity <= 0:
self.list_refs.delete(ref.id) # type: ignore
if recipe_ref.recipe_quantity <= 0.0:
self.list_refs.delete(recipe_ref.id)
else:
self.list_refs.update(ref.id, ref) # type: ignore
self.list_refs.update(recipe_ref.id, ref)
break
# Save Changes
return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items
return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items # type: ignore

View File

@ -96,7 +96,12 @@ class BaseMigrator(BaseService):
self._migrate()
self._save_all_entries()
return self.db.group_reports.get_one(self.report_id)
result = self.db.group_reports.get_one(self.report_id)
if not result:
raise ValueError("Report not found")
return result
def import_recipes_to_database(self, validated_recipes: list[Recipe]) -> list[tuple[str, UUID4, bool]]:
"""
@ -111,10 +116,13 @@ class BaseMigrator(BaseService):
if self.add_migration_tag:
migration_tag = self.helpers.get_or_set_tags([self.name])[0]
return_vars = []
return_vars: list[tuple[str, UUID4, bool]] = []
group = self.db.groups.get_one(self.group_id)
if not group or not group.preferences:
raise ValueError("Group preferences not found")
default_settings = RecipeSettings(
public=group.preferences.recipe_public,
show_nutrition=group.preferences.recipe_show_nutrition,
@ -132,6 +140,8 @@ class BaseMigrator(BaseService):
if recipe.tags:
recipe.tags = self.helpers.get_or_set_tags(x.name for x in recipe.tags)
else:
recipe.tags = []
if recipe.recipe_category:
recipe.recipe_category = self.helpers.get_or_set_category(x.name for x in recipe.recipe_category)
@ -155,7 +165,7 @@ class BaseMigrator(BaseService):
else:
message = f"Failed to import {recipe.name}"
return_vars.append((recipe.slug, recipe.id, status))
return_vars.append((recipe.slug, recipe.id, status)) # type: ignore
self.report_entries.append(
ReportEntryCreate(

View File

@ -42,8 +42,13 @@ class ChowdownMigrator(BaseMigrator):
for slug, recipe_id, status in results:
if status:
try:
original_image = recipe_lookup.get(slug).image
cd_image = image_dir.joinpath(original_image)
r = recipe_lookup.get(slug)
if not r:
continue
if r.image:
cd_image = image_dir.joinpath(r.image)
except StopIteration:
continue
if cd_image:

View File

@ -1,5 +1,6 @@
from pathlib import Path
from mealie.core.exceptions import UnexpectedNone
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_exports import GroupDataExport
from mealie.schema.recipe import CategoryBase
@ -41,6 +42,9 @@ class RecipeBulkActionsService(BaseService):
group = self.repos.groups.get_one(self.group.id)
if group is None:
raise UnexpectedNone("Failed to purge exports for group, no group found")
for match in group.directory.glob("**/export/*zip"):
if match.is_file():
match.unlink()
@ -52,8 +56,8 @@ class RecipeBulkActionsService(BaseService):
for slug in recipes:
recipe = self.repos.recipes.get_one(slug)
if recipe is None:
self.logger.error(f"Failed to set settings for recipe {slug}, no recipe found")
if recipe is None or recipe.settings is None:
raise UnexpectedNone(f"Failed to set settings for recipe {slug}, no recipe found")
settings.locked = recipe.settings.locked
recipe.settings = settings
@ -69,9 +73,12 @@ class RecipeBulkActionsService(BaseService):
recipe = self.repos.recipes.get_one(slug)
if recipe is None:
self.logger.error(f"Failed to tag recipe {slug}, no recipe found")
raise UnexpectedNone(f"Failed to tag recipe {slug}, no recipe found")
recipe.tags += tags
if recipe.tags is None:
recipe.tags = []
recipe.tags += tags # type: ignore
try:
self.repos.recipes.update(slug, recipe)
@ -84,9 +91,12 @@ class RecipeBulkActionsService(BaseService):
recipe = self.repos.recipes.get_one(slug)
if recipe is None:
self.logger.error(f"Failed to categorize recipe {slug}, no recipe found")
raise UnexpectedNone(f"Failed to categorize recipe {slug}, no recipe found")
recipe.recipe_category += categories
if recipe.recipe_category is None:
recipe.recipe_category = []
recipe.recipe_category += categories # type: ignore
try:
self.repos.recipes.update(slug, recipe)

View File

@ -50,6 +50,8 @@ class RecipeService(BaseService):
return recipe
def can_update(self, recipe: Recipe) -> bool:
if recipe.settings is None:
raise exceptions.UnexpectedNone("Recipe Settings is None")
return recipe.settings.locked is False or self.user.id == recipe.user_id
def can_lock_unlock(self, recipe: Recipe) -> bool:
@ -66,6 +68,9 @@ class RecipeService(BaseService):
except FileNotFoundError:
self.logger.error(f"Recipe Directory not Found: {original_slug}")
if recipe.assets is None:
recipe.assets = []
all_asset_files = [x.file_name for x in recipe.assets]
for file in recipe.asset_dir.iterdir():
@ -92,7 +97,7 @@ class RecipeService(BaseService):
additional_attrs["group_id"] = user.group_id
if additional_attrs.get("tags"):
for i in range(len(additional_attrs.get("tags"))):
for i in range(len(additional_attrs.get("tags", []))):
additional_attrs["tags"][i]["group_id"] = user.group_id
if not additional_attrs.get("recipe_ingredient"):
@ -105,6 +110,9 @@ class RecipeService(BaseService):
def create_one(self, create_data: Union[Recipe, CreateRecipe]) -> Recipe:
if create_data.name is None:
create_data.name = "New Recipe"
data: Recipe = self._recipe_creation_factory(
self.user,
name=create_data.name,
@ -134,8 +142,8 @@ class RecipeService(BaseService):
with temp_path.open("wb") as buffer:
shutil.copyfileobj(archive.file, buffer)
recipe_dict = None
recipe_image = None
recipe_dict: dict | None = None
recipe_image: bytes | None = None
with ZipFile(temp_path) as myzip:
for file in myzip.namelist():
@ -146,10 +154,15 @@ class RecipeService(BaseService):
with myzip.open(file) as myfile:
recipe_image = myfile.read()
if recipe_dict is None:
raise exceptions.UnexpectedNone("No json data found in Zip")
recipe = self.create_one(Recipe(**recipe_dict))
if recipe:
if recipe and recipe.id:
data_service = RecipeDataService(recipe.id)
if recipe_image:
data_service.write_image(recipe_image, "webp")
return recipe
@ -172,6 +185,10 @@ class RecipeService(BaseService):
"""
recipe = self._get_recipe(slug)
if recipe is None or recipe.settings is None:
raise exceptions.NoEntryFound("Recipe not found.")
if not self.can_update(recipe):
raise exceptions.PermissionDenied("You do not have permission to edit this recipe.")
@ -189,9 +206,12 @@ class RecipeService(BaseService):
return new_data
def patch_one(self, slug: str, patch_data: Recipe) -> Recipe:
recipe = self._pre_update_check(slug, patch_data)
recipe: Recipe | None = self._pre_update_check(slug, patch_data)
recipe = self.repos.recipes.by_group(self.group.id).get_one(slug)
if recipe is None:
raise exceptions.NoEntryFound("Recipe not found.")
new_data = self.repos.recipes.by_group(self.group.id).patch(recipe.slug, patch_data.dict(exclude_unset=True))
self.check_assets(new_data, recipe.slug)
@ -210,6 +230,6 @@ class RecipeService(BaseService):
# =================================================================
# Recipe Template Methods
def render_template(self, recipe: Recipe, temp_dir: Path, template: str = None) -> Path:
def render_template(self, recipe: Recipe, temp_dir: Path, template: str) -> Path:
t_service = TemplateService(temp_dir)
return t_service.render(recipe, template)

View File

@ -16,7 +16,7 @@ class TemplateType(str, enum.Enum):
class TemplateService(BaseService):
def __init__(self, temp: Path = None) -> None:
def __init__(self, temp: Path | None = None) -> None:
"""Creates a template service that can be used for multiple template generations
A temporary directory must be provided as a place holder for where to render all templates
Args:
@ -58,7 +58,7 @@ class TemplateService(BaseService):
return TemplateType(t_type)
def render(self, recipe: Recipe, template: str = None) -> Path:
def render(self, recipe: Recipe, template: str) -> Path:
"""
Renders a TemplateType in a temporary directory and returns the path to the file.
@ -87,6 +87,9 @@ class TemplateService(BaseService):
"""
self.__check_temp(self._render_json)
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_json")
save_path = self.temp.joinpath(f"{recipe.slug}.json")
with open(save_path, "w") as f:
f.write(recipe.json(indent=4, by_alias=True))
@ -100,6 +103,9 @@ class TemplateService(BaseService):
"""
self.__check_temp(self._render_jinja2)
if j2_template is None:
raise ValueError("Template must be provided for method _render_jinja2")
j2_path: Path = self.directories.TEMPLATE_DIR / j2_template
if not j2_path.is_file():
@ -113,6 +119,9 @@ class TemplateService(BaseService):
save_name = f"{recipe.slug}{j2_path.suffix}"
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_jinja2")
save_path = self.temp.joinpath(save_name)
with open(save_path, "w") as f:
@ -124,6 +133,10 @@ class TemplateService(BaseService):
self.__check_temp(self._render_jinja2)
image_asset = recipe.image_dir.joinpath(RecipeImageTypes.original.value)
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_zip")
zip_temp = self.temp.joinpath(f"{recipe.slug}.zip")
with ZipFile(zip_temp, "w") as myzip:

View File

@ -19,7 +19,7 @@ class ParserErrors(str, Enum):
CONNECTION_ERROR = "CONNECTION_ERROR"
def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]:
def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras | None]:
"""Main entry point for generating a recipe from a URL. Pass in a URL and
a Recipe object will be returned if successful.
@ -43,6 +43,10 @@ def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]:
try:
recipe_data_service.scrape_image(new_recipe.image)
if new_recipe.name is None:
new_recipe.name = "Untitled"
new_recipe.slug = slugify(new_recipe.name)
new_recipe.image = cache.new_key(4)
except Exception as e:

View File

@ -176,7 +176,7 @@ class RecipeScraperPackage(ABCScraperStrategy):
ingredients = []
try:
instruct = scraped_schema.instructions()
instruct: list | str = scraped_schema.instructions()
except Exception:
instruct = []
@ -212,7 +212,7 @@ class RecipeScraperOpenGraph(ABCScraperStrategy):
"""
def og_field(properties: dict, field_name: str) -> str:
return next((val for name, val in properties if name == field_name), None)
return next((val for name, val in properties if name == field_name), "")
def og_fields(properties: list[tuple[str, str]], field_name: str) -> list[str]:
return list({val for name, val in properties if name == field_name})

View File

@ -31,6 +31,9 @@ class PasswordResetService(BaseService):
def send_reset_email(self, email: str):
token_entry = self.generate_reset_token(email)
if token_entry is None:
return None
# Send Email
email_servive = EmailService()
reset_url = f"{self.settings.BASE_URL}/reset-password?token={token_entry.token}"

View File

@ -35,7 +35,8 @@ class RegistrationService:
can_organize=new_group,
)
return self.repos.users.create(new_user)
# TODO: problem with repository type, not type here
return self.repos.users.create(new_user) # type: ignore
def _register_new_group(self) -> GroupInDB:
group_data = GroupBase(name=self.registration.group)
@ -74,7 +75,13 @@ class RegistrationService:
token_entry = self.repos.group_invite_tokens.get_one(registration.group_token)
if not token_entry:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
group = self.repos.groups.get_one(token_entry.group_id)
maybe_none_group = self.repos.groups.get_one(token_entry.group_id)
if maybe_none_group is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
group = maybe_none_group
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Missing group"})

View File

@ -105,5 +105,5 @@ pgsql = ["psycopg2-binary"]
python_version = "3.10"
ignore_missing_imports = true
follow_imports = "skip"
strict_optional = false # TODO: Fix none type checks - temporary stop-gap to implement mypy
strict_optional = true
plugins = "pydantic.mypy"