diff --git a/scanner/matcher/subscriber.py b/scanner/matcher/subscriber.py index acbe3e36..b2db3366 100644 --- a/scanner/matcher/subscriber.py +++ b/scanner/matcher/subscriber.py @@ -1,7 +1,6 @@ import asyncio from typing import Union, Literal from msgspec import Struct, json -import os from logging import getLogger from aio_pika import connect_robust from aio_pika.abc import AbstractIncomingMessage @@ -28,38 +27,24 @@ class Refresh(Message): id: str + + decoder = json.Decoder(Union[Scan, Delete, Refresh]) -class Subscriber: - QUEUE = "scanner" - - async def __aenter__(self): - self._con = await connect_robust( - host=os.environ.get("RABBITMQ_HOST", "rabbitmq"), - port=int(os.environ.get("RABBITMQ_PORT", "5672")), - login=os.environ.get("RABBITMQ_DEFAULT_USER", "guest"), - password=os.environ.get("RABBITMQ_DEFAULT_PASS", "guest"), - ) - self._channel = await self._con.channel() - self._queue = await self._channel.declare_queue(self.QUEUE) - return self - - async def __aexit__(self, exc_type, exc_value, exc_tb): - await self._con.close() - - async def listen(self, scanner: Matcher): +class Subscriber(Publisher): + async def listen(self, matcher: Matcher): async def on_message(message: AbstractIncomingMessage): try: msg = decoder.decode(message.body) ack = False match msg: case Scan(path): - ack = await scanner.identify(path) + ack = await matcher.identify(path) case Delete(path): - ack = await scanner.delete(path) + ack = await matcher.delete(path) case Refresh(kind, id): - ack = await scanner.refresh(kind, id) + ack = await matcher.refresh(kind, id) case _: logger.error(f"Invalid action: {msg.action}") if ack: