fix: unclosed sessions (#1734)

* resolve session leak

* cleanup session management functions
This commit is contained in:
Hayden 2022-10-17 14:11:40 -08:00 committed by GitHub
parent a3904c45d8
commit e516a2e801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 132 additions and 90 deletions

View File

@ -31,22 +31,31 @@ SessionLocal, engine = sql_global_init(settings.DB_URL) # type: ignore
@contextmanager
def with_session() -> Session:
def session_context() -> Session:
"""
session_context() provides a managed session to the database that is automatically
closed when the context is exited. This is the preferred method of accessing the
database.
Note: use `generate_session` when using the `Depends` function from FastAPI
"""
global SessionLocal
sess = SessionLocal()
try:
yield sess
finally:
sess.close()
def create_session() -> Session:
global SessionLocal
return SessionLocal()
def generate_session() -> Generator[Session, None, None]:
"""
WARNING: This function should _only_ be called when used with
using the `Depends` function from FastAPI. This function will leak
sessions if used outside of the context of a request.
Use `with_session` instead. That function will allow you to use the
session within a context manager
"""
global SessionLocal
db = SessionLocal()
try:

View File

@ -9,7 +9,7 @@ from alembic.config import Config
from alembic.runtime import migration
from mealie.core import root_logger
from mealie.core.config import get_app_settings
from mealie.db.db_setup import create_session
from mealie.db.db_setup import session_context
from mealie.db.fixes.fix_slug_foods import fix_slug_food_names
from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories
@ -67,41 +67,40 @@ def connect(session: orm.Session) -> bool:
def main():
session = create_session()
# Wait for database to connect
max_retry = 10
wait_seconds = 1
while True:
if connect(session):
logger.info("Database connection established.")
break
with session_context() as session:
while True:
if connect(session):
logger.info("Database connection established.")
break
logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...")
max_retry -= 1
logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...")
max_retry -= 1
sleep(wait_seconds)
sleep(wait_seconds)
if max_retry == 0:
raise ConnectionError("Database connection failed - exiting application.")
if max_retry == 0:
raise ConnectionError("Database connection failed - exiting application.")
alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini"))
if db_is_at_head(alembic_cfg):
logger.info("Migration not needed.")
else:
logger.info("Migration needed. Performing migration...")
command.upgrade(alembic_cfg, "head")
alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini"))
if db_is_at_head(alembic_cfg):
logger.info("Migration not needed.")
else:
logger.info("Migration needed. Performing migration...")
command.upgrade(alembic_cfg, "head")
db = get_repositories(session)
db = get_repositories(session)
if db.users.get_all():
logger.info("Database exists")
else:
logger.info("Database contains no users, initializing...")
init_db(db)
if db.users.get_all():
logger.info("Database exists")
else:
logger.info("Database contains no users, initializing...")
init_db(db)
safe_try(lambda: fix_slug_food_names(db))
safe_try(lambda: fix_slug_food_names(db))
if __name__ == "__main__":

View File

@ -1,5 +1,5 @@
from mealie.core import root_logger
from mealie.db.db_setup import with_session
from mealie.db.db_setup import session_context
from mealie.repos.repository_factory import AllRepositories
from mealie.services.user_services.user_service import UserService
@ -13,7 +13,7 @@ def main():
logger = root_logger.get_logger()
with with_session() as session:
with session_context() as session:
repos = AllRepositories(session)
user_service = UserService(repos)

View File

@ -1,4 +1,7 @@
import contextlib
import json
from abc import ABC, abstractmethod
from collections.abc import Generator
from datetime import datetime, timezone
from typing import cast
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
@ -7,6 +10,7 @@ from fastapi.encoders import jsonable_encoder
from pydantic import UUID4
from sqlalchemy.orm.session import Session
from mealie.db.db_setup import session_context
from mealie.db.models.group.webhooks import GroupWebhooksModel
from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories
@ -18,34 +22,58 @@ from .event_types import Event, EventDocumentType, EventTypes, EventWebhookData
from .publisher import ApprisePublisher, PublisherLike, WebhookPublisher
class EventListenerBase:
class EventListenerBase(ABC):
session: Session | None
def __init__(self, session: Session, group_id: UUID4, publisher: PublisherLike) -> None:
self.session = session
self.group_id = group_id
self.publisher = publisher
@abstractmethod
def get_subscribers(self, event: Event) -> list:
"""Get a list of all subscribers to this event"""
...
@abstractmethod
def publish_to_subscribers(self, event: Event, subscribers: list) -> None:
"""Publishes the event to all subscribers"""
...
@contextlib.contextmanager
def ensure_session(self) -> Generator[None, None, None]:
"""
ensure_session ensures that a session is available for the caller by checking if a session
was provided during construction, and if not, creating a new session with the `with_session`
function and closing it when the context manager exits.
This is _required_ when working with sessions inside an event bus listener where the listener
may be constructed during a request where the session is provided by the request, but the when
run as a scheduled task, the session is not provided and must be created.
"""
if self.session is None:
with session_context() as session:
self.session = session
yield
else:
yield
class AppriseEventListener(EventListenerBase):
def __init__(self, session: Session, group_id: UUID4) -> None:
super().__init__(session, group_id, ApprisePublisher())
def get_subscribers(self, event: Event) -> list[str]:
repos = AllRepositories(self.session)
with self.ensure_session():
repos = AllRepositories(self.session)
notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore
self.group_id
).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate)
notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore
self.group_id
).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate)
urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)]
urls = AppriseEventListener.update_urls_with_event_data(urls, event)
urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)]
urls = AppriseEventListener.update_urls_with_event_data(urls, event)
return urls
@ -120,12 +148,13 @@ class WebhookEventListener(EventListenerBase):
def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]:
"""Fetches all scheduled webhooks from the database"""
return (
self.session.query(GroupWebhooksModel)
.where(
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),
GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(),
with self.ensure_session():
return (
self.session.query(GroupWebhooksModel)
.where(
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),
GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(),
)
.all()
)
.all()
)

View File

@ -40,12 +40,13 @@ class EventSource:
class EventBusService:
bg: BackgroundTasks | None
session: Session | None
group_id: UUID4 | None
def __init__(
self, bg: Optional[BackgroundTasks] = None, session: Optional[Session] = None, group_id: UUID4 | None = None
) -> None:
if not session:
session = next(generate_session())
self.bg = bg
self.session = session
self.group_id = group_id

View File

@ -3,7 +3,7 @@ from typing import Optional
from pydantic import UUID4
from mealie.db.db_setup import create_session
from mealie.db.db_setup import session_context
from mealie.repos.all_repositories import get_repositories
from mealie.schema.response.pagination import PaginationQuery
from mealie.services.event_bus_service.event_bus_service import EventBusService
@ -31,10 +31,11 @@ def post_group_webhooks(start_dt: Optional[datetime] = None, group_id: Optional[
if group_id is None:
# publish the webhook event to each group's event bus
session = create_session()
repos = get_repositories(session)
groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1))
group_ids = [group.id for group in groups_data.items]
with session_context() as session:
repos = get_repositories(session)
groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1))
group_ids = [group.id for group in groups_data.items]
else:
group_ids = [group_id]

View File

@ -3,7 +3,7 @@ from pathlib import Path
from mealie.core import root_logger
from mealie.core.config import get_app_dirs
from mealie.db.db_setup import create_session
from mealie.db.db_setup import session_context
from mealie.db.models.group.exports import GroupDataExportsModel
ONE_DAY_AS_MINUTES = 1440
@ -15,20 +15,19 @@ def purge_group_data_exports(max_minutes_old=ONE_DAY_AS_MINUTES):
logger.info("purging group data exports")
limit = datetime.datetime.now() - datetime.timedelta(minutes=max_minutes_old)
session = create_session()
results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit)
with session_context() as session:
results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit)
total_removed = 0
for result in results:
session.delete(result)
Path(result.path).unlink(missing_ok=True)
total_removed += 1
total_removed = 0
for result in results:
session.delete(result)
Path(result.path).unlink(missing_ok=True)
total_removed += 1
session.commit()
session.close()
session.commit()
logger.info(f"finished purging group data exports. {total_removed} exports removed from group data")
logger.info(f"finished purging group data exports. {total_removed} exports removed from group data")
def purge_excess_files() -> None:

View File

@ -1,7 +1,7 @@
import datetime
from mealie.core import root_logger
from mealie.db.db_setup import create_session
from mealie.db.db_setup import session_context
from mealie.db.models.users.password_reset import PasswordResetModel
logger = root_logger.get_logger()
@ -13,8 +13,9 @@ def purge_password_reset_tokens():
"""Purges all events after x days"""
logger.info("purging password reset tokens")
limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD)
session = create_session()
session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete()
session.commit()
session.close()
logger.info("password reset tokens purges")
with session_context() as session:
session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete()
session.commit()
session.close()
logger.info("password reset tokens purges")

View File

@ -1,7 +1,7 @@
import datetime
from mealie.core import root_logger
from mealie.db.db_setup import create_session
from mealie.db.db_setup import session_context
from mealie.db.models.group import GroupInviteToken
logger = root_logger.get_logger()
@ -13,8 +13,10 @@ def purge_group_registration():
"""Purges all events after x days"""
logger.info("purging expired registration tokens")
limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD)
session = create_session()
session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete()
session.commit()
session.close()
with session_context() as session:
session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete()
session.commit()
session.close()
logger.info("registration token purged")

View File

@ -1,5 +1,5 @@
from mealie.core import root_logger
from mealie.db.db_setup import with_session
from mealie.db.db_setup import session_context
from mealie.repos.repository_factory import AllRepositories
from mealie.services.user_services.user_service import UserService
@ -8,7 +8,7 @@ def locked_user_reset():
logger = root_logger.get_logger()
logger.info("resetting locked users")
with with_session() as session:
with session_context() as session:
repos = AllRepositories(session)
user_service = UserService(repos)

View File

@ -3,7 +3,7 @@ import json
import pytest
from fastapi.testclient import TestClient
from mealie.db.db_setup import create_session
from mealie.db.db_setup import session_context
from mealie.services.user_services.password_reset_service import PasswordResetService
from tests.utils.factories import random_string
from tests.utils.fixture_schemas import TestUser
@ -31,10 +31,10 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s
cased_email += l.lower()
cased_email
session = create_session()
service = PasswordResetService(session)
token = service.generate_reset_token(cased_email)
assert token is not None
with session_context() as session:
service = PasswordResetService(session)
token = service.generate_reset_token(cased_email)
assert token is not None
new_password = random_string(15)
@ -59,8 +59,6 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s
response = api_client.get(Routes.self, headers={"Authorization": f"Bearer {new_token}"})
assert response.status_code == 200
session.close()
# Test successful password reset
response = api_client.post(Routes.base, json=payload)
assert response.status_code == 400

View File

@ -5,7 +5,7 @@ from pytest import MonkeyPatch
from mealie.core import security
from mealie.core.config import get_app_settings
from mealie.core.dependencies import validate_file_token
from mealie.db.db_setup import create_session
from mealie.db.db_setup import session_context
from tests.utils.factories import random_string
@ -47,5 +47,8 @@ def test_ldap_authentication_mocked(monkeypatch: MonkeyPatch):
monkeypatch.setattr(ldap, "initialize", ldap_initialize_mock)
get_app_settings.cache_clear()
result = security.authenticate_user(create_session(), user, password)
with session_context() as session:
result = security.authenticate_user(session, user, password)
assert result is False