diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index 000119937e74a..6d359ec2de250 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Callable, Iterator from zipfile import BadZipFile import orjson -from fastapi import Depends, FastAPI, File, Form, HTTPException +from fastapi import Depends, FastAPI, File, Form, HTTPException, Response from fastapi.responses import ORJSONResponse from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile from PIL.Image import Image @@ -124,6 +124,23 @@ def get_entries(entries: str = Form()) -> InferenceEntries: raise HTTPException(422, "Invalid request format.") +def get_entry(entries: str = Form()) -> InferenceEntry: + try: + request: PipelineRequest = orjson.loads(entries) + for task, types in request.items(): + for type, entry in types.items(): + parsed: InferenceEntry = { + "name": entry["modelName"], + "task": task, + "type": type, + "options": entry.get("options", {}), + } + return parsed + except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e: + log.error(f"Invalid request format: {e}") + raise HTTPException(422, "Invalid request format.") + + app = FastAPI(lifespan=lifespan) @@ -137,6 +154,20 @@ def ping() -> str: return "pong" +@app.post("/load", response_model=TextResponse) +async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None: + model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl) + model = await load(model) + return Response(status_code=200) + + +@app.post("/unload", response_model=TextResponse) +async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None: + await model_cache.unload(entry["name"], entry["type"], entry["task"]) + print("unload") + return Response(status_code=200) + + @app.post("/predict", dependencies=[Depends(update_state)]) async def predict( entries: InferenceEntries = Depends(get_entries), diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py index bf8e8a6352145..34c9fd2a41fde 100644 --- a/machine-learning/app/models/cache.py +++ b/machine-learning/app/models/cache.py @@ -58,3 +58,10 @@ class ModelCache: async def revalidate(self, key: str, ttl: int | None) -> None: if ttl is not None and key in self.cache._handlers: await self.cache.expire(key, ttl) + + async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None: + key = f"{model_name}{model_type}{model_task}" + async with OptimisticLock(self.cache, key): + value = await self.cache.get(key) + if value is not None: + await self.cache.delete(key) diff --git a/mobile/openapi/README.md b/mobile/openapi/README.md index 36f442fd88b51..2504da1282710 100644 --- a/mobile/openapi/README.md +++ b/mobile/openapi/README.md @@ -116,7 +116,6 @@ Class | Method | HTTP request | Description *AuthenticationApi* | [**signUpAdmin**](doc//AuthenticationApi.md#signupadmin) | **POST** /auth/admin-sign-up | *AuthenticationApi* | [**validateAccessToken**](doc//AuthenticationApi.md#validateaccesstoken) | **POST** /auth/validateToken | *DeprecatedApi* | [**getPersonAssets**](doc//DeprecatedApi.md#getpersonassets) | **GET** /people/{id}/assets | -*DeprecatedApi* | [**getRandom**](doc//DeprecatedApi.md#getrandom) | **GET** /assets/random | *DownloadApi* | [**downloadArchive**](doc//DownloadApi.md#downloadarchive) | **POST** /download/archive | *DownloadApi* | [**getDownloadInfo**](doc//DownloadApi.md#getdownloadinfo) | **POST** /download/info | *DuplicatesApi* | [**getAssetDuplicates**](doc//DuplicatesApi.md#getassetduplicates) | **GET** /duplicates | @@ -125,7 +124,6 @@ Class | Method | HTTP request | Description *FileReportsApi* | [**fixAuditFiles**](doc//FileReportsApi.md#fixauditfiles) | **POST** /reports/fix | *FileReportsApi* | [**getAuditFiles**](doc//FileReportsApi.md#getauditfiles) | **GET** /reports | *FileReportsApi* | [**getFileChecksums**](doc//FileReportsApi.md#getfilechecksums) | **POST** /reports/checksum | -*JobsApi* | [**createJob**](doc//JobsApi.md#createjob) | **POST** /jobs | *JobsApi* | [**getAllJobsStatus**](doc//JobsApi.md#getalljobsstatus) | **GET** /jobs | *JobsApi* | [**sendJobCommand**](doc//JobsApi.md#sendjobcommand) | **PUT** /jobs/{id} | *LibrariesApi* | [**createLibrary**](doc//LibrariesApi.md#createlibrary) | **POST** /libraries | @@ -137,6 +135,7 @@ Class | Method | HTTP request | Description *LibrariesApi* | [**updateLibrary**](doc//LibrariesApi.md#updatelibrary) | **PUT** /libraries/{id} | *LibrariesApi* | [**validate**](doc//LibrariesApi.md#validate) | **POST** /libraries/{id}/validate | *MapApi* | [**getMapMarkers**](doc//MapApi.md#getmapmarkers) | **GET** /map/markers | +*MapApi* | [**getMapStyle**](doc//MapApi.md#getmapstyle) | **GET** /map/style.json | *MapApi* | [**reverseGeocode**](doc//MapApi.md#reversegeocode) | **GET** /map/reverse-geocode | *MemoriesApi* | [**addMemoryAssets**](doc//MemoriesApi.md#addmemoryassets) | **PUT** /memories/{id}/assets | *MemoriesApi* | [**createMemory**](doc//MemoriesApi.md#creatememory) | **POST** /memories | @@ -171,7 +170,6 @@ Class | Method | HTTP request | Description *SearchApi* | [**searchMetadata**](doc//SearchApi.md#searchmetadata) | **POST** /search/metadata | *SearchApi* | [**searchPerson**](doc//SearchApi.md#searchperson) | **GET** /search/person | *SearchApi* | [**searchPlaces**](doc//SearchApi.md#searchplaces) | **GET** /search/places | -*SearchApi* | [**searchRandom**](doc//SearchApi.md#searchrandom) | **POST** /search/random | *SearchApi* | [**searchSmart**](doc//SearchApi.md#searchsmart) | **POST** /search/smart | *ServerApi* | [**deleteServerLicense**](doc//ServerApi.md#deleteserverlicense) | **DELETE** /server/license | *ServerApi* | [**getAboutInfo**](doc//ServerApi.md#getaboutinfo) | **GET** /server/about | @@ -332,7 +330,6 @@ Class | Method | HTTP request | Description - [JobCommand](doc//JobCommand.md) - [JobCommandDto](doc//JobCommandDto.md) - [JobCountsDto](doc//JobCountsDto.md) - - [JobCreateDto](doc//JobCreateDto.md) - [JobName](doc//JobName.md) - [JobSettingsDto](doc//JobSettingsDto.md) - [JobStatusDto](doc//JobStatusDto.md) @@ -340,13 +337,14 @@ Class | Method | HTTP request | Description - [LibraryStatsResponseDto](doc//LibraryStatsResponseDto.md) - [LicenseKeyDto](doc//LicenseKeyDto.md) - [LicenseResponseDto](doc//LicenseResponseDto.md) + - [LoadTextualModelOnConnection](doc//LoadTextualModelOnConnection.md) - [LogLevel](doc//LogLevel.md) - [LoginCredentialDto](doc//LoginCredentialDto.md) - [LoginResponseDto](doc//LoginResponseDto.md) - [LogoutResponseDto](doc//LogoutResponseDto.md) - - [ManualJobName](doc//ManualJobName.md) - [MapMarkerResponseDto](doc//MapMarkerResponseDto.md) - [MapReverseGeocodeResponseDto](doc//MapReverseGeocodeResponseDto.md) + - [MapTheme](doc//MapTheme.md) - [MemoriesResponse](doc//MemoriesResponse.md) - [MemoriesUpdate](doc//MemoriesUpdate.md) - [MemoryCreateDto](doc//MemoryCreateDto.md) @@ -379,7 +377,6 @@ Class | Method | HTTP request | Description - [PurchaseResponse](doc//PurchaseResponse.md) - [PurchaseUpdate](doc//PurchaseUpdate.md) - [QueueStatusDto](doc//QueueStatusDto.md) - - [RandomSearchDto](doc//RandomSearchDto.md) - [RatingsResponse](doc//RatingsResponse.md) - [RatingsUpdate](doc//RatingsUpdate.md) - [ReactionLevel](doc//ReactionLevel.md) @@ -455,7 +452,6 @@ Class | Method | HTTP request | Description - [ToneMapping](doc//ToneMapping.md) - [TranscodeHWAccel](doc//TranscodeHWAccel.md) - [TranscodePolicy](doc//TranscodePolicy.md) - - [TrashResponseDto](doc//TrashResponseDto.md) - [UpdateAlbumDto](doc//UpdateAlbumDto.md) - [UpdateAlbumUserDto](doc//UpdateAlbumUserDto.md) - [UpdateAssetDto](doc//UpdateAssetDto.md) diff --git a/mobile/openapi/lib/api.dart b/mobile/openapi/lib/api.dart index 6fb7478d04bf2..1c1508fd3485e 100644 --- a/mobile/openapi/lib/api.dart +++ b/mobile/openapi/lib/api.dart @@ -144,7 +144,6 @@ part 'model/image_format.dart'; part 'model/job_command.dart'; part 'model/job_command_dto.dart'; part 'model/job_counts_dto.dart'; -part 'model/job_create_dto.dart'; part 'model/job_name.dart'; part 'model/job_settings_dto.dart'; part 'model/job_status_dto.dart'; @@ -152,13 +151,14 @@ part 'model/library_response_dto.dart'; part 'model/library_stats_response_dto.dart'; part 'model/license_key_dto.dart'; part 'model/license_response_dto.dart'; +part 'model/load_textual_model_on_connection.dart'; part 'model/log_level.dart'; part 'model/login_credential_dto.dart'; part 'model/login_response_dto.dart'; part 'model/logout_response_dto.dart'; -part 'model/manual_job_name.dart'; part 'model/map_marker_response_dto.dart'; part 'model/map_reverse_geocode_response_dto.dart'; +part 'model/map_theme.dart'; part 'model/memories_response.dart'; part 'model/memories_update.dart'; part 'model/memory_create_dto.dart'; @@ -191,7 +191,6 @@ part 'model/places_response_dto.dart'; part 'model/purchase_response.dart'; part 'model/purchase_update.dart'; part 'model/queue_status_dto.dart'; -part 'model/random_search_dto.dart'; part 'model/ratings_response.dart'; part 'model/ratings_update.dart'; part 'model/reaction_level.dart'; @@ -267,7 +266,6 @@ part 'model/time_bucket_size.dart'; part 'model/tone_mapping.dart'; part 'model/transcode_hw_accel.dart'; part 'model/transcode_policy.dart'; -part 'model/trash_response_dto.dart'; part 'model/update_album_dto.dart'; part 'model/update_album_user_dto.dart'; part 'model/update_asset_dto.dart'; diff --git a/mobile/openapi/lib/api_client.dart b/mobile/openapi/lib/api_client.dart index c1025b0bd4820..165f7887d427e 100644 --- a/mobile/openapi/lib/api_client.dart +++ b/mobile/openapi/lib/api_client.dart @@ -166,6 +166,7 @@ class ApiClient { /// Returns a native instance of an OpenAPI class matching the [specified type][targetType]. static dynamic fromJson(dynamic value, String targetType, {bool growable = false,}) { + upgradeDto(value, targetType); try { switch (targetType) { case 'String': @@ -342,8 +343,6 @@ class ApiClient { return JobCommandDto.fromJson(value); case 'JobCountsDto': return JobCountsDto.fromJson(value); - case 'JobCreateDto': - return JobCreateDto.fromJson(value); case 'JobName': return JobNameTypeTransformer().decode(value); case 'JobSettingsDto': @@ -358,6 +357,8 @@ class ApiClient { return LicenseKeyDto.fromJson(value); case 'LicenseResponseDto': return LicenseResponseDto.fromJson(value); + case 'LoadTextualModelOnConnection': + return LoadTextualModelOnConnection.fromJson(value); case 'LogLevel': return LogLevelTypeTransformer().decode(value); case 'LoginCredentialDto': @@ -366,12 +367,12 @@ class ApiClient { return LoginResponseDto.fromJson(value); case 'LogoutResponseDto': return LogoutResponseDto.fromJson(value); - case 'ManualJobName': - return ManualJobNameTypeTransformer().decode(value); case 'MapMarkerResponseDto': return MapMarkerResponseDto.fromJson(value); case 'MapReverseGeocodeResponseDto': return MapReverseGeocodeResponseDto.fromJson(value); + case 'MapTheme': + return MapThemeTypeTransformer().decode(value); case 'MemoriesResponse': return MemoriesResponse.fromJson(value); case 'MemoriesUpdate': @@ -436,8 +437,6 @@ class ApiClient { return PurchaseUpdate.fromJson(value); case 'QueueStatusDto': return QueueStatusDto.fromJson(value); - case 'RandomSearchDto': - return RandomSearchDto.fromJson(value); case 'RatingsResponse': return RatingsResponse.fromJson(value); case 'RatingsUpdate': @@ -588,8 +587,6 @@ class ApiClient { return TranscodeHWAccelTypeTransformer().decode(value); case 'TranscodePolicy': return TranscodePolicyTypeTransformer().decode(value); - case 'TrashResponseDto': - return TrashResponseDto.fromJson(value); case 'UpdateAlbumDto': return UpdateAlbumDto.fromJson(value); case 'UpdateAlbumUserDto': diff --git a/mobile/openapi/lib/model/clip_config.dart b/mobile/openapi/lib/model/clip_config.dart index b500d20f2e6eb..f41fb2b6ba80a 100644 --- a/mobile/openapi/lib/model/clip_config.dart +++ b/mobile/openapi/lib/model/clip_config.dart @@ -14,30 +14,36 @@ class CLIPConfig { /// Returns a new [CLIPConfig] instance. CLIPConfig({ required this.enabled, + required this.loadTextualModelOnConnection, required this.modelName, }); bool enabled; + LoadTextualModelOnConnection loadTextualModelOnConnection; + String modelName; @override bool operator ==(Object other) => identical(this, other) || other is CLIPConfig && other.enabled == enabled && + other.loadTextualModelOnConnection == loadTextualModelOnConnection && other.modelName == modelName; @override int get hashCode => // ignore: unnecessary_parenthesis (enabled.hashCode) + + (loadTextualModelOnConnection.hashCode) + (modelName.hashCode); @override - String toString() => 'CLIPConfig[enabled=$enabled, modelName=$modelName]'; + String toString() => 'CLIPConfig[enabled=$enabled, loadTextualModelOnConnection=$loadTextualModelOnConnection, modelName=$modelName]'; Map toJson() { final json = {}; json[r'enabled'] = this.enabled; + json[r'loadTextualModelOnConnection'] = this.loadTextualModelOnConnection; json[r'modelName'] = this.modelName; return json; } @@ -46,12 +52,12 @@ class CLIPConfig { /// [value] if it's a [Map], null otherwise. // ignore: prefer_constructors_over_static_methods static CLIPConfig? fromJson(dynamic value) { - upgradeDto(value, "CLIPConfig"); if (value is Map) { final json = value.cast(); return CLIPConfig( enabled: mapValueOfType(json, r'enabled')!, + loadTextualModelOnConnection: LoadTextualModelOnConnection.fromJson(json[r'loadTextualModelOnConnection'])!, modelName: mapValueOfType(json, r'modelName')!, ); } @@ -101,6 +107,7 @@ class CLIPConfig { /// The list of required keys that must be present in a JSON. static const requiredKeys = { 'enabled', + 'loadTextualModelOnConnection', 'modelName', }; } diff --git a/mobile/openapi/lib/model/load_textual_model_on_connection.dart b/mobile/openapi/lib/model/load_textual_model_on_connection.dart new file mode 100644 index 0000000000000..460799d4fa31e --- /dev/null +++ b/mobile/openapi/lib/model/load_textual_model_on_connection.dart @@ -0,0 +1,107 @@ +// +// AUTO-GENERATED FILE, DO NOT MODIFY! +// +// @dart=2.18 + +// ignore_for_file: unused_element, unused_import +// ignore_for_file: always_put_required_named_parameters_first +// ignore_for_file: constant_identifier_names +// ignore_for_file: lines_longer_than_80_chars + +part of openapi.api; + +class LoadTextualModelOnConnection { + /// Returns a new [LoadTextualModelOnConnection] instance. + LoadTextualModelOnConnection({ + required this.enabled, + required this.ttl, + }); + + bool enabled; + + /// Minimum value: 0 + num ttl; + + @override + bool operator ==(Object other) => identical(this, other) || other is LoadTextualModelOnConnection && + other.enabled == enabled && + other.ttl == ttl; + + @override + int get hashCode => + // ignore: unnecessary_parenthesis + (enabled.hashCode) + + (ttl.hashCode); + + @override + String toString() => 'LoadTextualModelOnConnection[enabled=$enabled, ttl=$ttl]'; + + Map toJson() { + final json = {}; + json[r'enabled'] = this.enabled; + json[r'ttl'] = this.ttl; + return json; + } + + /// Returns a new [LoadTextualModelOnConnection] instance and imports its values from + /// [value] if it's a [Map], null otherwise. + // ignore: prefer_constructors_over_static_methods + static LoadTextualModelOnConnection? fromJson(dynamic value) { + if (value is Map) { + final json = value.cast(); + + return LoadTextualModelOnConnection( + enabled: mapValueOfType(json, r'enabled')!, + ttl: num.parse('${json[r'ttl']}'), + ); + } + return null; + } + + static List listFromJson(dynamic json, {bool growable = false,}) { + final result = []; + if (json is List && json.isNotEmpty) { + for (final row in json) { + final value = LoadTextualModelOnConnection.fromJson(row); + if (value != null) { + result.add(value); + } + } + } + return result.toList(growable: growable); + } + + static Map mapFromJson(dynamic json) { + final map = {}; + if (json is Map && json.isNotEmpty) { + json = json.cast(); // ignore: parameter_assignments + for (final entry in json.entries) { + final value = LoadTextualModelOnConnection.fromJson(entry.value); + if (value != null) { + map[entry.key] = value; + } + } + } + return map; + } + + // maps a json object with a list of LoadTextualModelOnConnection-objects as value to a dart map + static Map> mapListFromJson(dynamic json, {bool growable = false,}) { + final map = >{}; + if (json is Map && json.isNotEmpty) { + // ignore: parameter_assignments + json = json.cast(); + for (final entry in json.entries) { + map[entry.key] = LoadTextualModelOnConnection.listFromJson(entry.value, growable: growable,); + } + } + return map; + } + + /// The list of required keys that must be present in a JSON. + static const requiredKeys = { + 'enabled', + 'ttl', + }; +} + diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json index d28effd6c5676..844bba542ee88 100644 --- a/open-api/immich-openapi-specs.json +++ b/open-api/immich-openapi-specs.json @@ -5296,8 +5296,8 @@ "name": "password", "required": false, "in": "query", + "example": "password", "schema": { - "example": "password", "type": "string" } }, @@ -8642,12 +8642,16 @@ "enabled": { "type": "boolean" }, + "loadTextualModelOnConnection": { + "$ref": "#/components/schemas/LoadTextualModelOnConnection" + }, "modelName": { "type": "string" } }, "required": [ "enabled", + "loadTextualModelOnConnection", "modelName" ], "type": "object" @@ -9488,6 +9492,17 @@ ], "type": "object" }, + "LoadTextualModelOnConnection": { + "properties": { + "enabled": { + "type": "boolean" + } + }, + "required": [ + "enabled" + ], + "type": "object" + }, "LogLevel": { "enum": [ "verbose", diff --git a/open-api/typescript-sdk/src/fetch-client.ts b/open-api/typescript-sdk/src/fetch-client.ts index 4f5eed0d13e21..fc4b33a81d356 100644 --- a/open-api/typescript-sdk/src/fetch-client.ts +++ b/open-api/typescript-sdk/src/fetch-client.ts @@ -1150,8 +1150,13 @@ export type SystemConfigLoggingDto = { enabled: boolean; level: LogLevel; }; +export type LoadTextualModelOnConnection = { + enabled: boolean; + ttl: number; +}; export type ClipConfig = { enabled: boolean; + loadTextualModelOnConnection: LoadTextualModelOnConnection; modelName: string; }; export type DuplicateDetectionConfig = { diff --git a/server/src/config.ts b/server/src/config.ts index 2e11f740d372c..8d01da87a3998 100644 --- a/server/src/config.ts +++ b/server/src/config.ts @@ -58,6 +58,9 @@ export interface SystemConfig { clip: { enabled: boolean; modelName: string; + loadTextualModelOnConnection: { + enabled: boolean; + }; }; duplicateDetection: { enabled: boolean; @@ -205,6 +208,9 @@ export const defaults = Object.freeze({ clip: { enabled: true, modelName: 'ViT-B-32__openai', + loadTextualModelOnConnection: { + enabled: false, + }, }, duplicateDetection: { enabled: true, diff --git a/server/src/dtos/model-config.dto.ts b/server/src/dtos/model-config.dto.ts index f8b9e2043f3ea..0c1630e5311eb 100644 --- a/server/src/dtos/model-config.dto.ts +++ b/server/src/dtos/model-config.dto.ts @@ -1,6 +1,6 @@ import { ApiProperty } from '@nestjs/swagger'; import { Type } from 'class-transformer'; -import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator'; +import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator'; import { ValidateBoolean } from 'src/validation'; export class TaskConfig { @@ -14,7 +14,17 @@ export class ModelConfig extends TaskConfig { modelName!: string; } -export class CLIPConfig extends ModelConfig {} +export class LoadTextualModelOnConnection { + @ValidateBoolean() + enabled!: boolean; +} + +export class CLIPConfig extends ModelConfig { + @Type(() => LoadTextualModelOnConnection) + @ValidateNested() + @IsObject() + loadTextualModelOnConnection!: LoadTextualModelOnConnection; +} export class DuplicateDetectionConfig extends TaskConfig { @IsNumber() diff --git a/server/src/interfaces/machine-learning.interface.ts b/server/src/interfaces/machine-learning.interface.ts index 5342030c8fde7..205d69f4f5a34 100644 --- a/server/src/interfaces/machine-learning.interface.ts +++ b/server/src/interfaces/machine-learning.interface.ts @@ -46,6 +46,11 @@ export interface Face { score: number; } +export enum LoadTextModelActions { + LOAD, + UNLOAD, +} + export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse; export type DetectedFaces = { faces: Face[] } & VisualResponse; export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest; @@ -54,4 +59,5 @@ export interface IMachineLearningRepository { encodeImage(url: string, imagePath: string, config: ModelOptions): Promise; encodeText(url: string, text: string, config: ModelOptions): Promise; detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise; + prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise; } diff --git a/server/src/repositories/event.repository.ts b/server/src/repositories/event.repository.ts index 90d8e7bf5d7a8..2d98943008b21 100644 --- a/server/src/repositories/event.repository.ts +++ b/server/src/repositories/event.repository.ts @@ -8,6 +8,7 @@ import { WebSocketServer, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; +import { SystemConfigCore } from 'src/cores/system-config.core'; import { ArgsOf, ClientEventMap, @@ -18,6 +19,8 @@ import { ServerEvents, } from 'src/interfaces/event.interface'; import { ILoggerRepository } from 'src/interfaces/logger.interface'; +import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface'; +import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface'; import { AuthService } from 'src/services/auth.service'; import { Instrumentation } from 'src/utils/instrumentation'; import { handlePromiseError } from 'src/utils/misc'; @@ -33,6 +36,7 @@ type EmitHandlers = Partial<{ [T in EmitEvent]: Array> }>; @Injectable() export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository { private emitHandlers: EmitHandlers = {}; + private configCore: SystemConfigCore; @WebSocketServer() private server?: Server; @@ -40,8 +44,11 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect constructor( private moduleRef: ModuleRef, @Inject(ILoggerRepository) private logger: ILoggerRepository, + @Inject(IMachineLearningRepository) private machineLearningRepository: IMachineLearningRepository, + @Inject(ISystemMetadataRepository) systemMetadataRepository: ISystemMetadataRepository, ) { this.logger.setContext(EventRepository.name); + this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger); } afterInit(server: Server) { @@ -63,6 +70,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect queryParams: {}, metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' }, }); + if ('background' in client.handshake.query && client.handshake.query.background === 'false') { + const { machineLearning } = await this.configCore.getConfig({ withCache: true }); + if (machineLearning.clip.loadTextualModelOnConnection.enabled) { + try { + console.log(this.server); + this.machineLearningRepository.prepareTextModel( + machineLearning.url, + machineLearning.clip, + LoadTextModelActions.LOAD, + ); + } catch (error) { + this.logger.warn(error); + } + } + } await client.join(auth.user.id); if (auth.session) { await client.join(auth.session.id); @@ -78,6 +100,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect async handleDisconnect(client: Socket) { this.logger.log(`Websocket Disconnect: ${client.id}`); await client.leave(client.nsp.name); + if ('background' in client.handshake.query && client.handshake.query.background === 'false') { + const { machineLearning } = await this.configCore.getConfig({ withCache: true }); + if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) { + try { + this.machineLearningRepository.prepareTextModel( + machineLearning.url, + machineLearning.clip, + LoadTextModelActions.UNLOAD, + ); + this.logger.debug('sent request to unload text model'); + } catch (error) { + this.logger.warn(error); + } + } + } } on(item: EventItem): void { diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index b9404022efffa..a65b29fa56321 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -7,6 +7,7 @@ import { FaceDetectionOptions, FacialRecognitionResponse, IMachineLearningRepository, + LoadTextModelActions, MachineLearningRequest, ModelPayload, ModelTask, @@ -20,13 +21,9 @@ const errorPrefix = 'Machine learning request'; @Injectable() export class MachineLearningRepository implements IMachineLearningRepository { private async predict(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise { - const formData = await this.getFormData(payload, config); + const formData = await this.getFormData(config, payload); - const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch( - (error: Error | any) => { - throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`); - }, - ); + const res = await this.fetchData(url, '/predict', formData); if (res.status >= 400) { throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`); @@ -34,6 +31,30 @@ export class MachineLearningRepository implements IMachineLearningRepository { return res.json(); } + private async fetchData(url: string, path: string, formData?: FormData): Promise { + const res = await fetch(new URL(path, url), { method: 'POST', body: formData }).catch((error: Error | any) => { + throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`); + }); + + return res; + } + + private prepareTextModelUrl: Record = { + [LoadTextModelActions.LOAD]: '/load', + [LoadTextModelActions.UNLOAD]: '/unload', + }; + + async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) { + try { + const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } }; + const formData = await this.getFormData(request); + const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData); + if (res.status >= 400) { + throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`); + } + } catch (error) {} + } + async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) { const request = { [ModelTask.FACIAL_RECOGNITION]: { @@ -61,16 +82,17 @@ export class MachineLearningRepository implements IMachineLearningRepository { return response[ModelTask.SEARCH]; } - private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise { + private async getFormData(config: MachineLearningRequest, payload?: ModelPayload): Promise { const formData = new FormData(); formData.append('entries', JSON.stringify(config)); - - if ('imagePath' in payload) { - formData.append('image', new Blob([await readFile(payload.imagePath)])); - } else if ('text' in payload) { - formData.append('text', payload.text); - } else { - throw new Error('Invalid input'); + if (payload) { + if ('imagePath' in payload) { + formData.append('image', new Blob([await readFile(payload.imagePath)])); + } else if ('text' in payload) { + formData.append('text', payload.text); + } else { + throw new Error('Invalid input'); + } } return formData; diff --git a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte index aac8cd52123be..a188cbc67a5b3 100644 --- a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte +++ b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte @@ -75,6 +75,21 @@

+ + +
+ +
+
diff --git a/web/src/lib/i18n/en.json b/web/src/lib/i18n/en.json index 22eb1c8f789dd..f8913e329cb8a 100644 --- a/web/src/lib/i18n/en.json +++ b/web/src/lib/i18n/en.json @@ -119,6 +119,12 @@ "machine_learning_min_detection_score_description": "Minimum confidence score for a face to be detected from 0-1. Lower values will detect more faces but may result in false positives.", "machine_learning_min_recognized_faces": "Minimum recognized faces", "machine_learning_min_recognized_faces_description": "The minimum number of recognized faces for a person to be created. Increasing this makes Facial Recognition more precise at the cost of increasing the chance that a face is not assigned to a person.", + "machine_learning_preload_model": "Preload model", + "machine_learning_preload_model_enabled": "Enable preload model", + "machine_learning_preload_model_enabled_description": "Preload the textual model during the connexion instead of during the first search", + "machine_learning_preload_model_setting_description": "Preload the textual model during the connexion", + "machine_learning_preload_model_ttl": "Inactivity time before a model in unloaded", + "machine_learning_preload_model_ttl_description": "Preload the textual model during the connexion", "machine_learning_settings": "Machine Learning Settings", "machine_learning_settings_description": "Manage machine learning features and settings", "machine_learning_smart_search": "Smart Search", diff --git a/web/src/lib/stores/websocket.ts b/web/src/lib/stores/websocket.ts index d398ca52a9d89..d1eafff6c083d 100644 --- a/web/src/lib/stores/websocket.ts +++ b/web/src/lib/stores/websocket.ts @@ -35,6 +35,9 @@ const websocket: Socket = io({ reconnection: true, forceNew: true, autoConnect: false, + query: { + background: false, + }, }); export const websocketStore = {