diff --git a/back/src/Kyoo.RabbitMq/ScannerProducer.cs b/back/src/Kyoo.RabbitMq/ScannerProducer.cs index 75c575a0..1034a2c5 100644 --- a/back/src/Kyoo.RabbitMq/ScannerProducer.cs +++ b/back/src/Kyoo.RabbitMq/ScannerProducer.cs @@ -32,19 +32,20 @@ public class ScannerProducer : IScanner { _channel = rabbitConnection.CreateModel(); _channel.QueueDeclare("scanner", exclusive: false, autoDelete: false); + _channel.QueueDeclare("scanner.rescan", exclusive: false, autoDelete: false); } - private Task _Publish(T message) + private Task _Publish(T message, string queue = "scanner") { var body = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(message, Utility.JsonOptions)); - _channel.BasicPublish("", routingKey: "scanner", body: body); + _channel.BasicPublish("", routingKey: queue, body: body); return Task.CompletedTask; } public Task SendRescanRequest() { var message = new { Action = "rescan", }; - return _Publish(message); + return _Publish(message, queue: "scanner.rescan"); } public Task SendRefreshRequest(string kind, Guid id) diff --git a/scanner/matcher/subscriber.py b/scanner/matcher/subscriber.py index bf94af0a..167654c8 100644 --- a/scanner/matcher/subscriber.py +++ b/scanner/matcher/subscriber.py @@ -4,9 +4,7 @@ from msgspec import Struct, json from logging import getLogger from aio_pika.abc import AbstractIncomingMessage -from scanner.publisher import Publisher -from scanner.scanner import scan - +from providers.rabbit_base import RabbitBase from matcher.matcher import Matcher logger = getLogger(__name__) @@ -29,14 +27,10 @@ class Refresh(Message): id: str -class Rescan(Message): - pass +decoder = json.Decoder(Union[Scan, Delete, Refresh]) -decoder = json.Decoder(Union[Scan, Delete, Refresh, Rescan]) - - -class Subscriber(Publisher): +class Subscriber(RabbitBase): async def listen(self, matcher: Matcher): async def on_message(message: AbstractIncomingMessage): try: @@ -49,9 +43,6 @@ class Subscriber(Publisher): ack = await matcher.delete(path) case Refresh(kind, id): ack = await matcher.refresh(kind, id) - case Rescan(): - await scan(None, self, matcher._client, remove_deleted=True) - ack = True case _: logger.error(f"Invalid action: {msg.action}") if ack: diff --git a/scanner/providers/rabbit_base.py b/scanner/providers/rabbit_base.py new file mode 100644 index 00000000..ff29c4f1 --- /dev/null +++ b/scanner/providers/rabbit_base.py @@ -0,0 +1,20 @@ +import os +from aio_pika import connect_robust + + +class RabbitBase: + QUEUE = "scanner" + + async def __aenter__(self): + self._con = await connect_robust( + host=os.environ.get("RABBITMQ_HOST", "rabbitmq"), + port=int(os.environ.get("RABBITMQ_PORT", "5672")), + login=os.environ.get("RABBITMQ_DEFAULT_USER", "guest"), + password=os.environ.get("RABBITMQ_DEFAULT_PASS", "guest"), + ) + self._channel = await self._con.channel() + self._queue = await self._channel.declare_queue(self.QUEUE) + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + await self._con.close() diff --git a/scanner/scanner/__init__.py b/scanner/scanner/__init__.py index bce4273e..c7ffd55a 100644 --- a/scanner/scanner/__init__.py +++ b/scanner/scanner/__init__.py @@ -13,8 +13,13 @@ async def main(): async with Publisher() as publisher, KyooClient() as client: path = os.environ.get("SCANNER_LIBRARY_ROOT", "/video") + + async def scan_all(): + await scan(path, publisher, client, remove_deleted=True) + await asyncio.gather( monitor(path, publisher, client), - scan(path, publisher, client, remove_deleted=True), + scan_all(), refresh(publisher, client), + publisher.listen(scan_all), ) diff --git a/scanner/scanner/publisher.py b/scanner/scanner/publisher.py index 315855e1..0132a781 100644 --- a/scanner/scanner/publisher.py +++ b/scanner/scanner/publisher.py @@ -1,26 +1,23 @@ -import os +import asyncio from guessit.jsonutils import json -from aio_pika import Message, connect_robust +from aio_pika import Message +from aio_pika.abc import AbstractIncomingMessage +from logging import getLogger from typing import Literal +from providers.rabbit_base import RabbitBase -class Publisher: - QUEUE = "scanner" +logger = getLogger(__name__) + + +class Publisher(RabbitBase): + QUEUE_RESCAN = "scanner.rescan" async def __aenter__(self): - self._con = await connect_robust( - host=os.environ.get("RABBITMQ_HOST", "rabbitmq"), - port=int(os.environ.get("RABBITMQ_PORT", "5672")), - login=os.environ.get("RABBITMQ_DEFAULT_USER", "guest"), - password=os.environ.get("RABBITMQ_DEFAULT_PASS", "guest"), - ) - self._channel = await self._con.channel() - self._queue = await self._channel.declare_queue(self.QUEUE) + await super().__aenter__() + self._queue = await self._channel.declare_queue(self.QUEUE_RESCAN) return self - async def __aexit__(self, exc_type, exc_value, exc_tb): - await self._con.close() - async def _publish(self, data: dict): await self._channel.default_exchange.publish( Message(json.dumps(data).encode()), @@ -40,3 +37,15 @@ class Publisher: **_kwargs, ): await self._publish({"action": "refresh", "kind": kind, "id": id}) + + async def listen(self, scan): + async def on_message(message: AbstractIncomingMessage): + try: + await scan() + await message.ack() + except Exception as e: + logger.exception("Unhandled error", exc_info=e) + await message.reject() + + await self._queue.consume(on_message) + await asyncio.Future()