diff --git a/mealie/db/db_setup.py b/mealie/db/db_setup.py index 6e92751be5f5..2fb343319965 100644 --- a/mealie/db/db_setup.py +++ b/mealie/db/db_setup.py @@ -31,22 +31,31 @@ SessionLocal, engine = sql_global_init(settings.DB_URL) # type: ignore @contextmanager -def with_session() -> Session: +def session_context() -> Session: + """ + session_context() provides a managed session to the database that is automatically + closed when the context is exited. This is the preferred method of accessing the + database. + + Note: use `generate_session` when using the `Depends` function from FastAPI + """ global SessionLocal sess = SessionLocal() - try: yield sess finally: sess.close() -def create_session() -> Session: - global SessionLocal - return SessionLocal() - - def generate_session() -> Generator[Session, None, None]: + """ + WARNING: This function should _only_ be called when used with + using the `Depends` function from FastAPI. This function will leak + sessions if used outside of the context of a request. + + Use `with_session` instead. That function will allow you to use the + session within a context manager + """ global SessionLocal db = SessionLocal() try: diff --git a/mealie/db/init_db.py b/mealie/db/init_db.py index dab910191dd1..259bf09a4dcb 100644 --- a/mealie/db/init_db.py +++ b/mealie/db/init_db.py @@ -9,7 +9,7 @@ from alembic.config import Config from alembic.runtime import migration from mealie.core import root_logger from mealie.core.config import get_app_settings -from mealie.db.db_setup import create_session +from mealie.db.db_setup import session_context from mealie.db.fixes.fix_slug_foods import fix_slug_food_names from mealie.repos.all_repositories import get_repositories from mealie.repos.repository_factory import AllRepositories @@ -67,41 +67,40 @@ def connect(session: orm.Session) -> bool: def main(): - session = create_session() - # Wait for database to connect max_retry = 10 wait_seconds = 1 - while True: - if connect(session): - logger.info("Database connection established.") - break + with session_context() as session: + while True: + if connect(session): + logger.info("Database connection established.") + break - logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...") - max_retry -= 1 + logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...") + max_retry -= 1 - sleep(wait_seconds) + sleep(wait_seconds) - if max_retry == 0: - raise ConnectionError("Database connection failed - exiting application.") + if max_retry == 0: + raise ConnectionError("Database connection failed - exiting application.") - alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini")) - if db_is_at_head(alembic_cfg): - logger.info("Migration not needed.") - else: - logger.info("Migration needed. Performing migration...") - command.upgrade(alembic_cfg, "head") + alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini")) + if db_is_at_head(alembic_cfg): + logger.info("Migration not needed.") + else: + logger.info("Migration needed. Performing migration...") + command.upgrade(alembic_cfg, "head") - db = get_repositories(session) + db = get_repositories(session) - if db.users.get_all(): - logger.info("Database exists") - else: - logger.info("Database contains no users, initializing...") - init_db(db) + if db.users.get_all(): + logger.info("Database exists") + else: + logger.info("Database contains no users, initializing...") + init_db(db) - safe_try(lambda: fix_slug_food_names(db)) + safe_try(lambda: fix_slug_food_names(db)) if __name__ == "__main__": diff --git a/mealie/scripts/reset_locked_users.py b/mealie/scripts/reset_locked_users.py index 34cfb5545505..ca2e5d6e79d3 100644 --- a/mealie/scripts/reset_locked_users.py +++ b/mealie/scripts/reset_locked_users.py @@ -1,5 +1,5 @@ from mealie.core import root_logger -from mealie.db.db_setup import with_session +from mealie.db.db_setup import session_context from mealie.repos.repository_factory import AllRepositories from mealie.services.user_services.user_service import UserService @@ -13,7 +13,7 @@ def main(): logger = root_logger.get_logger() - with with_session() as session: + with session_context() as session: repos = AllRepositories(session) user_service = UserService(repos) diff --git a/mealie/services/event_bus_service/event_bus_listeners.py b/mealie/services/event_bus_service/event_bus_listeners.py index 3c129af5385b..9c73993e2478 100644 --- a/mealie/services/event_bus_service/event_bus_listeners.py +++ b/mealie/services/event_bus_service/event_bus_listeners.py @@ -1,4 +1,7 @@ +import contextlib import json +from abc import ABC, abstractmethod +from collections.abc import Generator from datetime import datetime, timezone from typing import cast from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit @@ -7,6 +10,7 @@ from fastapi.encoders import jsonable_encoder from pydantic import UUID4 from sqlalchemy.orm.session import Session +from mealie.db.db_setup import session_context from mealie.db.models.group.webhooks import GroupWebhooksModel from mealie.repos.all_repositories import get_repositories from mealie.repos.repository_factory import AllRepositories @@ -18,34 +22,58 @@ from .event_types import Event, EventDocumentType, EventTypes, EventWebhookData from .publisher import ApprisePublisher, PublisherLike, WebhookPublisher -class EventListenerBase: +class EventListenerBase(ABC): + session: Session | None + def __init__(self, session: Session, group_id: UUID4, publisher: PublisherLike) -> None: self.session = session self.group_id = group_id self.publisher = publisher + @abstractmethod def get_subscribers(self, event: Event) -> list: """Get a list of all subscribers to this event""" ... + @abstractmethod def publish_to_subscribers(self, event: Event, subscribers: list) -> None: """Publishes the event to all subscribers""" ... + @contextlib.contextmanager + def ensure_session(self) -> Generator[None, None, None]: + """ + ensure_session ensures that a session is available for the caller by checking if a session + was provided during construction, and if not, creating a new session with the `with_session` + function and closing it when the context manager exits. + + This is _required_ when working with sessions inside an event bus listener where the listener + may be constructed during a request where the session is provided by the request, but the when + run as a scheduled task, the session is not provided and must be created. + """ + if self.session is None: + with session_context() as session: + self.session = session + yield + + else: + yield + class AppriseEventListener(EventListenerBase): def __init__(self, session: Session, group_id: UUID4) -> None: super().__init__(session, group_id, ApprisePublisher()) def get_subscribers(self, event: Event) -> list[str]: - repos = AllRepositories(self.session) + with self.ensure_session(): + repos = AllRepositories(self.session) - notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore - self.group_id - ).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate) + notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore + self.group_id + ).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate) - urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)] - urls = AppriseEventListener.update_urls_with_event_data(urls, event) + urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)] + urls = AppriseEventListener.update_urls_with_event_data(urls, event) return urls @@ -120,12 +148,13 @@ class WebhookEventListener(EventListenerBase): def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]: """Fetches all scheduled webhooks from the database""" - return ( - self.session.query(GroupWebhooksModel) - .where( - GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison - GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(), - GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(), + with self.ensure_session(): + return ( + self.session.query(GroupWebhooksModel) + .where( + GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison + GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(), + GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(), + ) + .all() ) - .all() - ) diff --git a/mealie/services/event_bus_service/event_bus_service.py b/mealie/services/event_bus_service/event_bus_service.py index f89c8442a75e..1fd0c59b619b 100644 --- a/mealie/services/event_bus_service/event_bus_service.py +++ b/mealie/services/event_bus_service/event_bus_service.py @@ -40,12 +40,13 @@ class EventSource: class EventBusService: + bg: BackgroundTasks | None + session: Session | None + group_id: UUID4 | None + def __init__( self, bg: Optional[BackgroundTasks] = None, session: Optional[Session] = None, group_id: UUID4 | None = None ) -> None: - if not session: - session = next(generate_session()) - self.bg = bg self.session = session self.group_id = group_id diff --git a/mealie/services/scheduler/tasks/post_webhooks.py b/mealie/services/scheduler/tasks/post_webhooks.py index 1e533e9f4319..b62b5f2c4b71 100644 --- a/mealie/services/scheduler/tasks/post_webhooks.py +++ b/mealie/services/scheduler/tasks/post_webhooks.py @@ -3,7 +3,7 @@ from typing import Optional from pydantic import UUID4 -from mealie.db.db_setup import create_session +from mealie.db.db_setup import session_context from mealie.repos.all_repositories import get_repositories from mealie.schema.response.pagination import PaginationQuery from mealie.services.event_bus_service.event_bus_service import EventBusService @@ -31,10 +31,11 @@ def post_group_webhooks(start_dt: Optional[datetime] = None, group_id: Optional[ if group_id is None: # publish the webhook event to each group's event bus - session = create_session() - repos = get_repositories(session) - groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1)) - group_ids = [group.id for group in groups_data.items] + + with session_context() as session: + repos = get_repositories(session) + groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1)) + group_ids = [group.id for group in groups_data.items] else: group_ids = [group_id] diff --git a/mealie/services/scheduler/tasks/purge_group_exports.py b/mealie/services/scheduler/tasks/purge_group_exports.py index 1c34a126c417..648215f83291 100644 --- a/mealie/services/scheduler/tasks/purge_group_exports.py +++ b/mealie/services/scheduler/tasks/purge_group_exports.py @@ -3,7 +3,7 @@ from pathlib import Path from mealie.core import root_logger from mealie.core.config import get_app_dirs -from mealie.db.db_setup import create_session +from mealie.db.db_setup import session_context from mealie.db.models.group.exports import GroupDataExportsModel ONE_DAY_AS_MINUTES = 1440 @@ -15,20 +15,19 @@ def purge_group_data_exports(max_minutes_old=ONE_DAY_AS_MINUTES): logger.info("purging group data exports") limit = datetime.datetime.now() - datetime.timedelta(minutes=max_minutes_old) - session = create_session() - results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit) + with session_context() as session: + results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit) - total_removed = 0 - for result in results: - session.delete(result) - Path(result.path).unlink(missing_ok=True) - total_removed += 1 + total_removed = 0 + for result in results: + session.delete(result) + Path(result.path).unlink(missing_ok=True) + total_removed += 1 - session.commit() - session.close() + session.commit() - logger.info(f"finished purging group data exports. {total_removed} exports removed from group data") + logger.info(f"finished purging group data exports. {total_removed} exports removed from group data") def purge_excess_files() -> None: diff --git a/mealie/services/scheduler/tasks/purge_password_reset.py b/mealie/services/scheduler/tasks/purge_password_reset.py index fdbeacfa863a..ab8e4808bda5 100644 --- a/mealie/services/scheduler/tasks/purge_password_reset.py +++ b/mealie/services/scheduler/tasks/purge_password_reset.py @@ -1,7 +1,7 @@ import datetime from mealie.core import root_logger -from mealie.db.db_setup import create_session +from mealie.db.db_setup import session_context from mealie.db.models.users.password_reset import PasswordResetModel logger = root_logger.get_logger() @@ -13,8 +13,9 @@ def purge_password_reset_tokens(): """Purges all events after x days""" logger.info("purging password reset tokens") limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD) - session = create_session() - session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete() - session.commit() - session.close() - logger.info("password reset tokens purges") + + with session_context() as session: + session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete() + session.commit() + session.close() + logger.info("password reset tokens purges") diff --git a/mealie/services/scheduler/tasks/purge_registration.py b/mealie/services/scheduler/tasks/purge_registration.py index 8a093eee5906..33f9efc78dbc 100644 --- a/mealie/services/scheduler/tasks/purge_registration.py +++ b/mealie/services/scheduler/tasks/purge_registration.py @@ -1,7 +1,7 @@ import datetime from mealie.core import root_logger -from mealie.db.db_setup import create_session +from mealie.db.db_setup import session_context from mealie.db.models.group import GroupInviteToken logger = root_logger.get_logger() @@ -13,8 +13,10 @@ def purge_group_registration(): """Purges all events after x days""" logger.info("purging expired registration tokens") limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD) - session = create_session() - session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete() - session.commit() - session.close() + + with session_context() as session: + session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete() + session.commit() + session.close() + logger.info("registration token purged") diff --git a/mealie/services/scheduler/tasks/reset_locked_users.py b/mealie/services/scheduler/tasks/reset_locked_users.py index 1e3c51322638..3dd1ab7f82ea 100644 --- a/mealie/services/scheduler/tasks/reset_locked_users.py +++ b/mealie/services/scheduler/tasks/reset_locked_users.py @@ -1,5 +1,5 @@ from mealie.core import root_logger -from mealie.db.db_setup import with_session +from mealie.db.db_setup import session_context from mealie.repos.repository_factory import AllRepositories from mealie.services.user_services.user_service import UserService @@ -8,7 +8,7 @@ def locked_user_reset(): logger = root_logger.get_logger() logger.info("resetting locked users") - with with_session() as session: + with session_context() as session: repos = AllRepositories(session) user_service = UserService(repos) diff --git a/tests/integration_tests/user_tests/test_user_password_reset_service.py b/tests/integration_tests/user_tests/test_user_password_reset_service.py index d2aa48975d18..c7c0f5a5c873 100644 --- a/tests/integration_tests/user_tests/test_user_password_reset_service.py +++ b/tests/integration_tests/user_tests/test_user_password_reset_service.py @@ -3,7 +3,7 @@ import json import pytest from fastapi.testclient import TestClient -from mealie.db.db_setup import create_session +from mealie.db.db_setup import session_context from mealie.services.user_services.password_reset_service import PasswordResetService from tests.utils.factories import random_string from tests.utils.fixture_schemas import TestUser @@ -31,10 +31,10 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s cased_email += l.lower() cased_email - session = create_session() - service = PasswordResetService(session) - token = service.generate_reset_token(cased_email) - assert token is not None + with session_context() as session: + service = PasswordResetService(session) + token = service.generate_reset_token(cased_email) + assert token is not None new_password = random_string(15) @@ -59,8 +59,6 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s response = api_client.get(Routes.self, headers={"Authorization": f"Bearer {new_token}"}) assert response.status_code == 200 - session.close() - # Test successful password reset response = api_client.post(Routes.base, json=payload) assert response.status_code == 400 diff --git a/tests/unit_tests/test_security.py b/tests/unit_tests/test_security.py index dca3734120ec..0fb3de04fbfa 100644 --- a/tests/unit_tests/test_security.py +++ b/tests/unit_tests/test_security.py @@ -5,7 +5,7 @@ from pytest import MonkeyPatch from mealie.core import security from mealie.core.config import get_app_settings from mealie.core.dependencies import validate_file_token -from mealie.db.db_setup import create_session +from mealie.db.db_setup import session_context from tests.utils.factories import random_string @@ -47,5 +47,8 @@ def test_ldap_authentication_mocked(monkeypatch: MonkeyPatch): monkeypatch.setattr(ldap, "initialize", ldap_initialize_mock) get_app_settings.cache_clear() - result = security.authenticate_user(create_session(), user, password) + + with session_context() as session: + result = security.authenticate_user(session, user, password) + assert result is False