fix: group creation (#1126)

* fix: unify group creation - closes #1100

* tests: disable password hashing during testing

* tests: fix email config tests
This commit is contained in:
Hayden 2022-04-02 19:33:15 -08:00 committed by GitHub
parent e9bb39c744
commit c988de1921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 113 additions and 33 deletions

View File

@ -0,0 +1 @@
from .security import *

View File

@ -0,0 +1,43 @@
from functools import lru_cache
from typing import Protocol
from passlib.context import CryptContext
from mealie.core.config import get_app_settings
class Hasher(Protocol):
def hash(self, password: str) -> str:
...
def verify(self, password: str, hashed: str) -> bool:
...
class FakeHasher:
def hash(self, password: str) -> str:
return password
def verify(self, password: str, hashed: str) -> bool:
return password == hashed
class PasslibHasher:
def __init__(self) -> None:
self.ctx = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash(self, password: str) -> str:
return self.ctx.hash(password)
def verify(self, password: str, hashed: str) -> bool:
return self.ctx.verify(password, hashed)
@lru_cache(maxsize=1)
def get_hasher() -> Hasher:
settings = get_app_settings()
if settings.TESTING:
return FakeHasher()
return PasslibHasher()

View File

@ -1,16 +1,15 @@
import secrets import secrets
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from jose import jwt from jose import jwt
from passlib.context import CryptContext
from mealie.core.config import get_app_settings from mealie.core.config import get_app_settings
from mealie.core.security.hasher import get_hasher
from mealie.repos.all_repositories import get_repositories from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.schema.user import PrivateUser from mealie.schema.user import PrivateUser
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ALGORITHM = "HS256" ALGORITHM = "HS256"
@ -20,7 +19,7 @@ def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
to_encode = data.copy() to_encode = data.copy()
expires_delta = expires_delta or timedelta(hours=settings.TOKEN_TIME) expires_delta = expires_delta or timedelta(hours=settings.TOKEN_TIME)
expire = datetime.utcnow() + expires_delta expire = datetime.now(timezone.utc) + expires_delta
to_encode["exp"] = expire to_encode["exp"] = expire
return jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM) return jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM)
@ -31,7 +30,7 @@ def create_file_token(file_path: Path) -> str:
return create_access_token(token_data, expires_delta=timedelta(minutes=30)) return create_access_token(token_data, expires_delta=timedelta(minutes=30))
def create_recipe_slug_token(file_path: str) -> str: def create_recipe_slug_token(file_path: str | Path) -> str:
token_data = {"slug": str(file_path)} token_data = {"slug": str(file_path)}
return create_access_token(token_data, expires_delta=timedelta(minutes=30)) return create_access_token(token_data, expires_delta=timedelta(minutes=30))
@ -96,12 +95,12 @@ def authenticate_user(session, email: str, password: str) -> PrivateUser | bool:
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Compares a plain string to a hashed password""" """Compares a plain string to a hashed password"""
return pwd_context.verify(plain_password, hashed_password) return get_hasher().verify(plain_password, hashed_password)
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
"""Takes in a raw password and hashes it. Used prior to saving a new password to the database.""" """Takes in a raw password and hashes it. Used prior to saving a new password to the database."""
return pwd_context.hash(password) return get_hasher().hash(password)
def url_safe_token() -> str: def url_safe_token() -> str:

View File

@ -16,7 +16,7 @@ from mealie.repos.repository_factory import AllRepositories
from mealie.repos.seed.init_users import default_user_init from mealie.repos.seed.init_users import default_user_init
from mealie.repos.seed.seeders import IngredientFoodsSeeder, IngredientUnitsSeeder, MultiPurposeLabelSeeder from mealie.repos.seed.seeders import IngredientFoodsSeeder, IngredientUnitsSeeder, MultiPurposeLabelSeeder
from mealie.schema.user.user import GroupBase from mealie.schema.user.user import GroupBase
from mealie.services.group_services.group_utils import create_new_group from mealie.services.group_services.group_service import GroupService
PROJECT_DIR = Path(__file__).parent.parent.parent PROJECT_DIR = Path(__file__).parent.parent.parent
@ -44,7 +44,8 @@ def default_group_init(db: AllRepositories):
settings = get_app_settings() settings = get_app_settings()
logger.info("Generating Default Group") logger.info("Generating Default Group")
create_new_group(db, GroupBase(name=settings.DEFAULT_GROUP))
GroupService.create_group(db, GroupBase(name=settings.DEFAULT_GROUP))
# Adapted from https://alembic.sqlalchemy.org/en/latest/cookbook.html#test-current-database-revision-is-at-head-s # Adapted from https://alembic.sqlalchemy.org/en/latest/cookbook.html#test-current-database-revision-is-at-head-s

View File

@ -8,6 +8,7 @@ from mealie.schema.mapper import mapper
from mealie.schema.query import GetAll from mealie.schema.query import GetAll
from mealie.schema.response.responses import ErrorResponse from mealie.schema.response.responses import ErrorResponse
from mealie.schema.user.user import GroupBase, GroupInDB from mealie.schema.user.user import GroupBase, GroupInDB
from mealie.services.group_services.group_service import GroupService
from .._base import BaseAdminController, controller from .._base import BaseAdminController, controller
from .._base.dependencies import SharedDependencies from .._base.dependencies import SharedDependencies
@ -44,7 +45,7 @@ class AdminUserManagementRoutes(BaseAdminController):
@router.post("", response_model=GroupInDB, status_code=status.HTTP_201_CREATED) @router.post("", response_model=GroupInDB, status_code=status.HTTP_201_CREATED)
def create_one(self, data: GroupBase): def create_one(self, data: GroupBase):
return self.mixins.create_one(data) return GroupService.create_group(self.deps.repos, data)
@router.get("/{item_id}", response_model=GroupInDB) @router.get("/{item_id}", response_model=GroupInDB)
def get_one(self, item_id: UUID4): def get_one(self, item_id: UUID4):
@ -69,7 +70,7 @@ class AdminUserManagementRoutes(BaseAdminController):
def delete_one(self, item_id: UUID4): def delete_one(self, item_id: UUID4):
item = self.repo.get_one(item_id) item = self.repo.get_one(item_id)
if len(item.users) > 0: if item and len(item.users) > 0:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorResponse.respond(message="Cannot delete group with users"), detail=ErrorResponse.respond(message="Cannot delete group with users"),

View File

@ -2,7 +2,9 @@ from pydantic import UUID4
from mealie.pkgs.stats import fs_stats from mealie.pkgs.stats import fs_stats
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_preferences import CreateGroupPreferences
from mealie.schema.group.group_statistics import GroupStatistics, GroupStorage from mealie.schema.group.group_statistics import GroupStatistics, GroupStorage
from mealie.schema.user.user import GroupBase
from mealie.services._base_service import BaseService from mealie.services._base_service import BaseService
ALLOWED_SIZE = 500 * fs_stats.megabyte ALLOWED_SIZE = 500 * fs_stats.megabyte
@ -14,6 +16,23 @@ class GroupService(BaseService):
self.repos = repos self.repos = repos
super().__init__() super().__init__()
@staticmethod
def create_group(repos: AllRepositories, g_base: GroupBase, prefs: CreateGroupPreferences | None = None):
"""
Creates a new group in the database with the required associated table references to ensure
the group includes required preferences.
"""
new_group = repos.groups.create(g_base)
if prefs is None:
prefs = CreateGroupPreferences(group_id=new_group.id)
else:
prefs.group_id = new_group.id
repos.group_preferences.create(prefs)
return new_group
def calculate_statistics(self, group_id: None | UUID4 = None) -> GroupStatistics: def calculate_statistics(self, group_id: None | UUID4 = None) -> GroupStatistics:
""" """
calculate_statistics calculates the statistics for the group and returns calculate_statistics calculates the statistics for the group and returns

View File

@ -1,18 +0,0 @@
from uuid import uuid4
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_preferences import CreateGroupPreferences
from mealie.schema.user.user import GroupBase, GroupInDB
def create_new_group(db: AllRepositories, g_base: GroupBase, g_preferences: CreateGroupPreferences = None) -> GroupInDB:
created_group = db.groups.create(g_base)
# Assign Temporary ID before group is created
g_preferences = g_preferences or CreateGroupPreferences(group_id=uuid4())
g_preferences.group_id = created_group.id
db.group_preferences.create(g_preferences)
return created_group

View File

@ -8,7 +8,7 @@ from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_preferences import CreateGroupPreferences from mealie.schema.group.group_preferences import CreateGroupPreferences
from mealie.schema.user.registration import CreateUserRegistration from mealie.schema.user.registration import CreateUserRegistration
from mealie.schema.user.user import GroupBase, GroupInDB, PrivateUser, UserIn from mealie.schema.user.user import GroupBase, GroupInDB, PrivateUser, UserIn
from mealie.services.group_services.group_utils import create_new_group from mealie.services.group_services.group_service import GroupService
class RegistrationService: class RegistrationService:
@ -19,7 +19,7 @@ class RegistrationService:
self.logger = logger self.logger = logger
self.repos = db self.repos = db
def _create_new_user(self, group: GroupInDB, new_group=bool) -> PrivateUser: def _create_new_user(self, group: GroupInDB, new_group: bool) -> PrivateUser:
new_user = UserIn( new_user = UserIn(
email=self.registration.email, email=self.registration.email,
username=self.registration.username, username=self.registration.username,
@ -49,7 +49,7 @@ class RegistrationService:
recipe_disable_amount=self.registration.advanced, recipe_disable_amount=self.registration.advanced,
) )
return create_new_group(self.repos, group_data, group_preferences) return GroupService.create_group(self.repos, group_data, group_preferences)
def register_user(self, registration: CreateUserRegistration) -> PrivateUser: def register_user(self, registration: CreateUserRegistration) -> PrivateUser:
self.registration = registration self.registration = registration

View File

@ -0,0 +1,22 @@
from pytest import MonkeyPatch
from mealie.core.config import get_app_settings
from mealie.core.security.hasher import FakeHasher, PasslibHasher, get_hasher
def test_get_hasher(monkeypatch: MonkeyPatch):
hasher = get_hasher()
assert isinstance(hasher, FakeHasher)
monkeypatch.setenv("TESTING", "0")
get_hasher.cache_clear()
get_app_settings.cache_clear()
hasher = get_hasher()
assert isinstance(hasher, PasslibHasher)
get_app_settings.cache_clear()
get_hasher.cache_clear()

View File

@ -44,8 +44,11 @@ def email_service(monkeypatch) -> EmailService:
return email_service return email_service
def test_email_disabled(): def test_email_disabled(monkeypatch):
email_service = EmailService(TestEmailSender()) email_service = EmailService(TestEmailSender())
monkeypatch.setenv("SMTP_HOST", "") # disable email
get_app_settings.cache_clear() get_app_settings.cache_clear()
email_service.settings = get_app_settings() email_service.settings = get_app_settings()
success = email_service.send_test_email(FAKE_ADDRESS) success = email_service.send_test_email(FAKE_ADDRESS)

View File

@ -60,8 +60,17 @@ def test_pg_connection_args(monkeypatch):
def test_smtp_enable(monkeypatch): def test_smtp_enable(monkeypatch):
monkeypatch.setenv("SMTP_HOST", "")
monkeypatch.setenv("SMTP_PORT", "")
monkeypatch.setenv("SMTP_TLS", "true")
monkeypatch.setenv("SMTP_FROM_NAME", "")
monkeypatch.setenv("SMTP_FROM_EMAIL", "")
monkeypatch.setenv("SMTP_USER", "")
monkeypatch.setenv("SMTP_PASSWORD", "")
get_app_settings.cache_clear() get_app_settings.cache_clear()
app_settings = get_app_settings() app_settings = get_app_settings()
assert app_settings.SMTP_ENABLE is False assert app_settings.SMTP_ENABLE is False
monkeypatch.setenv("SMTP_HOST", "email.mealie.io") monkeypatch.setenv("SMTP_HOST", "email.mealie.io")