diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py
index 000119937e74a..6d359ec2de250 100644
--- a/machine-learning/app/main.py
+++ b/machine-learning/app/main.py
@@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Callable, Iterator
from zipfile import BadZipFile
import orjson
-from fastapi import Depends, FastAPI, File, Form, HTTPException
+from fastapi import Depends, FastAPI, File, Form, HTTPException, Response
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
@@ -124,6 +124,23 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
raise HTTPException(422, "Invalid request format.")
+def get_entry(entries: str = Form()) -> InferenceEntry:
+ try:
+ request: PipelineRequest = orjson.loads(entries)
+ for task, types in request.items():
+ for type, entry in types.items():
+ parsed: InferenceEntry = {
+ "name": entry["modelName"],
+ "task": task,
+ "type": type,
+ "options": entry.get("options", {}),
+ }
+ return parsed
+ except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
+ log.error(f"Invalid request format: {e}")
+ raise HTTPException(422, "Invalid request format.")
+
+
app = FastAPI(lifespan=lifespan)
@@ -137,6 +154,20 @@ def ping() -> str:
return "pong"
+@app.post("/load", response_model=TextResponse)
+async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
+ model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
+ model = await load(model)
+ return Response(status_code=200)
+
+
+@app.post("/unload", response_model=TextResponse)
+async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None:
+ await model_cache.unload(entry["name"], entry["type"], entry["task"])
+ print("unload")
+ return Response(status_code=200)
+
+
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),
diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py
index bf8e8a6352145..34c9fd2a41fde 100644
--- a/machine-learning/app/models/cache.py
+++ b/machine-learning/app/models/cache.py
@@ -58,3 +58,10 @@ class ModelCache:
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)
+
+ async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None:
+ key = f"{model_name}{model_type}{model_task}"
+ async with OptimisticLock(self.cache, key):
+ value = await self.cache.get(key)
+ if value is not None:
+ await self.cache.delete(key)
diff --git a/mobile/openapi/README.md b/mobile/openapi/README.md
index 36f442fd88b51..2504da1282710 100644
--- a/mobile/openapi/README.md
+++ b/mobile/openapi/README.md
@@ -116,7 +116,6 @@ Class | Method | HTTP request | Description
*AuthenticationApi* | [**signUpAdmin**](doc//AuthenticationApi.md#signupadmin) | **POST** /auth/admin-sign-up |
*AuthenticationApi* | [**validateAccessToken**](doc//AuthenticationApi.md#validateaccesstoken) | **POST** /auth/validateToken |
*DeprecatedApi* | [**getPersonAssets**](doc//DeprecatedApi.md#getpersonassets) | **GET** /people/{id}/assets |
-*DeprecatedApi* | [**getRandom**](doc//DeprecatedApi.md#getrandom) | **GET** /assets/random |
*DownloadApi* | [**downloadArchive**](doc//DownloadApi.md#downloadarchive) | **POST** /download/archive |
*DownloadApi* | [**getDownloadInfo**](doc//DownloadApi.md#getdownloadinfo) | **POST** /download/info |
*DuplicatesApi* | [**getAssetDuplicates**](doc//DuplicatesApi.md#getassetduplicates) | **GET** /duplicates |
@@ -125,7 +124,6 @@ Class | Method | HTTP request | Description
*FileReportsApi* | [**fixAuditFiles**](doc//FileReportsApi.md#fixauditfiles) | **POST** /reports/fix |
*FileReportsApi* | [**getAuditFiles**](doc//FileReportsApi.md#getauditfiles) | **GET** /reports |
*FileReportsApi* | [**getFileChecksums**](doc//FileReportsApi.md#getfilechecksums) | **POST** /reports/checksum |
-*JobsApi* | [**createJob**](doc//JobsApi.md#createjob) | **POST** /jobs |
*JobsApi* | [**getAllJobsStatus**](doc//JobsApi.md#getalljobsstatus) | **GET** /jobs |
*JobsApi* | [**sendJobCommand**](doc//JobsApi.md#sendjobcommand) | **PUT** /jobs/{id} |
*LibrariesApi* | [**createLibrary**](doc//LibrariesApi.md#createlibrary) | **POST** /libraries |
@@ -137,6 +135,7 @@ Class | Method | HTTP request | Description
*LibrariesApi* | [**updateLibrary**](doc//LibrariesApi.md#updatelibrary) | **PUT** /libraries/{id} |
*LibrariesApi* | [**validate**](doc//LibrariesApi.md#validate) | **POST** /libraries/{id}/validate |
*MapApi* | [**getMapMarkers**](doc//MapApi.md#getmapmarkers) | **GET** /map/markers |
+*MapApi* | [**getMapStyle**](doc//MapApi.md#getmapstyle) | **GET** /map/style.json |
*MapApi* | [**reverseGeocode**](doc//MapApi.md#reversegeocode) | **GET** /map/reverse-geocode |
*MemoriesApi* | [**addMemoryAssets**](doc//MemoriesApi.md#addmemoryassets) | **PUT** /memories/{id}/assets |
*MemoriesApi* | [**createMemory**](doc//MemoriesApi.md#creatememory) | **POST** /memories |
@@ -171,7 +170,6 @@ Class | Method | HTTP request | Description
*SearchApi* | [**searchMetadata**](doc//SearchApi.md#searchmetadata) | **POST** /search/metadata |
*SearchApi* | [**searchPerson**](doc//SearchApi.md#searchperson) | **GET** /search/person |
*SearchApi* | [**searchPlaces**](doc//SearchApi.md#searchplaces) | **GET** /search/places |
-*SearchApi* | [**searchRandom**](doc//SearchApi.md#searchrandom) | **POST** /search/random |
*SearchApi* | [**searchSmart**](doc//SearchApi.md#searchsmart) | **POST** /search/smart |
*ServerApi* | [**deleteServerLicense**](doc//ServerApi.md#deleteserverlicense) | **DELETE** /server/license |
*ServerApi* | [**getAboutInfo**](doc//ServerApi.md#getaboutinfo) | **GET** /server/about |
@@ -332,7 +330,6 @@ Class | Method | HTTP request | Description
- [JobCommand](doc//JobCommand.md)
- [JobCommandDto](doc//JobCommandDto.md)
- [JobCountsDto](doc//JobCountsDto.md)
- - [JobCreateDto](doc//JobCreateDto.md)
- [JobName](doc//JobName.md)
- [JobSettingsDto](doc//JobSettingsDto.md)
- [JobStatusDto](doc//JobStatusDto.md)
@@ -340,13 +337,14 @@ Class | Method | HTTP request | Description
- [LibraryStatsResponseDto](doc//LibraryStatsResponseDto.md)
- [LicenseKeyDto](doc//LicenseKeyDto.md)
- [LicenseResponseDto](doc//LicenseResponseDto.md)
+ - [LoadTextualModelOnConnection](doc//LoadTextualModelOnConnection.md)
- [LogLevel](doc//LogLevel.md)
- [LoginCredentialDto](doc//LoginCredentialDto.md)
- [LoginResponseDto](doc//LoginResponseDto.md)
- [LogoutResponseDto](doc//LogoutResponseDto.md)
- - [ManualJobName](doc//ManualJobName.md)
- [MapMarkerResponseDto](doc//MapMarkerResponseDto.md)
- [MapReverseGeocodeResponseDto](doc//MapReverseGeocodeResponseDto.md)
+ - [MapTheme](doc//MapTheme.md)
- [MemoriesResponse](doc//MemoriesResponse.md)
- [MemoriesUpdate](doc//MemoriesUpdate.md)
- [MemoryCreateDto](doc//MemoryCreateDto.md)
@@ -379,7 +377,6 @@ Class | Method | HTTP request | Description
- [PurchaseResponse](doc//PurchaseResponse.md)
- [PurchaseUpdate](doc//PurchaseUpdate.md)
- [QueueStatusDto](doc//QueueStatusDto.md)
- - [RandomSearchDto](doc//RandomSearchDto.md)
- [RatingsResponse](doc//RatingsResponse.md)
- [RatingsUpdate](doc//RatingsUpdate.md)
- [ReactionLevel](doc//ReactionLevel.md)
@@ -455,7 +452,6 @@ Class | Method | HTTP request | Description
- [ToneMapping](doc//ToneMapping.md)
- [TranscodeHWAccel](doc//TranscodeHWAccel.md)
- [TranscodePolicy](doc//TranscodePolicy.md)
- - [TrashResponseDto](doc//TrashResponseDto.md)
- [UpdateAlbumDto](doc//UpdateAlbumDto.md)
- [UpdateAlbumUserDto](doc//UpdateAlbumUserDto.md)
- [UpdateAssetDto](doc//UpdateAssetDto.md)
diff --git a/mobile/openapi/lib/api.dart b/mobile/openapi/lib/api.dart
index 6fb7478d04bf2..1c1508fd3485e 100644
--- a/mobile/openapi/lib/api.dart
+++ b/mobile/openapi/lib/api.dart
@@ -144,7 +144,6 @@ part 'model/image_format.dart';
part 'model/job_command.dart';
part 'model/job_command_dto.dart';
part 'model/job_counts_dto.dart';
-part 'model/job_create_dto.dart';
part 'model/job_name.dart';
part 'model/job_settings_dto.dart';
part 'model/job_status_dto.dart';
@@ -152,13 +151,14 @@ part 'model/library_response_dto.dart';
part 'model/library_stats_response_dto.dart';
part 'model/license_key_dto.dart';
part 'model/license_response_dto.dart';
+part 'model/load_textual_model_on_connection.dart';
part 'model/log_level.dart';
part 'model/login_credential_dto.dart';
part 'model/login_response_dto.dart';
part 'model/logout_response_dto.dart';
-part 'model/manual_job_name.dart';
part 'model/map_marker_response_dto.dart';
part 'model/map_reverse_geocode_response_dto.dart';
+part 'model/map_theme.dart';
part 'model/memories_response.dart';
part 'model/memories_update.dart';
part 'model/memory_create_dto.dart';
@@ -191,7 +191,6 @@ part 'model/places_response_dto.dart';
part 'model/purchase_response.dart';
part 'model/purchase_update.dart';
part 'model/queue_status_dto.dart';
-part 'model/random_search_dto.dart';
part 'model/ratings_response.dart';
part 'model/ratings_update.dart';
part 'model/reaction_level.dart';
@@ -267,7 +266,6 @@ part 'model/time_bucket_size.dart';
part 'model/tone_mapping.dart';
part 'model/transcode_hw_accel.dart';
part 'model/transcode_policy.dart';
-part 'model/trash_response_dto.dart';
part 'model/update_album_dto.dart';
part 'model/update_album_user_dto.dart';
part 'model/update_asset_dto.dart';
diff --git a/mobile/openapi/lib/api_client.dart b/mobile/openapi/lib/api_client.dart
index c1025b0bd4820..165f7887d427e 100644
--- a/mobile/openapi/lib/api_client.dart
+++ b/mobile/openapi/lib/api_client.dart
@@ -166,6 +166,7 @@ class ApiClient {
/// Returns a native instance of an OpenAPI class matching the [specified type][targetType].
static dynamic fromJson(dynamic value, String targetType, {bool growable = false,}) {
+ upgradeDto(value, targetType);
try {
switch (targetType) {
case 'String':
@@ -342,8 +343,6 @@ class ApiClient {
return JobCommandDto.fromJson(value);
case 'JobCountsDto':
return JobCountsDto.fromJson(value);
- case 'JobCreateDto':
- return JobCreateDto.fromJson(value);
case 'JobName':
return JobNameTypeTransformer().decode(value);
case 'JobSettingsDto':
@@ -358,6 +357,8 @@ class ApiClient {
return LicenseKeyDto.fromJson(value);
case 'LicenseResponseDto':
return LicenseResponseDto.fromJson(value);
+ case 'LoadTextualModelOnConnection':
+ return LoadTextualModelOnConnection.fromJson(value);
case 'LogLevel':
return LogLevelTypeTransformer().decode(value);
case 'LoginCredentialDto':
@@ -366,12 +367,12 @@ class ApiClient {
return LoginResponseDto.fromJson(value);
case 'LogoutResponseDto':
return LogoutResponseDto.fromJson(value);
- case 'ManualJobName':
- return ManualJobNameTypeTransformer().decode(value);
case 'MapMarkerResponseDto':
return MapMarkerResponseDto.fromJson(value);
case 'MapReverseGeocodeResponseDto':
return MapReverseGeocodeResponseDto.fromJson(value);
+ case 'MapTheme':
+ return MapThemeTypeTransformer().decode(value);
case 'MemoriesResponse':
return MemoriesResponse.fromJson(value);
case 'MemoriesUpdate':
@@ -436,8 +437,6 @@ class ApiClient {
return PurchaseUpdate.fromJson(value);
case 'QueueStatusDto':
return QueueStatusDto.fromJson(value);
- case 'RandomSearchDto':
- return RandomSearchDto.fromJson(value);
case 'RatingsResponse':
return RatingsResponse.fromJson(value);
case 'RatingsUpdate':
@@ -588,8 +587,6 @@ class ApiClient {
return TranscodeHWAccelTypeTransformer().decode(value);
case 'TranscodePolicy':
return TranscodePolicyTypeTransformer().decode(value);
- case 'TrashResponseDto':
- return TrashResponseDto.fromJson(value);
case 'UpdateAlbumDto':
return UpdateAlbumDto.fromJson(value);
case 'UpdateAlbumUserDto':
diff --git a/mobile/openapi/lib/model/clip_config.dart b/mobile/openapi/lib/model/clip_config.dart
index b500d20f2e6eb..f41fb2b6ba80a 100644
--- a/mobile/openapi/lib/model/clip_config.dart
+++ b/mobile/openapi/lib/model/clip_config.dart
@@ -14,30 +14,36 @@ class CLIPConfig {
/// Returns a new [CLIPConfig] instance.
CLIPConfig({
required this.enabled,
+ required this.loadTextualModelOnConnection,
required this.modelName,
});
bool enabled;
+ LoadTextualModelOnConnection loadTextualModelOnConnection;
+
String modelName;
@override
bool operator ==(Object other) => identical(this, other) || other is CLIPConfig &&
other.enabled == enabled &&
+ other.loadTextualModelOnConnection == loadTextualModelOnConnection &&
other.modelName == modelName;
@override
int get hashCode =>
// ignore: unnecessary_parenthesis
(enabled.hashCode) +
+ (loadTextualModelOnConnection.hashCode) +
(modelName.hashCode);
@override
- String toString() => 'CLIPConfig[enabled=$enabled, modelName=$modelName]';
+ String toString() => 'CLIPConfig[enabled=$enabled, loadTextualModelOnConnection=$loadTextualModelOnConnection, modelName=$modelName]';
Map toJson() {
final json = {};
json[r'enabled'] = this.enabled;
+ json[r'loadTextualModelOnConnection'] = this.loadTextualModelOnConnection;
json[r'modelName'] = this.modelName;
return json;
}
@@ -46,12 +52,12 @@ class CLIPConfig {
/// [value] if it's a [Map], null otherwise.
// ignore: prefer_constructors_over_static_methods
static CLIPConfig? fromJson(dynamic value) {
- upgradeDto(value, "CLIPConfig");
if (value is Map) {
final json = value.cast();
return CLIPConfig(
enabled: mapValueOfType(json, r'enabled')!,
+ loadTextualModelOnConnection: LoadTextualModelOnConnection.fromJson(json[r'loadTextualModelOnConnection'])!,
modelName: mapValueOfType(json, r'modelName')!,
);
}
@@ -101,6 +107,7 @@ class CLIPConfig {
/// The list of required keys that must be present in a JSON.
static const requiredKeys = {
'enabled',
+ 'loadTextualModelOnConnection',
'modelName',
};
}
diff --git a/mobile/openapi/lib/model/load_textual_model_on_connection.dart b/mobile/openapi/lib/model/load_textual_model_on_connection.dart
new file mode 100644
index 0000000000000..460799d4fa31e
--- /dev/null
+++ b/mobile/openapi/lib/model/load_textual_model_on_connection.dart
@@ -0,0 +1,107 @@
+//
+// AUTO-GENERATED FILE, DO NOT MODIFY!
+//
+// @dart=2.18
+
+// ignore_for_file: unused_element, unused_import
+// ignore_for_file: always_put_required_named_parameters_first
+// ignore_for_file: constant_identifier_names
+// ignore_for_file: lines_longer_than_80_chars
+
+part of openapi.api;
+
+class LoadTextualModelOnConnection {
+ /// Returns a new [LoadTextualModelOnConnection] instance.
+ LoadTextualModelOnConnection({
+ required this.enabled,
+ required this.ttl,
+ });
+
+ bool enabled;
+
+ /// Minimum value: 0
+ num ttl;
+
+ @override
+ bool operator ==(Object other) => identical(this, other) || other is LoadTextualModelOnConnection &&
+ other.enabled == enabled &&
+ other.ttl == ttl;
+
+ @override
+ int get hashCode =>
+ // ignore: unnecessary_parenthesis
+ (enabled.hashCode) +
+ (ttl.hashCode);
+
+ @override
+ String toString() => 'LoadTextualModelOnConnection[enabled=$enabled, ttl=$ttl]';
+
+ Map toJson() {
+ final json = {};
+ json[r'enabled'] = this.enabled;
+ json[r'ttl'] = this.ttl;
+ return json;
+ }
+
+ /// Returns a new [LoadTextualModelOnConnection] instance and imports its values from
+ /// [value] if it's a [Map], null otherwise.
+ // ignore: prefer_constructors_over_static_methods
+ static LoadTextualModelOnConnection? fromJson(dynamic value) {
+ if (value is Map) {
+ final json = value.cast();
+
+ return LoadTextualModelOnConnection(
+ enabled: mapValueOfType(json, r'enabled')!,
+ ttl: num.parse('${json[r'ttl']}'),
+ );
+ }
+ return null;
+ }
+
+ static List listFromJson(dynamic json, {bool growable = false,}) {
+ final result = [];
+ if (json is List && json.isNotEmpty) {
+ for (final row in json) {
+ final value = LoadTextualModelOnConnection.fromJson(row);
+ if (value != null) {
+ result.add(value);
+ }
+ }
+ }
+ return result.toList(growable: growable);
+ }
+
+ static Map mapFromJson(dynamic json) {
+ final map = {};
+ if (json is Map && json.isNotEmpty) {
+ json = json.cast(); // ignore: parameter_assignments
+ for (final entry in json.entries) {
+ final value = LoadTextualModelOnConnection.fromJson(entry.value);
+ if (value != null) {
+ map[entry.key] = value;
+ }
+ }
+ }
+ return map;
+ }
+
+ // maps a json object with a list of LoadTextualModelOnConnection-objects as value to a dart map
+ static Map> mapListFromJson(dynamic json, {bool growable = false,}) {
+ final map = >{};
+ if (json is Map && json.isNotEmpty) {
+ // ignore: parameter_assignments
+ json = json.cast();
+ for (final entry in json.entries) {
+ map[entry.key] = LoadTextualModelOnConnection.listFromJson(entry.value, growable: growable,);
+ }
+ }
+ return map;
+ }
+
+ /// The list of required keys that must be present in a JSON.
+ static const requiredKeys = {
+ 'enabled',
+ 'ttl',
+ };
+}
+
diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json
index d28effd6c5676..844bba542ee88 100644
--- a/open-api/immich-openapi-specs.json
+++ b/open-api/immich-openapi-specs.json
@@ -5296,8 +5296,8 @@
"name": "password",
"required": false,
"in": "query",
+ "example": "password",
"schema": {
- "example": "password",
"type": "string"
}
},
@@ -8642,12 +8642,16 @@
"enabled": {
"type": "boolean"
},
+ "loadTextualModelOnConnection": {
+ "$ref": "#/components/schemas/LoadTextualModelOnConnection"
+ },
"modelName": {
"type": "string"
}
},
"required": [
"enabled",
+ "loadTextualModelOnConnection",
"modelName"
],
"type": "object"
@@ -9488,6 +9492,17 @@
],
"type": "object"
},
+ "LoadTextualModelOnConnection": {
+ "properties": {
+ "enabled": {
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "enabled"
+ ],
+ "type": "object"
+ },
"LogLevel": {
"enum": [
"verbose",
diff --git a/open-api/typescript-sdk/src/fetch-client.ts b/open-api/typescript-sdk/src/fetch-client.ts
index 4f5eed0d13e21..fc4b33a81d356 100644
--- a/open-api/typescript-sdk/src/fetch-client.ts
+++ b/open-api/typescript-sdk/src/fetch-client.ts
@@ -1150,8 +1150,13 @@ export type SystemConfigLoggingDto = {
enabled: boolean;
level: LogLevel;
};
+export type LoadTextualModelOnConnection = {
+ enabled: boolean;
+ ttl: number;
+};
export type ClipConfig = {
enabled: boolean;
+ loadTextualModelOnConnection: LoadTextualModelOnConnection;
modelName: string;
};
export type DuplicateDetectionConfig = {
diff --git a/server/src/config.ts b/server/src/config.ts
index 2e11f740d372c..8d01da87a3998 100644
--- a/server/src/config.ts
+++ b/server/src/config.ts
@@ -58,6 +58,9 @@ export interface SystemConfig {
clip: {
enabled: boolean;
modelName: string;
+ loadTextualModelOnConnection: {
+ enabled: boolean;
+ };
};
duplicateDetection: {
enabled: boolean;
@@ -205,6 +208,9 @@ export const defaults = Object.freeze({
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',
+ loadTextualModelOnConnection: {
+ enabled: false,
+ },
},
duplicateDetection: {
enabled: true,
diff --git a/server/src/dtos/model-config.dto.ts b/server/src/dtos/model-config.dto.ts
index f8b9e2043f3ea..0c1630e5311eb 100644
--- a/server/src/dtos/model-config.dto.ts
+++ b/server/src/dtos/model-config.dto.ts
@@ -1,6 +1,6 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
-import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
+import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
import { ValidateBoolean } from 'src/validation';
export class TaskConfig {
@@ -14,7 +14,17 @@ export class ModelConfig extends TaskConfig {
modelName!: string;
}
-export class CLIPConfig extends ModelConfig {}
+export class LoadTextualModelOnConnection {
+ @ValidateBoolean()
+ enabled!: boolean;
+}
+
+export class CLIPConfig extends ModelConfig {
+ @Type(() => LoadTextualModelOnConnection)
+ @ValidateNested()
+ @IsObject()
+ loadTextualModelOnConnection!: LoadTextualModelOnConnection;
+}
export class DuplicateDetectionConfig extends TaskConfig {
@IsNumber()
diff --git a/server/src/interfaces/machine-learning.interface.ts b/server/src/interfaces/machine-learning.interface.ts
index 5342030c8fde7..205d69f4f5a34 100644
--- a/server/src/interfaces/machine-learning.interface.ts
+++ b/server/src/interfaces/machine-learning.interface.ts
@@ -46,6 +46,11 @@ export interface Face {
score: number;
}
+export enum LoadTextModelActions {
+ LOAD,
+ UNLOAD,
+}
+
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
@@ -54,4 +59,5 @@ export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise;
encodeText(url: string, text: string, config: ModelOptions): Promise;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise;
+ prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise;
}
diff --git a/server/src/repositories/event.repository.ts b/server/src/repositories/event.repository.ts
index 90d8e7bf5d7a8..2d98943008b21 100644
--- a/server/src/repositories/event.repository.ts
+++ b/server/src/repositories/event.repository.ts
@@ -8,6 +8,7 @@ import {
WebSocketServer,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
+import { SystemConfigCore } from 'src/cores/system-config.core';
import {
ArgsOf,
ClientEventMap,
@@ -18,6 +19,8 @@ import {
ServerEvents,
} from 'src/interfaces/event.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
+import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface';
+import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
import { AuthService } from 'src/services/auth.service';
import { Instrumentation } from 'src/utils/instrumentation';
import { handlePromiseError } from 'src/utils/misc';
@@ -33,6 +36,7 @@ type EmitHandlers = Partial<{ [T in EmitEvent]: Array> }>;
@Injectable()
export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository {
private emitHandlers: EmitHandlers = {};
+ private configCore: SystemConfigCore;
@WebSocketServer()
private server?: Server;
@@ -40,8 +44,11 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
constructor(
private moduleRef: ModuleRef,
@Inject(ILoggerRepository) private logger: ILoggerRepository,
+ @Inject(IMachineLearningRepository) private machineLearningRepository: IMachineLearningRepository,
+ @Inject(ISystemMetadataRepository) systemMetadataRepository: ISystemMetadataRepository,
) {
this.logger.setContext(EventRepository.name);
+ this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger);
}
afterInit(server: Server) {
@@ -63,6 +70,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
queryParams: {},
metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' },
});
+ if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
+ const { machineLearning } = await this.configCore.getConfig({ withCache: true });
+ if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
+ try {
+ console.log(this.server);
+ this.machineLearningRepository.prepareTextModel(
+ machineLearning.url,
+ machineLearning.clip,
+ LoadTextModelActions.LOAD,
+ );
+ } catch (error) {
+ this.logger.warn(error);
+ }
+ }
+ }
await client.join(auth.user.id);
if (auth.session) {
await client.join(auth.session.id);
@@ -78,6 +100,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
async handleDisconnect(client: Socket) {
this.logger.log(`Websocket Disconnect: ${client.id}`);
await client.leave(client.nsp.name);
+ if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
+ const { machineLearning } = await this.configCore.getConfig({ withCache: true });
+ if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) {
+ try {
+ this.machineLearningRepository.prepareTextModel(
+ machineLearning.url,
+ machineLearning.clip,
+ LoadTextModelActions.UNLOAD,
+ );
+ this.logger.debug('sent request to unload text model');
+ } catch (error) {
+ this.logger.warn(error);
+ }
+ }
+ }
}
on(item: EventItem): void {
diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts
index b9404022efffa..a65b29fa56321 100644
--- a/server/src/repositories/machine-learning.repository.ts
+++ b/server/src/repositories/machine-learning.repository.ts
@@ -7,6 +7,7 @@ import {
FaceDetectionOptions,
FacialRecognitionResponse,
IMachineLearningRepository,
+ LoadTextModelActions,
MachineLearningRequest,
ModelPayload,
ModelTask,
@@ -20,13 +21,9 @@ const errorPrefix = 'Machine learning request';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise {
- const formData = await this.getFormData(payload, config);
+ const formData = await this.getFormData(config, payload);
- const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
- (error: Error | any) => {
- throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
- },
- );
+ const res = await this.fetchData(url, '/predict', formData);
if (res.status >= 400) {
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
@@ -34,6 +31,30 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return res.json();
}
+ private async fetchData(url: string, path: string, formData?: FormData): Promise {
+ const res = await fetch(new URL(path, url), { method: 'POST', body: formData }).catch((error: Error | any) => {
+ throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
+ });
+
+ return res;
+ }
+
+ private prepareTextModelUrl: Record = {
+ [LoadTextModelActions.LOAD]: '/load',
+ [LoadTextModelActions.UNLOAD]: '/unload',
+ };
+
+ async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) {
+ try {
+ const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
+ const formData = await this.getFormData(request);
+ const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData);
+ if (res.status >= 400) {
+ throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
+ }
+ } catch (error) {}
+ }
+
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
@@ -61,16 +82,17 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return response[ModelTask.SEARCH];
}
- private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise {
+ private async getFormData(config: MachineLearningRequest, payload?: ModelPayload): Promise {
const formData = new FormData();
formData.append('entries', JSON.stringify(config));
-
- if ('imagePath' in payload) {
- formData.append('image', new Blob([await readFile(payload.imagePath)]));
- } else if ('text' in payload) {
- formData.append('text', payload.text);
- } else {
- throw new Error('Invalid input');
+ if (payload) {
+ if ('imagePath' in payload) {
+ formData.append('image', new Blob([await readFile(payload.imagePath)]));
+ } else if ('text' in payload) {
+ formData.append('text', payload.text);
+ } else {
+ throw new Error('Invalid input');
+ }
}
return formData;
diff --git a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte
index aac8cd52123be..a188cbc67a5b3 100644
--- a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte
+++ b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte
@@ -75,6 +75,21 @@
+
+
+
+
+
+
diff --git a/web/src/lib/i18n/en.json b/web/src/lib/i18n/en.json
index 22eb1c8f789dd..f8913e329cb8a 100644
--- a/web/src/lib/i18n/en.json
+++ b/web/src/lib/i18n/en.json
@@ -119,6 +119,12 @@
"machine_learning_min_detection_score_description": "Minimum confidence score for a face to be detected from 0-1. Lower values will detect more faces but may result in false positives.",
"machine_learning_min_recognized_faces": "Minimum recognized faces",
"machine_learning_min_recognized_faces_description": "The minimum number of recognized faces for a person to be created. Increasing this makes Facial Recognition more precise at the cost of increasing the chance that a face is not assigned to a person.",
+ "machine_learning_preload_model": "Preload model",
+ "machine_learning_preload_model_enabled": "Enable preload model",
+ "machine_learning_preload_model_enabled_description": "Preload the textual model during the connexion instead of during the first search",
+ "machine_learning_preload_model_setting_description": "Preload the textual model during the connexion",
+ "machine_learning_preload_model_ttl": "Inactivity time before a model in unloaded",
+ "machine_learning_preload_model_ttl_description": "Preload the textual model during the connexion",
"machine_learning_settings": "Machine Learning Settings",
"machine_learning_settings_description": "Manage machine learning features and settings",
"machine_learning_smart_search": "Smart Search",
diff --git a/web/src/lib/stores/websocket.ts b/web/src/lib/stores/websocket.ts
index d398ca52a9d89..d1eafff6c083d 100644
--- a/web/src/lib/stores/websocket.ts
+++ b/web/src/lib/stores/websocket.ts
@@ -35,6 +35,9 @@ const websocket: Socket = io({
reconnection: true,
forceNew: true,
autoConnect: false,
+ query: {
+ background: false,
+ },
});
export const websocketStore = {