diff --git a/server/src/app.module.ts b/server/src/app.module.ts index 8d261463e7..8079441329 100644 --- a/server/src/app.module.ts +++ b/server/src/app.module.ts @@ -19,6 +19,7 @@ import { ConfigRepository } from 'src/repositories/config.repository'; import { EventRepository } from 'src/repositories/event.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; import { teardownTelemetry, TelemetryRepository } from 'src/repositories/telemetry.repository'; +import { WebsocketRepository } from 'src/repositories/websocket.repository'; import { services } from 'src/services'; import { AuthService } from 'src/services/auth.service'; import { CliService } from 'src/services/cli.service'; @@ -52,6 +53,7 @@ class BaseModule implements OnModuleInit, OnModuleDestroy { @Inject(IWorker) private worker: ImmichWorker, logger: LoggingRepository, private eventRepository: EventRepository, + private websocketRepository: WebsocketRepository, private jobService: JobService, private telemetryRepository: TelemetryRepository, private authService: AuthService, @@ -64,7 +66,7 @@ class BaseModule implements OnModuleInit, OnModuleDestroy { this.jobService.setServices(services); - this.eventRepository.setAuthFn(async (client) => + this.websocketRepository.setAuthFn(async (client) => this.authService.authenticate({ headers: client.request.headers, queryParams: {}, diff --git a/server/src/repositories/event.repository.ts b/server/src/repositories/event.repository.ts index 420be0e1b4..92479a26dc 100644 --- a/server/src/repositories/event.repository.ts +++ b/server/src/repositories/event.repository.ts @@ -1,27 +1,15 @@ import { Injectable } from '@nestjs/common'; import { ModuleRef, Reflector } from '@nestjs/core'; -import { - OnGatewayConnection, - OnGatewayDisconnect, - OnGatewayInit, - WebSocketGateway, - WebSocketServer, -} from '@nestjs/websockets'; import { ClassConstructor } from 'class-transformer'; import _ from 'lodash'; -import { Server, Socket } from 'socket.io'; +import { Socket } from 'socket.io'; import { SystemConfig } from 'src/config'; import { EventConfig } from 'src/decorators'; -import { AssetResponseDto } from 'src/dtos/asset-response.dto'; import { AuthDto } from 'src/dtos/auth.dto'; -import { NotificationDto } from 'src/dtos/notification.dto'; -import { ReleaseNotification, ServerVersionResponseDto } from 'src/dtos/server.dto'; -import { SyncAssetExifV1, SyncAssetV1 } from 'src/dtos/sync.dto'; import { ImmichWorker, JobStatus, MetadataKey, QueueName, UserAvatarColor, UserStatus } from 'src/enum'; import { ConfigRepository } from 'src/repositories/config.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; import { JobItem, JobSource } from 'src/types'; -import { handlePromiseError } from 'src/utils/misc'; type EmitHandlers = Partial<{ [T in EmitEvent]: Array> }>; @@ -130,33 +118,11 @@ type UserEvent = { profileChangedAt: Date; }; -export const serverEvents = ['ConfigUpdate'] as const; -export type ServerEvents = (typeof serverEvents)[number]; - export type EmitEvent = keyof EventMap; export type EmitHandler = (...args: ArgsOf) => Promise | void; export type ArgOf = EventMap[T][0]; export type ArgsOf = EventMap[T]; -export interface ClientEventMap { - on_upload_success: [AssetResponseDto]; - on_user_delete: [string]; - on_asset_delete: [string]; - on_asset_trash: [string[]]; - on_asset_update: [AssetResponseDto]; - on_asset_hidden: [string]; - on_asset_restore: [string[]]; - on_asset_stack_update: string[]; - on_person_thumbnail: [string]; - on_server_version: [ServerVersionResponseDto]; - on_config_update: []; - on_new_release: [ReleaseNotification]; - on_notification: [NotificationDto]; - on_session_delete: [string]; - - AssetUploadReadyV1: [{ asset: SyncAssetV1; exif: SyncAssetExifV1 }]; -} - export type EventItem = { event: T; handler: EmitHandler; @@ -165,18 +131,9 @@ export type EventItem = { export type AuthFn = (client: Socket) => Promise; -@WebSocketGateway({ - cors: true, - path: '/api/socket.io', - transports: ['websocket'], -}) @Injectable() -export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit { +export class EventRepository { private emitHandlers: EmitHandlers = {}; - private authFn?: AuthFn; - - @WebSocketServer() - private server?: Server; constructor( private moduleRef: ModuleRef, @@ -237,38 +194,6 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect } } - afterInit(server: Server) { - this.logger.log('Initialized websocket server'); - - for (const event of serverEvents) { - server.on(event, (...args: ArgsOf) => { - this.logger.debug(`Server event: ${event} (receive)`); - handlePromiseError(this.onEvent({ name: event, args, server: true }), this.logger); - }); - } - } - - async handleConnection(client: Socket) { - try { - this.logger.log(`Websocket Connect: ${client.id}`); - const auth = await this.authenticate(client); - await client.join(auth.user.id); - if (auth.session) { - await client.join(auth.session.id); - } - await this.onEvent({ name: 'WebsocketConnect', args: [{ userId: auth.user.id }], server: false }); - } catch (error: Error | any) { - this.logger.error(`Websocket connection error: ${error}`, error?.stack); - client.emit('error', 'unauthorized'); - client.disconnect(); - } - } - - async handleDisconnect(client: Socket) { - this.logger.log(`Websocket Disconnect: ${client.id}`); - await client.leave(client.nsp.name); - } - private addHandler(item: Item): void { const event = item.event; @@ -283,7 +208,7 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect return this.onEvent({ name: event, args, server: false }); } - private async onEvent(event: { name: T; args: ArgsOf; server: boolean }): Promise { + async onEvent(event: { name: T; args: ArgsOf; server: boolean }): Promise { const handlers = this.emitHandlers[event.name] || []; for (const { handler, server } of handlers) { // exclude handlers that ignore server events @@ -294,29 +219,4 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect await handler(...event.args); } } - - clientSend(event: T, room: string, ...data: ClientEventMap[T]) { - this.server?.to(room).emit(event, ...data); - } - - clientBroadcast(event: T, ...data: ClientEventMap[T]) { - this.server?.emit(event, ...data); - } - - serverSend(event: T, ...args: ArgsOf): void { - this.logger.debug(`Server event: ${event} (send)`); - this.server?.serverSideEmit(event, ...args); - } - - setAuthFn(fn: (client: Socket) => Promise) { - this.authFn = fn; - } - - private async authenticate(client: Socket) { - if (!this.authFn) { - throw new Error('Auth function not set'); - } - - return this.authFn(client); - } } diff --git a/server/src/repositories/index.ts b/server/src/repositories/index.ts index d2e1aa08c8..f02806785a 100644 --- a/server/src/repositories/index.ts +++ b/server/src/repositories/index.ts @@ -44,6 +44,7 @@ import { TrashRepository } from 'src/repositories/trash.repository'; import { UserRepository } from 'src/repositories/user.repository'; import { VersionHistoryRepository } from 'src/repositories/version-history.repository'; import { ViewRepository } from 'src/repositories/view-repository'; +import { WebsocketRepository } from 'src/repositories/websocket.repository'; export const repositories = [ AccessRepository, @@ -92,4 +93,5 @@ export const repositories = [ UserRepository, ViewRepository, VersionHistoryRepository, + WebsocketRepository, ]; diff --git a/server/src/repositories/websocket.repository.ts b/server/src/repositories/websocket.repository.ts new file mode 100644 index 0000000000..030659772d --- /dev/null +++ b/server/src/repositories/websocket.repository.ts @@ -0,0 +1,118 @@ +import { Injectable } from '@nestjs/common'; +import { + OnGatewayConnection, + OnGatewayDisconnect, + OnGatewayInit, + WebSocketGateway, + WebSocketServer, +} from '@nestjs/websockets'; +import { Server, Socket } from 'socket.io'; +import { AssetResponseDto } from 'src/dtos/asset-response.dto'; +import { AuthDto } from 'src/dtos/auth.dto'; +import { NotificationDto } from 'src/dtos/notification.dto'; +import { ReleaseNotification, ServerVersionResponseDto } from 'src/dtos/server.dto'; +import { SyncAssetExifV1, SyncAssetV1 } from 'src/dtos/sync.dto'; +import { ArgsOf, EventRepository } from 'src/repositories/event.repository'; +import { LoggingRepository } from 'src/repositories/logging.repository'; +import { handlePromiseError } from 'src/utils/misc'; + +export const serverEvents = ['ConfigUpdate'] as const; +export type ServerEvents = (typeof serverEvents)[number]; + +export interface ClientEventMap { + on_upload_success: [AssetResponseDto]; + on_user_delete: [string]; + on_asset_delete: [string]; + on_asset_trash: [string[]]; + on_asset_update: [AssetResponseDto]; + on_asset_hidden: [string]; + on_asset_restore: [string[]]; + on_asset_stack_update: string[]; + on_person_thumbnail: [string]; + on_server_version: [ServerVersionResponseDto]; + on_config_update: []; + on_new_release: [ReleaseNotification]; + on_notification: [NotificationDto]; + on_session_delete: [string]; + + AssetUploadReadyV1: [{ asset: SyncAssetV1; exif: SyncAssetExifV1 }]; +} + +export type AuthFn = (client: Socket) => Promise; + +@WebSocketGateway({ + cors: true, + path: '/api/socket.io', + transports: ['websocket'], +}) +@Injectable() +export class WebsocketRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit { + private authFn?: AuthFn; + + @WebSocketServer() + private server?: Server; + + constructor( + private eventRepository: EventRepository, + private logger: LoggingRepository, + ) { + this.logger.setContext(WebsocketRepository.name); + } + + afterInit(server: Server) { + this.logger.log('Initialized websocket server'); + + for (const event of serverEvents) { + server.on(event, (...args: ArgsOf) => { + this.logger.debug(`Server event: ${event} (receive)`); + handlePromiseError(this.eventRepository.onEvent({ name: event, args, server: true }), this.logger); + }); + } + } + + async handleConnection(client: Socket) { + try { + this.logger.log(`Websocket Connect: ${client.id}`); + const auth = await this.authenticate(client); + await client.join(auth.user.id); + if (auth.session) { + await client.join(auth.session.id); + } + await this.eventRepository.emit('WebsocketConnect', { userId: auth.user.id }); + } catch (error: Error | any) { + this.logger.error(`Websocket connection error: ${error}`, error?.stack); + client.emit('error', 'unauthorized'); + client.disconnect(); + } + } + + async handleDisconnect(client: Socket) { + this.logger.log(`Websocket Disconnect: ${client.id}`); + await client.leave(client.nsp.name); + } + + clientSend(event: T, room: string, ...data: ClientEventMap[T]) { + this.server?.to(room).emit(event, ...data); + } + + clientBroadcast(event: T, ...data: ClientEventMap[T]) { + this.server?.emit(event, ...data); + } + + serverSend(event: T, ...args: ArgsOf): void { + this.logger.debug(`Server event: ${event} (send)`); + this.server?.serverSideEmit(event, ...args); + } + + setAuthFn(fn: (client: Socket) => Promise) { + this.authFn = fn; + } + + private async authenticate(client: Socket) { + if (!this.authFn) { + throw new Error('Auth function not set'); + } + + return this.authFn(client); + } +} diff --git a/server/src/services/base.service.ts b/server/src/services/base.service.ts index de5e5862c7..bc71c6faf1 100644 --- a/server/src/services/base.service.ts +++ b/server/src/services/base.service.ts @@ -51,6 +51,7 @@ import { TrashRepository } from 'src/repositories/trash.repository'; import { UserRepository } from 'src/repositories/user.repository'; import { VersionHistoryRepository } from 'src/repositories/version-history.repository'; import { ViewRepository } from 'src/repositories/view-repository'; +import { WebsocketRepository } from 'src/repositories/websocket.repository'; import { UserTable } from 'src/schema/tables/user.table'; import { AccessRequest, checkAccess, requireAccess } from 'src/utils/access'; import { getConfig, updateConfig } from 'src/utils/config'; @@ -155,6 +156,7 @@ export class BaseService { protected userRepository: UserRepository, protected versionRepository: VersionHistoryRepository, protected viewRepository: ViewRepository, + protected websocketRepository: WebsocketRepository, ) { this.logger.setContext(this.constructor.name); this.storageCore = StorageCore.create( diff --git a/server/src/services/job.service.ts b/server/src/services/job.service.ts index 658e5dbd7a..140f9f3ec9 100644 --- a/server/src/services/job.service.ts +++ b/server/src/services/job.service.ts @@ -331,7 +331,7 @@ export class JobService extends BaseService { const { id } = item.data; const person = await this.personRepository.getById(id); if (person) { - this.eventRepository.clientSend('on_person_thumbnail', person.ownerId, person.id); + this.websocketRepository.clientSend('on_person_thumbnail', person.ownerId, person.id); } break; } @@ -358,10 +358,10 @@ export class JobService extends BaseService { await this.jobRepository.queueAll(jobs); if (asset.visibility === AssetVisibility.Timeline || asset.visibility === AssetVisibility.Archive) { - this.eventRepository.clientSend('on_upload_success', asset.ownerId, mapAsset(asset)); + this.websocketRepository.clientSend('on_upload_success', asset.ownerId, mapAsset(asset)); if (asset.exifInfo) { const exif = asset.exifInfo; - this.eventRepository.clientSend('AssetUploadReadyV1', asset.ownerId, { + this.websocketRepository.clientSend('AssetUploadReadyV1', asset.ownerId, { // TODO remove `on_upload_success` and then modify the query to select only the required fields) asset: { id: asset.id, diff --git a/server/src/services/notification.service.spec.ts b/server/src/services/notification.service.spec.ts index 403ed44631..daa3f221ae 100644 --- a/server/src/services/notification.service.spec.ts +++ b/server/src/services/notification.service.spec.ts @@ -65,8 +65,8 @@ describe(NotificationService.name, () => { it('should emit client and server events', () => { const update = { oldConfig: defaults, newConfig: defaults }; expect(sut.onConfigUpdate(update)).toBeUndefined(); - expect(mocks.event.clientBroadcast).toHaveBeenCalledWith('on_config_update'); - expect(mocks.event.serverSend).toHaveBeenCalledWith('ConfigUpdate', update); + expect(mocks.websocket.clientBroadcast).toHaveBeenCalledWith('on_config_update'); + expect(mocks.websocket.serverSend).toHaveBeenCalledWith('ConfigUpdate', update); }); }); @@ -125,7 +125,7 @@ describe(NotificationService.name, () => { describe('onAssetHide', () => { it('should send connected clients an event', () => { sut.onAssetHide({ assetId: 'asset-id', userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_hidden', 'user-id', 'asset-id'); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_hidden', 'user-id', 'asset-id'); }); }); @@ -178,67 +178,67 @@ describe(NotificationService.name, () => { it('should send a on_session_delete client event', () => { vi.useFakeTimers(); sut.onSessionDelete({ sessionId: 'id' }); - expect(mocks.event.clientSend).not.toHaveBeenCalled(); + expect(mocks.websocket.clientSend).not.toHaveBeenCalled(); vi.advanceTimersByTime(500); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_session_delete', 'id', 'id'); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_session_delete', 'id', 'id'); }); }); describe('onAssetTrash', () => { - it('should send connected clients an event', () => { + it('should send connected clients an websocket', () => { sut.onAssetTrash({ assetId: 'asset-id', userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_trash', 'user-id', ['asset-id']); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_trash', 'user-id', ['asset-id']); }); }); describe('onAssetDelete', () => { it('should send connected clients an event', () => { sut.onAssetDelete({ assetId: 'asset-id', userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_delete', 'user-id', 'asset-id'); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_delete', 'user-id', 'asset-id'); }); }); describe('onAssetsTrash', () => { it('should send connected clients an event', () => { sut.onAssetsTrash({ assetIds: ['asset-id'], userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_trash', 'user-id', ['asset-id']); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_trash', 'user-id', ['asset-id']); }); }); describe('onAssetsRestore', () => { it('should send connected clients an event', () => { sut.onAssetsRestore({ assetIds: ['asset-id'], userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_restore', 'user-id', ['asset-id']); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_restore', 'user-id', ['asset-id']); }); }); describe('onStackCreate', () => { it('should send connected clients an event', () => { sut.onStackCreate({ stackId: 'stack-id', userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); }); }); describe('onStackUpdate', () => { it('should send connected clients an event', () => { sut.onStackUpdate({ stackId: 'stack-id', userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); }); }); describe('onStackDelete', () => { it('should send connected clients an event', () => { sut.onStackDelete({ stackId: 'stack-id', userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); }); }); describe('onStacksDelete', () => { it('should send connected clients an event', () => { sut.onStacksDelete({ stackIds: ['stack-id'], userId: 'user-id' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_asset_stack_update', 'user-id'); }); }); diff --git a/server/src/services/notification.service.ts b/server/src/services/notification.service.ts index 51cd51811f..8276f141a0 100644 --- a/server/src/services/notification.service.ts +++ b/server/src/services/notification.service.ts @@ -98,7 +98,7 @@ export class NotificationService extends BaseService { description: `Job ${[job.name]} failed with error: ${errorMessage}`, }); - this.eventRepository.clientSend('on_notification', admin.id, mapNotification(item)); + this.websocketRepository.clientSend('on_notification', admin.id, mapNotification(item)); break; } @@ -110,8 +110,8 @@ export class NotificationService extends BaseService { @OnEvent({ name: 'ConfigUpdate' }) onConfigUpdate({ oldConfig, newConfig }: ArgOf<'ConfigUpdate'>) { - this.eventRepository.clientBroadcast('on_config_update'); - this.eventRepository.serverSend('ConfigUpdate', { oldConfig, newConfig }); + this.websocketRepository.clientBroadcast('on_config_update'); + this.websocketRepository.serverSend('ConfigUpdate', { oldConfig, newConfig }); } @OnEvent({ name: 'ConfigValidate', priority: -100 }) @@ -131,7 +131,7 @@ export class NotificationService extends BaseService { @OnEvent({ name: 'AssetHide' }) onAssetHide({ assetId, userId }: ArgOf<'AssetHide'>) { - this.eventRepository.clientSend('on_asset_hidden', userId, assetId); + this.websocketRepository.clientSend('on_asset_hidden', userId, assetId); } @OnEvent({ name: 'AssetShow' }) @@ -141,17 +141,17 @@ export class NotificationService extends BaseService { @OnEvent({ name: 'AssetTrash' }) onAssetTrash({ assetId, userId }: ArgOf<'AssetTrash'>) { - this.eventRepository.clientSend('on_asset_trash', userId, [assetId]); + this.websocketRepository.clientSend('on_asset_trash', userId, [assetId]); } @OnEvent({ name: 'AssetDelete' }) onAssetDelete({ assetId, userId }: ArgOf<'AssetDelete'>) { - this.eventRepository.clientSend('on_asset_delete', userId, assetId); + this.websocketRepository.clientSend('on_asset_delete', userId, assetId); } @OnEvent({ name: 'AssetTrashAll' }) onAssetsTrash({ assetIds, userId }: ArgOf<'AssetTrashAll'>) { - this.eventRepository.clientSend('on_asset_trash', userId, assetIds); + this.websocketRepository.clientSend('on_asset_trash', userId, assetIds); } @OnEvent({ name: 'AssetMetadataExtracted' }) @@ -162,7 +162,7 @@ export class NotificationService extends BaseService { const [asset] = await this.assetRepository.getByIdsWithAllRelationsButStacks([assetId]); if (asset) { - this.eventRepository.clientSend( + this.websocketRepository.clientSend( 'on_asset_update', userId, mapAsset(asset, { auth: { user: { id: userId } } as AuthDto }), @@ -172,27 +172,27 @@ export class NotificationService extends BaseService { @OnEvent({ name: 'AssetRestoreAll' }) onAssetsRestore({ assetIds, userId }: ArgOf<'AssetRestoreAll'>) { - this.eventRepository.clientSend('on_asset_restore', userId, assetIds); + this.websocketRepository.clientSend('on_asset_restore', userId, assetIds); } @OnEvent({ name: 'StackCreate' }) onStackCreate({ userId }: ArgOf<'StackCreate'>) { - this.eventRepository.clientSend('on_asset_stack_update', userId); + this.websocketRepository.clientSend('on_asset_stack_update', userId); } @OnEvent({ name: 'StackUpdate' }) onStackUpdate({ userId }: ArgOf<'StackUpdate'>) { - this.eventRepository.clientSend('on_asset_stack_update', userId); + this.websocketRepository.clientSend('on_asset_stack_update', userId); } @OnEvent({ name: 'StackDelete' }) onStackDelete({ userId }: ArgOf<'StackDelete'>) { - this.eventRepository.clientSend('on_asset_stack_update', userId); + this.websocketRepository.clientSend('on_asset_stack_update', userId); } @OnEvent({ name: 'StackDeleteAll' }) onStacksDelete({ userId }: ArgOf<'StackDeleteAll'>) { - this.eventRepository.clientSend('on_asset_stack_update', userId); + this.websocketRepository.clientSend('on_asset_stack_update', userId); } @OnEvent({ name: 'UserSignup' }) @@ -204,7 +204,7 @@ export class NotificationService extends BaseService { @OnEvent({ name: 'UserDelete' }) onUserDelete({ id }: ArgOf<'UserDelete'>) { - this.eventRepository.clientBroadcast('on_user_delete', id); + this.websocketRepository.clientBroadcast('on_user_delete', id); } @OnEvent({ name: 'AlbumUpdate' }) @@ -224,7 +224,7 @@ export class NotificationService extends BaseService { @OnEvent({ name: 'SessionDelete' }) onSessionDelete({ sessionId }: ArgOf<'SessionDelete'>) { // after the response is sent - setTimeout(() => this.eventRepository.clientSend('on_session_delete', sessionId, sessionId), 500); + setTimeout(() => this.websocketRepository.clientSend('on_session_delete', sessionId, sessionId), 500); } async sendTestEmail(id: string, dto: SystemConfigSmtpDto, tempTemplate?: string) { @@ -464,6 +464,6 @@ export class NotificationService extends BaseService { data: JSON.stringify({ albumId: album.id }), }); - this.eventRepository.clientSend('on_notification', userId, mapNotification(item)); + this.websocketRepository.clientSend('on_notification', userId, mapNotification(item)); } } diff --git a/server/src/services/version.service.spec.ts b/server/src/services/version.service.spec.ts index 73794275ea..84c7b578dd 100644 --- a/server/src/services/version.service.spec.ts +++ b/server/src/services/version.service.spec.ts @@ -108,7 +108,7 @@ describe(VersionService.name, () => { await expect(sut.handleVersionCheck()).resolves.toEqual(JobStatus.Success); expect(mocks.systemMetadata.set).toHaveBeenCalled(); expect(mocks.logger.log).toHaveBeenCalled(); - expect(mocks.event.clientBroadcast).toHaveBeenCalled(); + expect(mocks.websocket.clientBroadcast).toHaveBeenCalled(); }); it('should not notify if the version is equal', async () => { @@ -118,14 +118,14 @@ describe(VersionService.name, () => { checkedAt: expect.any(String), releaseVersion: serverVersion.toString(), }); - expect(mocks.event.clientBroadcast).not.toHaveBeenCalled(); + expect(mocks.websocket.clientBroadcast).not.toHaveBeenCalled(); }); it('should handle a github error', async () => { mocks.serverInfo.getGitHubRelease.mockRejectedValue(new Error('GitHub is down')); await expect(sut.handleVersionCheck()).resolves.toEqual(JobStatus.Failed); expect(mocks.systemMetadata.set).not.toHaveBeenCalled(); - expect(mocks.event.clientBroadcast).not.toHaveBeenCalled(); + expect(mocks.websocket.clientBroadcast).not.toHaveBeenCalled(); expect(mocks.logger.warn).toHaveBeenCalled(); }); }); @@ -133,15 +133,15 @@ describe(VersionService.name, () => { describe('onWebsocketConnectionEvent', () => { it('should send on_server_version client event', async () => { await sut.onWebsocketConnection({ userId: '42' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_server_version', '42', expect.any(SemVer)); - expect(mocks.event.clientSend).toHaveBeenCalledTimes(1); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_server_version', '42', expect.any(SemVer)); + expect(mocks.websocket.clientSend).toHaveBeenCalledTimes(1); }); it('should also send a new release notification', async () => { mocks.systemMetadata.get.mockResolvedValue({ checkedAt: '2024-01-01', releaseVersion: 'v1.42.0' }); await sut.onWebsocketConnection({ userId: '42' }); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_server_version', '42', expect.any(SemVer)); - expect(mocks.event.clientSend).toHaveBeenCalledWith('on_new_release', '42', expect.any(Object)); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_server_version', '42', expect.any(SemVer)); + expect(mocks.websocket.clientSend).toHaveBeenCalledWith('on_new_release', '42', expect.any(Object)); }); }); }); diff --git a/server/src/services/version.service.ts b/server/src/services/version.service.ts index b817363eac..2d3924bc49 100644 --- a/server/src/services/version.service.ts +++ b/server/src/services/version.service.ts @@ -92,7 +92,7 @@ export class VersionService extends BaseService { if (semver.gt(releaseVersion, serverVersion)) { this.logger.log(`Found ${releaseVersion}, released at ${new Date(publishedAt).toLocaleString()}`); - this.eventRepository.clientBroadcast('on_new_release', asNotification(metadata)); + this.websocketRepository.clientBroadcast('on_new_release', asNotification(metadata)); } } catch (error: Error | any) { this.logger.warn(`Unable to run version check: ${error}\n${error?.stack}`); @@ -104,10 +104,10 @@ export class VersionService extends BaseService { @OnEvent({ name: 'WebsocketConnect' }) async onWebsocketConnection({ userId }: ArgOf<'WebsocketConnect'>) { - this.eventRepository.clientSend('on_server_version', userId, serverVersion); + this.websocketRepository.clientSend('on_server_version', userId, serverVersion); const metadata = await this.systemMetadataRepository.get(SystemMetadataKey.VersionCheckState); if (metadata) { - this.eventRepository.clientSend('on_new_release', userId, asNotification(metadata)); + this.websocketRepository.clientSend('on_new_release', userId, asNotification(metadata)); } } } diff --git a/server/test/utils.ts b/server/test/utils.ts index bae9163b80..746bab7682 100644 --- a/server/test/utils.ts +++ b/server/test/utils.ts @@ -60,6 +60,7 @@ import { TrashRepository } from 'src/repositories/trash.repository'; import { UserRepository } from 'src/repositories/user.repository'; import { VersionHistoryRepository } from 'src/repositories/version-history.repository'; import { ViewRepository } from 'src/repositories/view-repository'; +import { WebsocketRepository } from 'src/repositories/websocket.repository'; import { DB } from 'src/schema'; import { AuthService } from 'src/services/auth.service'; import { BaseService } from 'src/services/base.service'; @@ -249,6 +250,7 @@ export type ServiceOverrides = { user: UserRepository; versionHistory: VersionHistoryRepository; view: ViewRepository; + websocket: WebsocketRepository; }; type As = T extends RepositoryInterface ? U : never; @@ -323,6 +325,8 @@ export const newTestService = ( user: automock(UserRepository, { strict: false }), versionHistory: automock(VersionHistoryRepository), view: automock(ViewRepository), + // eslint-disable-next-line no-sparse-arrays + websocket: automock(WebsocketRepository, { args: [, loggerMock], strict: false }), }; const sut = new Service( @@ -372,6 +376,7 @@ export const newTestService = ( overrides.user || (mocks.user as As), overrides.versionHistory || (mocks.versionHistory as As), overrides.view || (mocks.view as As), + overrides.websocket || (mocks.websocket as As), ); return {