diff --git a/scanner/scanner/__init__.py b/scanner/scanner/__init__.py index 0ebe94db..e9334a32 100644 --- a/scanner/scanner/__init__.py +++ b/scanner/scanner/__init__.py @@ -1,5 +1,5 @@ import logging -from asyncio import CancelledError, TaskGroup, create_task +from asyncio import CancelledError, TaskGroup, create_task, sleep from contextlib import asynccontextmanager from fastapi import FastAPI @@ -25,15 +25,21 @@ async def lifespan(_): get_db() as db, KyooClient() as client, TheMovieDatabase() as tmdb, - RequestProcessor(db, client, tmdb) as processor, ): # there's no way someone else used the same id, right? is_master = await db.fetchval("select pg_try_advisory_lock(198347)") if is_master: await migrate() - async with get_db() as db: - scanner = FsScanner(client, RequestCreator(db)) - tasks = create_task(background_startup(scanner, processor, is_master)) + async with get_db() as scanner_db: + processor = RequestProcessor(db, client, tmdb) + scanner = FsScanner(client, RequestCreator(scanner_db)) + tasks = create_task( + background_startup( + scanner, + processor, + is_master, + ) + ) yield _ = tasks.cancel() @@ -44,7 +50,7 @@ async def background_startup( is_master: bool | None, ): async with TaskGroup() as tg: - _ = tg.create_task(processor.listen()) + _ = tg.create_task(processor.listen(tg)) if is_master: _ = tg.create_task(scanner.monitor()) _ = tg.create_task(scanner.scan(remove_deleted=True)) diff --git a/scanner/scanner/requests.py b/scanner/scanner/requests.py index c45c107d..a4d86dde 100644 --- a/scanner/scanner/requests.py +++ b/scanner/scanner/requests.py @@ -1,10 +1,11 @@ from __future__ import annotations +from asyncio import CancelledError, Future, TaskGroup, sleep from logging import getLogger from types import TracebackType -from typing import Literal +from typing import Literal, cast -from asyncpg import Connection +from asyncpg import Connection, Pool from pydantic import Field, TypeAdapter from .client import KyooClient @@ -68,18 +69,17 @@ class RequestProcessor: self._client = client self._providers = providers - async def __aenter__(self): - logger.info("Listening for requestes") - await self._database.add_listener("scanner_requests", self.process_all) - return self + async def listen(self, tg: TaskGroup): + def process(*_): + _ = tg.create_task(self.process_all()) - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ): - await self._database.remove_listener("scanner_requests", self.process_all) + try: + logger.info("Listening for requestes") + await self._database.add_listener("scanner_requests", process) + await Future() + except CancelledError as e: + logger.info("Stopped listening for requsets") + await self._database.remove_listener("scanner_requests", process) async def process_all(self): found = True