mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-05-24 01:12:54 -04:00
fix: unclosed sessions (#1734)
* resolve session leak * cleanup session management functions
This commit is contained in:
parent
a3904c45d8
commit
e516a2e801
@ -31,22 +31,31 @@ SessionLocal, engine = sql_global_init(settings.DB_URL) # type: ignore
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_session() -> Session:
|
||||
def session_context() -> Session:
|
||||
"""
|
||||
session_context() provides a managed session to the database that is automatically
|
||||
closed when the context is exited. This is the preferred method of accessing the
|
||||
database.
|
||||
|
||||
Note: use `generate_session` when using the `Depends` function from FastAPI
|
||||
"""
|
||||
global SessionLocal
|
||||
sess = SessionLocal()
|
||||
|
||||
try:
|
||||
yield sess
|
||||
finally:
|
||||
sess.close()
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
global SessionLocal
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
def generate_session() -> Generator[Session, None, None]:
|
||||
"""
|
||||
WARNING: This function should _only_ be called when used with
|
||||
using the `Depends` function from FastAPI. This function will leak
|
||||
sessions if used outside of the context of a request.
|
||||
|
||||
Use `with_session` instead. That function will allow you to use the
|
||||
session within a context manager
|
||||
"""
|
||||
global SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
@ -9,7 +9,7 @@ from alembic.config import Config
|
||||
from alembic.runtime import migration
|
||||
from mealie.core import root_logger
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.db.db_setup import create_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.db.fixes.fix_slug_foods import fix_slug_food_names
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
@ -67,41 +67,40 @@ def connect(session: orm.Session) -> bool:
|
||||
|
||||
|
||||
def main():
|
||||
session = create_session()
|
||||
|
||||
# Wait for database to connect
|
||||
max_retry = 10
|
||||
wait_seconds = 1
|
||||
|
||||
while True:
|
||||
if connect(session):
|
||||
logger.info("Database connection established.")
|
||||
break
|
||||
with session_context() as session:
|
||||
while True:
|
||||
if connect(session):
|
||||
logger.info("Database connection established.")
|
||||
break
|
||||
|
||||
logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...")
|
||||
max_retry -= 1
|
||||
logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...")
|
||||
max_retry -= 1
|
||||
|
||||
sleep(wait_seconds)
|
||||
sleep(wait_seconds)
|
||||
|
||||
if max_retry == 0:
|
||||
raise ConnectionError("Database connection failed - exiting application.")
|
||||
if max_retry == 0:
|
||||
raise ConnectionError("Database connection failed - exiting application.")
|
||||
|
||||
alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini"))
|
||||
if db_is_at_head(alembic_cfg):
|
||||
logger.info("Migration not needed.")
|
||||
else:
|
||||
logger.info("Migration needed. Performing migration...")
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini"))
|
||||
if db_is_at_head(alembic_cfg):
|
||||
logger.info("Migration not needed.")
|
||||
else:
|
||||
logger.info("Migration needed. Performing migration...")
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
db = get_repositories(session)
|
||||
db = get_repositories(session)
|
||||
|
||||
if db.users.get_all():
|
||||
logger.info("Database exists")
|
||||
else:
|
||||
logger.info("Database contains no users, initializing...")
|
||||
init_db(db)
|
||||
if db.users.get_all():
|
||||
logger.info("Database exists")
|
||||
else:
|
||||
logger.info("Database contains no users, initializing...")
|
||||
init_db(db)
|
||||
|
||||
safe_try(lambda: fix_slug_food_names(db))
|
||||
safe_try(lambda: fix_slug_food_names(db))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,5 +1,5 @@
|
||||
from mealie.core import root_logger
|
||||
from mealie.db.db_setup import with_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.services.user_services.user_service import UserService
|
||||
|
||||
@ -13,7 +13,7 @@ def main():
|
||||
|
||||
logger = root_logger.get_logger()
|
||||
|
||||
with with_session() as session:
|
||||
with session_context() as session:
|
||||
repos = AllRepositories(session)
|
||||
user_service = UserService(repos)
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
import contextlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
||||
@ -7,6 +10,7 @@ from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import UUID4
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.db.models.group.webhooks import GroupWebhooksModel
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
@ -18,34 +22,58 @@ from .event_types import Event, EventDocumentType, EventTypes, EventWebhookData
|
||||
from .publisher import ApprisePublisher, PublisherLike, WebhookPublisher
|
||||
|
||||
|
||||
class EventListenerBase:
|
||||
class EventListenerBase(ABC):
|
||||
session: Session | None
|
||||
|
||||
def __init__(self, session: Session, group_id: UUID4, publisher: PublisherLike) -> None:
|
||||
self.session = session
|
||||
self.group_id = group_id
|
||||
self.publisher = publisher
|
||||
|
||||
@abstractmethod
|
||||
def get_subscribers(self, event: Event) -> list:
|
||||
"""Get a list of all subscribers to this event"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def publish_to_subscribers(self, event: Event, subscribers: list) -> None:
|
||||
"""Publishes the event to all subscribers"""
|
||||
...
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ensure_session(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
ensure_session ensures that a session is available for the caller by checking if a session
|
||||
was provided during construction, and if not, creating a new session with the `with_session`
|
||||
function and closing it when the context manager exits.
|
||||
|
||||
This is _required_ when working with sessions inside an event bus listener where the listener
|
||||
may be constructed during a request where the session is provided by the request, but the when
|
||||
run as a scheduled task, the session is not provided and must be created.
|
||||
"""
|
||||
if self.session is None:
|
||||
with session_context() as session:
|
||||
self.session = session
|
||||
yield
|
||||
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class AppriseEventListener(EventListenerBase):
|
||||
def __init__(self, session: Session, group_id: UUID4) -> None:
|
||||
super().__init__(session, group_id, ApprisePublisher())
|
||||
|
||||
def get_subscribers(self, event: Event) -> list[str]:
|
||||
repos = AllRepositories(self.session)
|
||||
with self.ensure_session():
|
||||
repos = AllRepositories(self.session)
|
||||
|
||||
notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore
|
||||
self.group_id
|
||||
).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate)
|
||||
notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore
|
||||
self.group_id
|
||||
).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate)
|
||||
|
||||
urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)]
|
||||
urls = AppriseEventListener.update_urls_with_event_data(urls, event)
|
||||
urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)]
|
||||
urls = AppriseEventListener.update_urls_with_event_data(urls, event)
|
||||
|
||||
return urls
|
||||
|
||||
@ -120,12 +148,13 @@ class WebhookEventListener(EventListenerBase):
|
||||
|
||||
def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]:
|
||||
"""Fetches all scheduled webhooks from the database"""
|
||||
return (
|
||||
self.session.query(GroupWebhooksModel)
|
||||
.where(
|
||||
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
|
||||
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),
|
||||
GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(),
|
||||
with self.ensure_session():
|
||||
return (
|
||||
self.session.query(GroupWebhooksModel)
|
||||
.where(
|
||||
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
|
||||
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),
|
||||
GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
@ -40,12 +40,13 @@ class EventSource:
|
||||
|
||||
|
||||
class EventBusService:
|
||||
bg: BackgroundTasks | None
|
||||
session: Session | None
|
||||
group_id: UUID4 | None
|
||||
|
||||
def __init__(
|
||||
self, bg: Optional[BackgroundTasks] = None, session: Optional[Session] = None, group_id: UUID4 | None = None
|
||||
) -> None:
|
||||
if not session:
|
||||
session = next(generate_session())
|
||||
|
||||
self.bg = bg
|
||||
self.session = session
|
||||
self.group_id = group_id
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import UUID4
|
||||
|
||||
from mealie.db.db_setup import create_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.schema.response.pagination import PaginationQuery
|
||||
from mealie.services.event_bus_service.event_bus_service import EventBusService
|
||||
@ -31,10 +31,11 @@ def post_group_webhooks(start_dt: Optional[datetime] = None, group_id: Optional[
|
||||
|
||||
if group_id is None:
|
||||
# publish the webhook event to each group's event bus
|
||||
session = create_session()
|
||||
repos = get_repositories(session)
|
||||
groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1))
|
||||
group_ids = [group.id for group in groups_data.items]
|
||||
|
||||
with session_context() as session:
|
||||
repos = get_repositories(session)
|
||||
groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1))
|
||||
group_ids = [group.id for group in groups_data.items]
|
||||
|
||||
else:
|
||||
group_ids = [group_id]
|
||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
|
||||
from mealie.core import root_logger
|
||||
from mealie.core.config import get_app_dirs
|
||||
from mealie.db.db_setup import create_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.db.models.group.exports import GroupDataExportsModel
|
||||
|
||||
ONE_DAY_AS_MINUTES = 1440
|
||||
@ -15,20 +15,19 @@ def purge_group_data_exports(max_minutes_old=ONE_DAY_AS_MINUTES):
|
||||
|
||||
logger.info("purging group data exports")
|
||||
limit = datetime.datetime.now() - datetime.timedelta(minutes=max_minutes_old)
|
||||
session = create_session()
|
||||
|
||||
results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit)
|
||||
with session_context() as session:
|
||||
results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit)
|
||||
|
||||
total_removed = 0
|
||||
for result in results:
|
||||
session.delete(result)
|
||||
Path(result.path).unlink(missing_ok=True)
|
||||
total_removed += 1
|
||||
total_removed = 0
|
||||
for result in results:
|
||||
session.delete(result)
|
||||
Path(result.path).unlink(missing_ok=True)
|
||||
total_removed += 1
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
session.commit()
|
||||
|
||||
logger.info(f"finished purging group data exports. {total_removed} exports removed from group data")
|
||||
logger.info(f"finished purging group data exports. {total_removed} exports removed from group data")
|
||||
|
||||
|
||||
def purge_excess_files() -> None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
|
||||
from mealie.core import root_logger
|
||||
from mealie.db.db_setup import create_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.db.models.users.password_reset import PasswordResetModel
|
||||
|
||||
logger = root_logger.get_logger()
|
||||
@ -13,8 +13,9 @@ def purge_password_reset_tokens():
|
||||
"""Purges all events after x days"""
|
||||
logger.info("purging password reset tokens")
|
||||
limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD)
|
||||
session = create_session()
|
||||
session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
logger.info("password reset tokens purges")
|
||||
|
||||
with session_context() as session:
|
||||
session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
logger.info("password reset tokens purges")
|
||||
|
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
|
||||
from mealie.core import root_logger
|
||||
from mealie.db.db_setup import create_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.db.models.group import GroupInviteToken
|
||||
|
||||
logger = root_logger.get_logger()
|
||||
@ -13,8 +13,10 @@ def purge_group_registration():
|
||||
"""Purges all events after x days"""
|
||||
logger.info("purging expired registration tokens")
|
||||
limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD)
|
||||
session = create_session()
|
||||
session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
with session_context() as session:
|
||||
session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
logger.info("registration token purged")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from mealie.core import root_logger
|
||||
from mealie.db.db_setup import with_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.services.user_services.user_service import UserService
|
||||
|
||||
@ -8,7 +8,7 @@ def locked_user_reset():
|
||||
logger = root_logger.get_logger()
|
||||
logger.info("resetting locked users")
|
||||
|
||||
with with_session() as session:
|
||||
with session_context() as session:
|
||||
repos = AllRepositories(session)
|
||||
user_service = UserService(repos)
|
||||
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from mealie.db.db_setup import create_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from mealie.services.user_services.password_reset_service import PasswordResetService
|
||||
from tests.utils.factories import random_string
|
||||
from tests.utils.fixture_schemas import TestUser
|
||||
@ -31,10 +31,10 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s
|
||||
cased_email += l.lower()
|
||||
cased_email
|
||||
|
||||
session = create_session()
|
||||
service = PasswordResetService(session)
|
||||
token = service.generate_reset_token(cased_email)
|
||||
assert token is not None
|
||||
with session_context() as session:
|
||||
service = PasswordResetService(session)
|
||||
token = service.generate_reset_token(cased_email)
|
||||
assert token is not None
|
||||
|
||||
new_password = random_string(15)
|
||||
|
||||
@ -59,8 +59,6 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s
|
||||
response = api_client.get(Routes.self, headers={"Authorization": f"Bearer {new_token}"})
|
||||
assert response.status_code == 200
|
||||
|
||||
session.close()
|
||||
|
||||
# Test successful password reset
|
||||
response = api_client.post(Routes.base, json=payload)
|
||||
assert response.status_code == 400
|
||||
|
@ -5,7 +5,7 @@ from pytest import MonkeyPatch
|
||||
from mealie.core import security
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.core.dependencies import validate_file_token
|
||||
from mealie.db.db_setup import create_session
|
||||
from mealie.db.db_setup import session_context
|
||||
from tests.utils.factories import random_string
|
||||
|
||||
|
||||
@ -47,5 +47,8 @@ def test_ldap_authentication_mocked(monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(ldap, "initialize", ldap_initialize_mock)
|
||||
|
||||
get_app_settings.cache_clear()
|
||||
result = security.authenticate_user(create_session(), user, password)
|
||||
|
||||
with session_context() as session:
|
||||
result = security.authenticate_user(session, user, password)
|
||||
|
||||
assert result is False
|
||||
|
Loading…
x
Reference in New Issue
Block a user