Merge 59300d2097e5c2f8e81b94ea6113750f3d14c9ff into 63437529e1224f5e7879ce567b1a4502fb97573b

This commit is contained in:
martin 2024-10-02 03:04:21 +07:00 committed by GitHub
commit 9b2a7a7828
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 307 additions and 39 deletions

View File

@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Callable, Iterator
from zipfile import BadZipFile
import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
@ -124,6 +124,23 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
raise HTTPException(422, "Invalid request format.")
def get_entry(entries: str = Form()) -> InferenceEntry:
try:
request: PipelineRequest = orjson.loads(entries)
for task, types in request.items():
for type, entry in types.items():
parsed: InferenceEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
}
return parsed
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
log.error(f"Invalid request format: {e}")
raise HTTPException(422, "Invalid request format.")
app = FastAPI(lifespan=lifespan)
@ -137,6 +154,20 @@ def ping() -> str:
return "pong"
@app.post("/load", response_model=TextResponse)
async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
model = await load(model)
return Response(status_code=200)
@app.post("/unload", response_model=TextResponse)
async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None:
await model_cache.unload(entry["name"], entry["type"], entry["task"])
print("unload")
return Response(status_code=200)
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),

View File

@ -58,3 +58,10 @@ class ModelCache:
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)
async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None:
key = f"{model_name}{model_type}{model_task}"
async with OptimisticLock(self.cache, key):
value = await self.cache.get(key)
if value is not None:
await self.cache.delete(key)

View File

@ -116,7 +116,6 @@ Class | Method | HTTP request | Description
*AuthenticationApi* | [**signUpAdmin**](doc//AuthenticationApi.md#signupadmin) | **POST** /auth/admin-sign-up |
*AuthenticationApi* | [**validateAccessToken**](doc//AuthenticationApi.md#validateaccesstoken) | **POST** /auth/validateToken |
*DeprecatedApi* | [**getPersonAssets**](doc//DeprecatedApi.md#getpersonassets) | **GET** /people/{id}/assets |
*DeprecatedApi* | [**getRandom**](doc//DeprecatedApi.md#getrandom) | **GET** /assets/random |
*DownloadApi* | [**downloadArchive**](doc//DownloadApi.md#downloadarchive) | **POST** /download/archive |
*DownloadApi* | [**getDownloadInfo**](doc//DownloadApi.md#getdownloadinfo) | **POST** /download/info |
*DuplicatesApi* | [**getAssetDuplicates**](doc//DuplicatesApi.md#getassetduplicates) | **GET** /duplicates |
@ -125,7 +124,6 @@ Class | Method | HTTP request | Description
*FileReportsApi* | [**fixAuditFiles**](doc//FileReportsApi.md#fixauditfiles) | **POST** /reports/fix |
*FileReportsApi* | [**getAuditFiles**](doc//FileReportsApi.md#getauditfiles) | **GET** /reports |
*FileReportsApi* | [**getFileChecksums**](doc//FileReportsApi.md#getfilechecksums) | **POST** /reports/checksum |
*JobsApi* | [**createJob**](doc//JobsApi.md#createjob) | **POST** /jobs |
*JobsApi* | [**getAllJobsStatus**](doc//JobsApi.md#getalljobsstatus) | **GET** /jobs |
*JobsApi* | [**sendJobCommand**](doc//JobsApi.md#sendjobcommand) | **PUT** /jobs/{id} |
*LibrariesApi* | [**createLibrary**](doc//LibrariesApi.md#createlibrary) | **POST** /libraries |
@ -137,6 +135,7 @@ Class | Method | HTTP request | Description
*LibrariesApi* | [**updateLibrary**](doc//LibrariesApi.md#updatelibrary) | **PUT** /libraries/{id} |
*LibrariesApi* | [**validate**](doc//LibrariesApi.md#validate) | **POST** /libraries/{id}/validate |
*MapApi* | [**getMapMarkers**](doc//MapApi.md#getmapmarkers) | **GET** /map/markers |
*MapApi* | [**getMapStyle**](doc//MapApi.md#getmapstyle) | **GET** /map/style.json |
*MapApi* | [**reverseGeocode**](doc//MapApi.md#reversegeocode) | **GET** /map/reverse-geocode |
*MemoriesApi* | [**addMemoryAssets**](doc//MemoriesApi.md#addmemoryassets) | **PUT** /memories/{id}/assets |
*MemoriesApi* | [**createMemory**](doc//MemoriesApi.md#creatememory) | **POST** /memories |
@ -171,7 +170,6 @@ Class | Method | HTTP request | Description
*SearchApi* | [**searchMetadata**](doc//SearchApi.md#searchmetadata) | **POST** /search/metadata |
*SearchApi* | [**searchPerson**](doc//SearchApi.md#searchperson) | **GET** /search/person |
*SearchApi* | [**searchPlaces**](doc//SearchApi.md#searchplaces) | **GET** /search/places |
*SearchApi* | [**searchRandom**](doc//SearchApi.md#searchrandom) | **POST** /search/random |
*SearchApi* | [**searchSmart**](doc//SearchApi.md#searchsmart) | **POST** /search/smart |
*ServerApi* | [**deleteServerLicense**](doc//ServerApi.md#deleteserverlicense) | **DELETE** /server/license |
*ServerApi* | [**getAboutInfo**](doc//ServerApi.md#getaboutinfo) | **GET** /server/about |
@ -332,7 +330,6 @@ Class | Method | HTTP request | Description
- [JobCommand](doc//JobCommand.md)
- [JobCommandDto](doc//JobCommandDto.md)
- [JobCountsDto](doc//JobCountsDto.md)
- [JobCreateDto](doc//JobCreateDto.md)
- [JobName](doc//JobName.md)
- [JobSettingsDto](doc//JobSettingsDto.md)
- [JobStatusDto](doc//JobStatusDto.md)
@ -340,13 +337,14 @@ Class | Method | HTTP request | Description
- [LibraryStatsResponseDto](doc//LibraryStatsResponseDto.md)
- [LicenseKeyDto](doc//LicenseKeyDto.md)
- [LicenseResponseDto](doc//LicenseResponseDto.md)
- [LoadTextualModelOnConnection](doc//LoadTextualModelOnConnection.md)
- [LogLevel](doc//LogLevel.md)
- [LoginCredentialDto](doc//LoginCredentialDto.md)
- [LoginResponseDto](doc//LoginResponseDto.md)
- [LogoutResponseDto](doc//LogoutResponseDto.md)
- [ManualJobName](doc//ManualJobName.md)
- [MapMarkerResponseDto](doc//MapMarkerResponseDto.md)
- [MapReverseGeocodeResponseDto](doc//MapReverseGeocodeResponseDto.md)
- [MapTheme](doc//MapTheme.md)
- [MemoriesResponse](doc//MemoriesResponse.md)
- [MemoriesUpdate](doc//MemoriesUpdate.md)
- [MemoryCreateDto](doc//MemoryCreateDto.md)
@ -379,7 +377,6 @@ Class | Method | HTTP request | Description
- [PurchaseResponse](doc//PurchaseResponse.md)
- [PurchaseUpdate](doc//PurchaseUpdate.md)
- [QueueStatusDto](doc//QueueStatusDto.md)
- [RandomSearchDto](doc//RandomSearchDto.md)
- [RatingsResponse](doc//RatingsResponse.md)
- [RatingsUpdate](doc//RatingsUpdate.md)
- [ReactionLevel](doc//ReactionLevel.md)
@ -455,7 +452,6 @@ Class | Method | HTTP request | Description
- [ToneMapping](doc//ToneMapping.md)
- [TranscodeHWAccel](doc//TranscodeHWAccel.md)
- [TranscodePolicy](doc//TranscodePolicy.md)
- [TrashResponseDto](doc//TrashResponseDto.md)
- [UpdateAlbumDto](doc//UpdateAlbumDto.md)
- [UpdateAlbumUserDto](doc//UpdateAlbumUserDto.md)
- [UpdateAssetDto](doc//UpdateAssetDto.md)

View File

@ -144,7 +144,6 @@ part 'model/image_format.dart';
part 'model/job_command.dart';
part 'model/job_command_dto.dart';
part 'model/job_counts_dto.dart';
part 'model/job_create_dto.dart';
part 'model/job_name.dart';
part 'model/job_settings_dto.dart';
part 'model/job_status_dto.dart';
@ -152,13 +151,14 @@ part 'model/library_response_dto.dart';
part 'model/library_stats_response_dto.dart';
part 'model/license_key_dto.dart';
part 'model/license_response_dto.dart';
part 'model/load_textual_model_on_connection.dart';
part 'model/log_level.dart';
part 'model/login_credential_dto.dart';
part 'model/login_response_dto.dart';
part 'model/logout_response_dto.dart';
part 'model/manual_job_name.dart';
part 'model/map_marker_response_dto.dart';
part 'model/map_reverse_geocode_response_dto.dart';
part 'model/map_theme.dart';
part 'model/memories_response.dart';
part 'model/memories_update.dart';
part 'model/memory_create_dto.dart';
@ -191,7 +191,6 @@ part 'model/places_response_dto.dart';
part 'model/purchase_response.dart';
part 'model/purchase_update.dart';
part 'model/queue_status_dto.dart';
part 'model/random_search_dto.dart';
part 'model/ratings_response.dart';
part 'model/ratings_update.dart';
part 'model/reaction_level.dart';
@ -267,7 +266,6 @@ part 'model/time_bucket_size.dart';
part 'model/tone_mapping.dart';
part 'model/transcode_hw_accel.dart';
part 'model/transcode_policy.dart';
part 'model/trash_response_dto.dart';
part 'model/update_album_dto.dart';
part 'model/update_album_user_dto.dart';
part 'model/update_asset_dto.dart';

View File

@ -166,6 +166,7 @@ class ApiClient {
/// Returns a native instance of an OpenAPI class matching the [specified type][targetType].
static dynamic fromJson(dynamic value, String targetType, {bool growable = false,}) {
upgradeDto(value, targetType);
try {
switch (targetType) {
case 'String':
@ -342,8 +343,6 @@ class ApiClient {
return JobCommandDto.fromJson(value);
case 'JobCountsDto':
return JobCountsDto.fromJson(value);
case 'JobCreateDto':
return JobCreateDto.fromJson(value);
case 'JobName':
return JobNameTypeTransformer().decode(value);
case 'JobSettingsDto':
@ -358,6 +357,8 @@ class ApiClient {
return LicenseKeyDto.fromJson(value);
case 'LicenseResponseDto':
return LicenseResponseDto.fromJson(value);
case 'LoadTextualModelOnConnection':
return LoadTextualModelOnConnection.fromJson(value);
case 'LogLevel':
return LogLevelTypeTransformer().decode(value);
case 'LoginCredentialDto':
@ -366,12 +367,12 @@ class ApiClient {
return LoginResponseDto.fromJson(value);
case 'LogoutResponseDto':
return LogoutResponseDto.fromJson(value);
case 'ManualJobName':
return ManualJobNameTypeTransformer().decode(value);
case 'MapMarkerResponseDto':
return MapMarkerResponseDto.fromJson(value);
case 'MapReverseGeocodeResponseDto':
return MapReverseGeocodeResponseDto.fromJson(value);
case 'MapTheme':
return MapThemeTypeTransformer().decode(value);
case 'MemoriesResponse':
return MemoriesResponse.fromJson(value);
case 'MemoriesUpdate':
@ -436,8 +437,6 @@ class ApiClient {
return PurchaseUpdate.fromJson(value);
case 'QueueStatusDto':
return QueueStatusDto.fromJson(value);
case 'RandomSearchDto':
return RandomSearchDto.fromJson(value);
case 'RatingsResponse':
return RatingsResponse.fromJson(value);
case 'RatingsUpdate':
@ -588,8 +587,6 @@ class ApiClient {
return TranscodeHWAccelTypeTransformer().decode(value);
case 'TranscodePolicy':
return TranscodePolicyTypeTransformer().decode(value);
case 'TrashResponseDto':
return TrashResponseDto.fromJson(value);
case 'UpdateAlbumDto':
return UpdateAlbumDto.fromJson(value);
case 'UpdateAlbumUserDto':

View File

@ -14,30 +14,36 @@ class CLIPConfig {
/// Returns a new [CLIPConfig] instance.
CLIPConfig({
required this.enabled,
required this.loadTextualModelOnConnection,
required this.modelName,
});
bool enabled;
LoadTextualModelOnConnection loadTextualModelOnConnection;
String modelName;
@override
bool operator ==(Object other) => identical(this, other) || other is CLIPConfig &&
other.enabled == enabled &&
other.loadTextualModelOnConnection == loadTextualModelOnConnection &&
other.modelName == modelName;
@override
int get hashCode =>
// ignore: unnecessary_parenthesis
(enabled.hashCode) +
(loadTextualModelOnConnection.hashCode) +
(modelName.hashCode);
@override
String toString() => 'CLIPConfig[enabled=$enabled, modelName=$modelName]';
String toString() => 'CLIPConfig[enabled=$enabled, loadTextualModelOnConnection=$loadTextualModelOnConnection, modelName=$modelName]';
Map<String, dynamic> toJson() {
final json = <String, dynamic>{};
json[r'enabled'] = this.enabled;
json[r'loadTextualModelOnConnection'] = this.loadTextualModelOnConnection;
json[r'modelName'] = this.modelName;
return json;
}
@ -46,12 +52,12 @@ class CLIPConfig {
/// [value] if it's a [Map], null otherwise.
// ignore: prefer_constructors_over_static_methods
static CLIPConfig? fromJson(dynamic value) {
upgradeDto(value, "CLIPConfig");
if (value is Map) {
final json = value.cast<String, dynamic>();
return CLIPConfig(
enabled: mapValueOfType<bool>(json, r'enabled')!,
loadTextualModelOnConnection: LoadTextualModelOnConnection.fromJson(json[r'loadTextualModelOnConnection'])!,
modelName: mapValueOfType<String>(json, r'modelName')!,
);
}
@ -101,6 +107,7 @@ class CLIPConfig {
/// The list of required keys that must be present in a JSON.
static const requiredKeys = <String>{
'enabled',
'loadTextualModelOnConnection',
'modelName',
};
}

View File

@ -0,0 +1,107 @@
//
// AUTO-GENERATED FILE, DO NOT MODIFY!
//
// @dart=2.18
// ignore_for_file: unused_element, unused_import
// ignore_for_file: always_put_required_named_parameters_first
// ignore_for_file: constant_identifier_names
// ignore_for_file: lines_longer_than_80_chars
part of openapi.api;
class LoadTextualModelOnConnection {
/// Returns a new [LoadTextualModelOnConnection] instance.
LoadTextualModelOnConnection({
required this.enabled,
required this.ttl,
});
bool enabled;
/// Minimum value: 0
num ttl;
@override
bool operator ==(Object other) => identical(this, other) || other is LoadTextualModelOnConnection &&
other.enabled == enabled &&
other.ttl == ttl;
@override
int get hashCode =>
// ignore: unnecessary_parenthesis
(enabled.hashCode) +
(ttl.hashCode);
@override
String toString() => 'LoadTextualModelOnConnection[enabled=$enabled, ttl=$ttl]';
Map<String, dynamic> toJson() {
final json = <String, dynamic>{};
json[r'enabled'] = this.enabled;
json[r'ttl'] = this.ttl;
return json;
}
/// Returns a new [LoadTextualModelOnConnection] instance and imports its values from
/// [value] if it's a [Map], null otherwise.
// ignore: prefer_constructors_over_static_methods
static LoadTextualModelOnConnection? fromJson(dynamic value) {
if (value is Map) {
final json = value.cast<String, dynamic>();
return LoadTextualModelOnConnection(
enabled: mapValueOfType<bool>(json, r'enabled')!,
ttl: num.parse('${json[r'ttl']}'),
);
}
return null;
}
static List<LoadTextualModelOnConnection> listFromJson(dynamic json, {bool growable = false,}) {
final result = <LoadTextualModelOnConnection>[];
if (json is List && json.isNotEmpty) {
for (final row in json) {
final value = LoadTextualModelOnConnection.fromJson(row);
if (value != null) {
result.add(value);
}
}
}
return result.toList(growable: growable);
}
static Map<String, LoadTextualModelOnConnection> mapFromJson(dynamic json) {
final map = <String, LoadTextualModelOnConnection>{};
if (json is Map && json.isNotEmpty) {
json = json.cast<String, dynamic>(); // ignore: parameter_assignments
for (final entry in json.entries) {
final value = LoadTextualModelOnConnection.fromJson(entry.value);
if (value != null) {
map[entry.key] = value;
}
}
}
return map;
}
// maps a json object with a list of LoadTextualModelOnConnection-objects as value to a dart map
static Map<String, List<LoadTextualModelOnConnection>> mapListFromJson(dynamic json, {bool growable = false,}) {
final map = <String, List<LoadTextualModelOnConnection>>{};
if (json is Map && json.isNotEmpty) {
// ignore: parameter_assignments
json = json.cast<String, dynamic>();
for (final entry in json.entries) {
map[entry.key] = LoadTextualModelOnConnection.listFromJson(entry.value, growable: growable,);
}
}
return map;
}
/// The list of required keys that must be present in a JSON.
static const requiredKeys = <String>{
'enabled',
'ttl',
};
}

View File

@ -5296,8 +5296,8 @@
"name": "password",
"required": false,
"in": "query",
"example": "password",
"schema": {
"example": "password",
"type": "string"
}
},
@ -8642,12 +8642,16 @@
"enabled": {
"type": "boolean"
},
"loadTextualModelOnConnection": {
"$ref": "#/components/schemas/LoadTextualModelOnConnection"
},
"modelName": {
"type": "string"
}
},
"required": [
"enabled",
"loadTextualModelOnConnection",
"modelName"
],
"type": "object"
@ -9488,6 +9492,17 @@
],
"type": "object"
},
"LoadTextualModelOnConnection": {
"properties": {
"enabled": {
"type": "boolean"
}
},
"required": [
"enabled"
],
"type": "object"
},
"LogLevel": {
"enum": [
"verbose",

View File

@ -1150,8 +1150,13 @@ export type SystemConfigLoggingDto = {
enabled: boolean;
level: LogLevel;
};
export type LoadTextualModelOnConnection = {
enabled: boolean;
ttl: number;
};
export type ClipConfig = {
enabled: boolean;
loadTextualModelOnConnection: LoadTextualModelOnConnection;
modelName: string;
};
export type DuplicateDetectionConfig = {

View File

@ -58,6 +58,9 @@ export interface SystemConfig {
clip: {
enabled: boolean;
modelName: string;
loadTextualModelOnConnection: {
enabled: boolean;
};
};
duplicateDetection: {
enabled: boolean;
@ -205,6 +208,9 @@ export const defaults = Object.freeze<SystemConfig>({
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',
loadTextualModelOnConnection: {
enabled: false,
},
},
duplicateDetection: {
enabled: true,

View File

@ -1,6 +1,6 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
import { ValidateBoolean } from 'src/validation';
export class TaskConfig {
@ -14,7 +14,17 @@ export class ModelConfig extends TaskConfig {
modelName!: string;
}
export class CLIPConfig extends ModelConfig {}
export class LoadTextualModelOnConnection {
@ValidateBoolean()
enabled!: boolean;
}
export class CLIPConfig extends ModelConfig {
@Type(() => LoadTextualModelOnConnection)
@ValidateNested()
@IsObject()
loadTextualModelOnConnection!: LoadTextualModelOnConnection;
}
export class DuplicateDetectionConfig extends TaskConfig {
@IsNumber()

View File

@ -46,6 +46,11 @@ export interface Face {
score: number;
}
export enum LoadTextModelActions {
LOAD,
UNLOAD,
}
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
@ -54,4 +59,5 @@ export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise<void>;
}

View File

@ -8,6 +8,7 @@ import {
WebSocketServer,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { SystemConfigCore } from 'src/cores/system-config.core';
import {
ArgsOf,
ClientEventMap,
@ -18,6 +19,8 @@ import {
ServerEvents,
} from 'src/interfaces/event.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface';
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
import { AuthService } from 'src/services/auth.service';
import { Instrumentation } from 'src/utils/instrumentation';
import { handlePromiseError } from 'src/utils/misc';
@ -33,6 +36,7 @@ type EmitHandlers = Partial<{ [T in EmitEvent]: Array<EventItem<T>> }>;
@Injectable()
export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository {
private emitHandlers: EmitHandlers = {};
private configCore: SystemConfigCore;
@WebSocketServer()
private server?: Server;
@ -40,8 +44,11 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
constructor(
private moduleRef: ModuleRef,
@Inject(ILoggerRepository) private logger: ILoggerRepository,
@Inject(IMachineLearningRepository) private machineLearningRepository: IMachineLearningRepository,
@Inject(ISystemMetadataRepository) systemMetadataRepository: ISystemMetadataRepository,
) {
this.logger.setContext(EventRepository.name);
this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger);
}
afterInit(server: Server) {
@ -63,6 +70,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
queryParams: {},
metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' },
});
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
try {
console.log(this.server);
this.machineLearningRepository.prepareTextModel(
machineLearning.url,
machineLearning.clip,
LoadTextModelActions.LOAD,
);
} catch (error) {
this.logger.warn(error);
}
}
}
await client.join(auth.user.id);
if (auth.session) {
await client.join(auth.session.id);
@ -78,6 +100,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
async handleDisconnect(client: Socket) {
this.logger.log(`Websocket Disconnect: ${client.id}`);
await client.leave(client.nsp.name);
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) {
try {
this.machineLearningRepository.prepareTextModel(
machineLearning.url,
machineLearning.clip,
LoadTextModelActions.UNLOAD,
);
this.logger.debug('sent request to unload text model');
} catch (error) {
this.logger.warn(error);
}
}
}
}
on<T extends EmitEvent>(item: EventItem<T>): void {

View File

@ -7,6 +7,7 @@ import {
FaceDetectionOptions,
FacialRecognitionResponse,
IMachineLearningRepository,
LoadTextModelActions,
MachineLearningRequest,
ModelPayload,
ModelTask,
@ -20,13 +21,9 @@ const errorPrefix = 'Machine learning request';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
const formData = await this.getFormData(config, payload);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
},
);
const res = await this.fetchData(url, '/predict', formData);
if (res.status >= 400) {
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
@ -34,6 +31,30 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return res.json();
}
private async fetchData(url: string, path: string, formData?: FormData): Promise<Response> {
const res = await fetch(new URL(path, url), { method: 'POST', body: formData }).catch((error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
});
return res;
}
private prepareTextModelUrl: Record<LoadTextModelActions, string> = {
[LoadTextModelActions.LOAD]: '/load',
[LoadTextModelActions.UNLOAD]: '/unload',
};
async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) {
try {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
const formData = await this.getFormData(request);
const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData);
if (res.status >= 400) {
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
}
} catch (error) {}
}
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
@ -61,16 +82,17 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return response[ModelTask.SEARCH];
}
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
private async getFormData(config: MachineLearningRequest, payload?: ModelPayload): Promise<FormData> {
const formData = new FormData();
formData.append('entries', JSON.stringify(config));
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
if (payload) {
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
}
}
return formData;

View File

@ -75,6 +75,21 @@
</FormatMessage>
</p>
</SettingInputField>
<SettingAccordion
key="Preload clip model"
title={$t('admin.machine_learning_preload_model')}
subtitle={$t('admin.machine_learning_preload_model_setting_description')}
>
<div class="ml-4 mt-4 flex flex-col gap-4">
<SettingSwitch
title={$t('admin.machine_learning_preload_model_enabled')}
subtitle={$t('admin.machine_learning_preload_model_enabled_description')}
bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled}
/>
</div>
</SettingAccordion>
</div>
</SettingAccordion>

View File

@ -119,6 +119,12 @@
"machine_learning_min_detection_score_description": "Minimum confidence score for a face to be detected from 0-1. Lower values will detect more faces but may result in false positives.",
"machine_learning_min_recognized_faces": "Minimum recognized faces",
"machine_learning_min_recognized_faces_description": "The minimum number of recognized faces for a person to be created. Increasing this makes Facial Recognition more precise at the cost of increasing the chance that a face is not assigned to a person.",
"machine_learning_preload_model": "Preload model",
"machine_learning_preload_model_enabled": "Enable preload model",
"machine_learning_preload_model_enabled_description": "Preload the textual model during the connexion instead of during the first search",
"machine_learning_preload_model_setting_description": "Preload the textual model during the connexion",
"machine_learning_preload_model_ttl": "Inactivity time before a model in unloaded",
"machine_learning_preload_model_ttl_description": "Preload the textual model during the connexion",
"machine_learning_settings": "Machine Learning Settings",
"machine_learning_settings_description": "Manage machine learning features and settings",
"machine_learning_smart_search": "Smart Search",

View File

@ -35,6 +35,9 @@ const websocket: Socket<Events> = io({
reconnection: true,
forceNew: true,
autoConnect: false,
query: {
background: false,
},
});
export const websocketStore = {