fix: unclosed sessions (#1734)

* resolve session leak

* cleanup session management functions
This commit is contained in:
Hayden 2022-10-17 14:11:40 -08:00 committed by GitHub
parent a3904c45d8
commit e516a2e801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 132 additions and 90 deletions

View File

@ -31,22 +31,31 @@ SessionLocal, engine = sql_global_init(settings.DB_URL) # type: ignore
@contextmanager @contextmanager
def with_session() -> Session: def session_context() -> Session:
"""
session_context() provides a managed session to the database that is automatically
closed when the context is exited. This is the preferred method of accessing the
database.
Note: use `generate_session` when using the `Depends` function from FastAPI
"""
global SessionLocal global SessionLocal
sess = SessionLocal() sess = SessionLocal()
try: try:
yield sess yield sess
finally: finally:
sess.close() sess.close()
def create_session() -> Session:
global SessionLocal
return SessionLocal()
def generate_session() -> Generator[Session, None, None]: def generate_session() -> Generator[Session, None, None]:
"""
WARNING: This function should _only_ be called when used with
using the `Depends` function from FastAPI. This function will leak
sessions if used outside of the context of a request.
Use `with_session` instead. That function will allow you to use the
session within a context manager
"""
global SessionLocal global SessionLocal
db = SessionLocal() db = SessionLocal()
try: try:

View File

@ -9,7 +9,7 @@ from alembic.config import Config
from alembic.runtime import migration from alembic.runtime import migration
from mealie.core import root_logger from mealie.core import root_logger
from mealie.core.config import get_app_settings from mealie.core.config import get_app_settings
from mealie.db.db_setup import create_session from mealie.db.db_setup import session_context
from mealie.db.fixes.fix_slug_foods import fix_slug_food_names from mealie.db.fixes.fix_slug_foods import fix_slug_food_names
from mealie.repos.all_repositories import get_repositories from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
@ -67,41 +67,40 @@ def connect(session: orm.Session) -> bool:
def main(): def main():
session = create_session()
# Wait for database to connect # Wait for database to connect
max_retry = 10 max_retry = 10
wait_seconds = 1 wait_seconds = 1
while True: with session_context() as session:
if connect(session): while True:
logger.info("Database connection established.") if connect(session):
break logger.info("Database connection established.")
break
logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...") logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...")
max_retry -= 1 max_retry -= 1
sleep(wait_seconds) sleep(wait_seconds)
if max_retry == 0: if max_retry == 0:
raise ConnectionError("Database connection failed - exiting application.") raise ConnectionError("Database connection failed - exiting application.")
alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini")) alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini"))
if db_is_at_head(alembic_cfg): if db_is_at_head(alembic_cfg):
logger.info("Migration not needed.") logger.info("Migration not needed.")
else: else:
logger.info("Migration needed. Performing migration...") logger.info("Migration needed. Performing migration...")
command.upgrade(alembic_cfg, "head") command.upgrade(alembic_cfg, "head")
db = get_repositories(session) db = get_repositories(session)
if db.users.get_all(): if db.users.get_all():
logger.info("Database exists") logger.info("Database exists")
else: else:
logger.info("Database contains no users, initializing...") logger.info("Database contains no users, initializing...")
init_db(db) init_db(db)
safe_try(lambda: fix_slug_food_names(db)) safe_try(lambda: fix_slug_food_names(db))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,5 @@
from mealie.core import root_logger from mealie.core import root_logger
from mealie.db.db_setup import with_session from mealie.db.db_setup import session_context
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.services.user_services.user_service import UserService from mealie.services.user_services.user_service import UserService
@ -13,7 +13,7 @@ def main():
logger = root_logger.get_logger() logger = root_logger.get_logger()
with with_session() as session: with session_context() as session:
repos = AllRepositories(session) repos = AllRepositories(session)
user_service = UserService(repos) user_service = UserService(repos)

View File

@ -1,4 +1,7 @@
import contextlib
import json import json
from abc import ABC, abstractmethod
from collections.abc import Generator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import cast from typing import cast
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
@ -7,6 +10,7 @@ from fastapi.encoders import jsonable_encoder
from pydantic import UUID4 from pydantic import UUID4
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from mealie.db.db_setup import session_context
from mealie.db.models.group.webhooks import GroupWebhooksModel from mealie.db.models.group.webhooks import GroupWebhooksModel
from mealie.repos.all_repositories import get_repositories from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
@ -18,34 +22,58 @@ from .event_types import Event, EventDocumentType, EventTypes, EventWebhookData
from .publisher import ApprisePublisher, PublisherLike, WebhookPublisher from .publisher import ApprisePublisher, PublisherLike, WebhookPublisher
class EventListenerBase: class EventListenerBase(ABC):
session: Session | None
def __init__(self, session: Session, group_id: UUID4, publisher: PublisherLike) -> None: def __init__(self, session: Session, group_id: UUID4, publisher: PublisherLike) -> None:
self.session = session self.session = session
self.group_id = group_id self.group_id = group_id
self.publisher = publisher self.publisher = publisher
@abstractmethod
def get_subscribers(self, event: Event) -> list: def get_subscribers(self, event: Event) -> list:
"""Get a list of all subscribers to this event""" """Get a list of all subscribers to this event"""
... ...
@abstractmethod
def publish_to_subscribers(self, event: Event, subscribers: list) -> None: def publish_to_subscribers(self, event: Event, subscribers: list) -> None:
"""Publishes the event to all subscribers""" """Publishes the event to all subscribers"""
... ...
@contextlib.contextmanager
def ensure_session(self) -> Generator[None, None, None]:
"""
ensure_session ensures that a session is available for the caller by checking if a session
was provided during construction, and if not, creating a new session with the `with_session`
function and closing it when the context manager exits.
This is _required_ when working with sessions inside an event bus listener where the listener
may be constructed during a request where the session is provided by the request, but the when
run as a scheduled task, the session is not provided and must be created.
"""
if self.session is None:
with session_context() as session:
self.session = session
yield
else:
yield
class AppriseEventListener(EventListenerBase): class AppriseEventListener(EventListenerBase):
def __init__(self, session: Session, group_id: UUID4) -> None: def __init__(self, session: Session, group_id: UUID4) -> None:
super().__init__(session, group_id, ApprisePublisher()) super().__init__(session, group_id, ApprisePublisher())
def get_subscribers(self, event: Event) -> list[str]: def get_subscribers(self, event: Event) -> list[str]:
repos = AllRepositories(self.session) with self.ensure_session():
repos = AllRepositories(self.session)
notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group( # type: ignore
self.group_id self.group_id
).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate) ).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate)
urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)] urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)]
urls = AppriseEventListener.update_urls_with_event_data(urls, event) urls = AppriseEventListener.update_urls_with_event_data(urls, event)
return urls return urls
@ -120,12 +148,13 @@ class WebhookEventListener(EventListenerBase):
def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]: def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]:
"""Fetches all scheduled webhooks from the database""" """Fetches all scheduled webhooks from the database"""
return ( with self.ensure_session():
self.session.query(GroupWebhooksModel) return (
.where( self.session.query(GroupWebhooksModel)
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison .where(
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(), GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(), GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),
GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(),
)
.all()
) )
.all()
)

View File

@ -40,12 +40,13 @@ class EventSource:
class EventBusService: class EventBusService:
bg: BackgroundTasks | None
session: Session | None
group_id: UUID4 | None
def __init__( def __init__(
self, bg: Optional[BackgroundTasks] = None, session: Optional[Session] = None, group_id: UUID4 | None = None self, bg: Optional[BackgroundTasks] = None, session: Optional[Session] = None, group_id: UUID4 | None = None
) -> None: ) -> None:
if not session:
session = next(generate_session())
self.bg = bg self.bg = bg
self.session = session self.session = session
self.group_id = group_id self.group_id = group_id

View File

@ -3,7 +3,7 @@ from typing import Optional
from pydantic import UUID4 from pydantic import UUID4
from mealie.db.db_setup import create_session from mealie.db.db_setup import session_context
from mealie.repos.all_repositories import get_repositories from mealie.repos.all_repositories import get_repositories
from mealie.schema.response.pagination import PaginationQuery from mealie.schema.response.pagination import PaginationQuery
from mealie.services.event_bus_service.event_bus_service import EventBusService from mealie.services.event_bus_service.event_bus_service import EventBusService
@ -31,10 +31,11 @@ def post_group_webhooks(start_dt: Optional[datetime] = None, group_id: Optional[
if group_id is None: if group_id is None:
# publish the webhook event to each group's event bus # publish the webhook event to each group's event bus
session = create_session()
repos = get_repositories(session) with session_context() as session:
groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1)) repos = get_repositories(session)
group_ids = [group.id for group in groups_data.items] groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1))
group_ids = [group.id for group in groups_data.items]
else: else:
group_ids = [group_id] group_ids = [group_id]

View File

@ -3,7 +3,7 @@ from pathlib import Path
from mealie.core import root_logger from mealie.core import root_logger
from mealie.core.config import get_app_dirs from mealie.core.config import get_app_dirs
from mealie.db.db_setup import create_session from mealie.db.db_setup import session_context
from mealie.db.models.group.exports import GroupDataExportsModel from mealie.db.models.group.exports import GroupDataExportsModel
ONE_DAY_AS_MINUTES = 1440 ONE_DAY_AS_MINUTES = 1440
@ -15,20 +15,19 @@ def purge_group_data_exports(max_minutes_old=ONE_DAY_AS_MINUTES):
logger.info("purging group data exports") logger.info("purging group data exports")
limit = datetime.datetime.now() - datetime.timedelta(minutes=max_minutes_old) limit = datetime.datetime.now() - datetime.timedelta(minutes=max_minutes_old)
session = create_session()
results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit) with session_context() as session:
results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit)
total_removed = 0 total_removed = 0
for result in results: for result in results:
session.delete(result) session.delete(result)
Path(result.path).unlink(missing_ok=True) Path(result.path).unlink(missing_ok=True)
total_removed += 1 total_removed += 1
session.commit() session.commit()
session.close()
logger.info(f"finished purging group data exports. {total_removed} exports removed from group data") logger.info(f"finished purging group data exports. {total_removed} exports removed from group data")
def purge_excess_files() -> None: def purge_excess_files() -> None:

View File

@ -1,7 +1,7 @@
import datetime import datetime
from mealie.core import root_logger from mealie.core import root_logger
from mealie.db.db_setup import create_session from mealie.db.db_setup import session_context
from mealie.db.models.users.password_reset import PasswordResetModel from mealie.db.models.users.password_reset import PasswordResetModel
logger = root_logger.get_logger() logger = root_logger.get_logger()
@ -13,8 +13,9 @@ def purge_password_reset_tokens():
"""Purges all events after x days""" """Purges all events after x days"""
logger.info("purging password reset tokens") logger.info("purging password reset tokens")
limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD) limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD)
session = create_session()
session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete() with session_context() as session:
session.commit() session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete()
session.close() session.commit()
logger.info("password reset tokens purges") session.close()
logger.info("password reset tokens purges")

View File

@ -1,7 +1,7 @@
import datetime import datetime
from mealie.core import root_logger from mealie.core import root_logger
from mealie.db.db_setup import create_session from mealie.db.db_setup import session_context
from mealie.db.models.group import GroupInviteToken from mealie.db.models.group import GroupInviteToken
logger = root_logger.get_logger() logger = root_logger.get_logger()
@ -13,8 +13,10 @@ def purge_group_registration():
"""Purges all events after x days""" """Purges all events after x days"""
logger.info("purging expired registration tokens") logger.info("purging expired registration tokens")
limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD) limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD)
session = create_session()
session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete() with session_context() as session:
session.commit() session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete()
session.close() session.commit()
session.close()
logger.info("registration token purged") logger.info("registration token purged")

View File

@ -1,5 +1,5 @@
from mealie.core import root_logger from mealie.core import root_logger
from mealie.db.db_setup import with_session from mealie.db.db_setup import session_context
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.services.user_services.user_service import UserService from mealie.services.user_services.user_service import UserService
@ -8,7 +8,7 @@ def locked_user_reset():
logger = root_logger.get_logger() logger = root_logger.get_logger()
logger.info("resetting locked users") logger.info("resetting locked users")
with with_session() as session: with session_context() as session:
repos = AllRepositories(session) repos = AllRepositories(session)
user_service = UserService(repos) user_service = UserService(repos)

View File

@ -3,7 +3,7 @@ import json
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from mealie.db.db_setup import create_session from mealie.db.db_setup import session_context
from mealie.services.user_services.password_reset_service import PasswordResetService from mealie.services.user_services.password_reset_service import PasswordResetService
from tests.utils.factories import random_string from tests.utils.factories import random_string
from tests.utils.fixture_schemas import TestUser from tests.utils.fixture_schemas import TestUser
@ -31,10 +31,10 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s
cased_email += l.lower() cased_email += l.lower()
cased_email cased_email
session = create_session() with session_context() as session:
service = PasswordResetService(session) service = PasswordResetService(session)
token = service.generate_reset_token(cased_email) token = service.generate_reset_token(cased_email)
assert token is not None assert token is not None
new_password = random_string(15) new_password = random_string(15)
@ -59,8 +59,6 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s
response = api_client.get(Routes.self, headers={"Authorization": f"Bearer {new_token}"}) response = api_client.get(Routes.self, headers={"Authorization": f"Bearer {new_token}"})
assert response.status_code == 200 assert response.status_code == 200
session.close()
# Test successful password reset # Test successful password reset
response = api_client.post(Routes.base, json=payload) response = api_client.post(Routes.base, json=payload)
assert response.status_code == 400 assert response.status_code == 400

View File

@ -5,7 +5,7 @@ from pytest import MonkeyPatch
from mealie.core import security from mealie.core import security
from mealie.core.config import get_app_settings from mealie.core.config import get_app_settings
from mealie.core.dependencies import validate_file_token from mealie.core.dependencies import validate_file_token
from mealie.db.db_setup import create_session from mealie.db.db_setup import session_context
from tests.utils.factories import random_string from tests.utils.factories import random_string
@ -47,5 +47,8 @@ def test_ldap_authentication_mocked(monkeypatch: MonkeyPatch):
monkeypatch.setattr(ldap, "initialize", ldap_initialize_mock) monkeypatch.setattr(ldap, "initialize", ldap_initialize_mock)
get_app_settings.cache_clear() get_app_settings.cache_clear()
result = security.authenticate_user(create_session(), user, password)
with session_context() as session:
result = security.authenticate_user(session, user, password)
assert result is False assert result is False