mirror of
https://github.com/immich-app/immich.git
synced 2025-05-31 20:26:27 -04:00
Merge 59300d2097e5c2f8e81b94ea6113750f3d14c9ff into b0bcc6c03ecedff0d756a0e54352586672cea761
This commit is contained in:
commit
5701be51aa
@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Callable, Iterator
|
|||||||
from zipfile import BadZipFile
|
from zipfile import BadZipFile
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import Depends, FastAPI, File, Form, HTTPException
|
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
@ -124,6 +124,23 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
|
|||||||
raise HTTPException(422, "Invalid request format.")
|
raise HTTPException(422, "Invalid request format.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_entry(entries: str = Form()) -> InferenceEntry:
|
||||||
|
try:
|
||||||
|
request: PipelineRequest = orjson.loads(entries)
|
||||||
|
for task, types in request.items():
|
||||||
|
for type, entry in types.items():
|
||||||
|
parsed: InferenceEntry = {
|
||||||
|
"name": entry["modelName"],
|
||||||
|
"task": task,
|
||||||
|
"type": type,
|
||||||
|
"options": entry.get("options", {}),
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
||||||
|
log.error(f"Invalid request format: {e}")
|
||||||
|
raise HTTPException(422, "Invalid request format.")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
@ -137,6 +154,20 @@ def ping() -> str:
|
|||||||
return "pong"
|
return "pong"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/load", response_model=TextResponse)
|
||||||
|
async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
|
||||||
|
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
|
||||||
|
model = await load(model)
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/unload", response_model=TextResponse)
|
||||||
|
async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None:
|
||||||
|
await model_cache.unload(entry["name"], entry["type"], entry["task"])
|
||||||
|
print("unload")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/predict", dependencies=[Depends(update_state)])
|
@app.post("/predict", dependencies=[Depends(update_state)])
|
||||||
async def predict(
|
async def predict(
|
||||||
entries: InferenceEntries = Depends(get_entries),
|
entries: InferenceEntries = Depends(get_entries),
|
||||||
|
@ -58,3 +58,10 @@ class ModelCache:
|
|||||||
async def revalidate(self, key: str, ttl: int | None) -> None:
|
async def revalidate(self, key: str, ttl: int | None) -> None:
|
||||||
if ttl is not None and key in self.cache._handlers:
|
if ttl is not None and key in self.cache._handlers:
|
||||||
await self.cache.expire(key, ttl)
|
await self.cache.expire(key, ttl)
|
||||||
|
|
||||||
|
async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None:
|
||||||
|
key = f"{model_name}{model_type}{model_task}"
|
||||||
|
async with OptimisticLock(self.cache, key):
|
||||||
|
value = await self.cache.get(key)
|
||||||
|
if value is not None:
|
||||||
|
await self.cache.delete(key)
|
||||||
|
10
mobile/openapi/README.md
generated
10
mobile/openapi/README.md
generated
@ -116,7 +116,6 @@ Class | Method | HTTP request | Description
|
|||||||
*AuthenticationApi* | [**signUpAdmin**](doc//AuthenticationApi.md#signupadmin) | **POST** /auth/admin-sign-up |
|
*AuthenticationApi* | [**signUpAdmin**](doc//AuthenticationApi.md#signupadmin) | **POST** /auth/admin-sign-up |
|
||||||
*AuthenticationApi* | [**validateAccessToken**](doc//AuthenticationApi.md#validateaccesstoken) | **POST** /auth/validateToken |
|
*AuthenticationApi* | [**validateAccessToken**](doc//AuthenticationApi.md#validateaccesstoken) | **POST** /auth/validateToken |
|
||||||
*DeprecatedApi* | [**getPersonAssets**](doc//DeprecatedApi.md#getpersonassets) | **GET** /people/{id}/assets |
|
*DeprecatedApi* | [**getPersonAssets**](doc//DeprecatedApi.md#getpersonassets) | **GET** /people/{id}/assets |
|
||||||
*DeprecatedApi* | [**getRandom**](doc//DeprecatedApi.md#getrandom) | **GET** /assets/random |
|
|
||||||
*DownloadApi* | [**downloadArchive**](doc//DownloadApi.md#downloadarchive) | **POST** /download/archive |
|
*DownloadApi* | [**downloadArchive**](doc//DownloadApi.md#downloadarchive) | **POST** /download/archive |
|
||||||
*DownloadApi* | [**getDownloadInfo**](doc//DownloadApi.md#getdownloadinfo) | **POST** /download/info |
|
*DownloadApi* | [**getDownloadInfo**](doc//DownloadApi.md#getdownloadinfo) | **POST** /download/info |
|
||||||
*DuplicatesApi* | [**getAssetDuplicates**](doc//DuplicatesApi.md#getassetduplicates) | **GET** /duplicates |
|
*DuplicatesApi* | [**getAssetDuplicates**](doc//DuplicatesApi.md#getassetduplicates) | **GET** /duplicates |
|
||||||
@ -125,7 +124,6 @@ Class | Method | HTTP request | Description
|
|||||||
*FileReportsApi* | [**fixAuditFiles**](doc//FileReportsApi.md#fixauditfiles) | **POST** /reports/fix |
|
*FileReportsApi* | [**fixAuditFiles**](doc//FileReportsApi.md#fixauditfiles) | **POST** /reports/fix |
|
||||||
*FileReportsApi* | [**getAuditFiles**](doc//FileReportsApi.md#getauditfiles) | **GET** /reports |
|
*FileReportsApi* | [**getAuditFiles**](doc//FileReportsApi.md#getauditfiles) | **GET** /reports |
|
||||||
*FileReportsApi* | [**getFileChecksums**](doc//FileReportsApi.md#getfilechecksums) | **POST** /reports/checksum |
|
*FileReportsApi* | [**getFileChecksums**](doc//FileReportsApi.md#getfilechecksums) | **POST** /reports/checksum |
|
||||||
*JobsApi* | [**createJob**](doc//JobsApi.md#createjob) | **POST** /jobs |
|
|
||||||
*JobsApi* | [**getAllJobsStatus**](doc//JobsApi.md#getalljobsstatus) | **GET** /jobs |
|
*JobsApi* | [**getAllJobsStatus**](doc//JobsApi.md#getalljobsstatus) | **GET** /jobs |
|
||||||
*JobsApi* | [**sendJobCommand**](doc//JobsApi.md#sendjobcommand) | **PUT** /jobs/{id} |
|
*JobsApi* | [**sendJobCommand**](doc//JobsApi.md#sendjobcommand) | **PUT** /jobs/{id} |
|
||||||
*LibrariesApi* | [**createLibrary**](doc//LibrariesApi.md#createlibrary) | **POST** /libraries |
|
*LibrariesApi* | [**createLibrary**](doc//LibrariesApi.md#createlibrary) | **POST** /libraries |
|
||||||
@ -137,6 +135,7 @@ Class | Method | HTTP request | Description
|
|||||||
*LibrariesApi* | [**updateLibrary**](doc//LibrariesApi.md#updatelibrary) | **PUT** /libraries/{id} |
|
*LibrariesApi* | [**updateLibrary**](doc//LibrariesApi.md#updatelibrary) | **PUT** /libraries/{id} |
|
||||||
*LibrariesApi* | [**validate**](doc//LibrariesApi.md#validate) | **POST** /libraries/{id}/validate |
|
*LibrariesApi* | [**validate**](doc//LibrariesApi.md#validate) | **POST** /libraries/{id}/validate |
|
||||||
*MapApi* | [**getMapMarkers**](doc//MapApi.md#getmapmarkers) | **GET** /map/markers |
|
*MapApi* | [**getMapMarkers**](doc//MapApi.md#getmapmarkers) | **GET** /map/markers |
|
||||||
|
*MapApi* | [**getMapStyle**](doc//MapApi.md#getmapstyle) | **GET** /map/style.json |
|
||||||
*MapApi* | [**reverseGeocode**](doc//MapApi.md#reversegeocode) | **GET** /map/reverse-geocode |
|
*MapApi* | [**reverseGeocode**](doc//MapApi.md#reversegeocode) | **GET** /map/reverse-geocode |
|
||||||
*MemoriesApi* | [**addMemoryAssets**](doc//MemoriesApi.md#addmemoryassets) | **PUT** /memories/{id}/assets |
|
*MemoriesApi* | [**addMemoryAssets**](doc//MemoriesApi.md#addmemoryassets) | **PUT** /memories/{id}/assets |
|
||||||
*MemoriesApi* | [**createMemory**](doc//MemoriesApi.md#creatememory) | **POST** /memories |
|
*MemoriesApi* | [**createMemory**](doc//MemoriesApi.md#creatememory) | **POST** /memories |
|
||||||
@ -171,7 +170,6 @@ Class | Method | HTTP request | Description
|
|||||||
*SearchApi* | [**searchMetadata**](doc//SearchApi.md#searchmetadata) | **POST** /search/metadata |
|
*SearchApi* | [**searchMetadata**](doc//SearchApi.md#searchmetadata) | **POST** /search/metadata |
|
||||||
*SearchApi* | [**searchPerson**](doc//SearchApi.md#searchperson) | **GET** /search/person |
|
*SearchApi* | [**searchPerson**](doc//SearchApi.md#searchperson) | **GET** /search/person |
|
||||||
*SearchApi* | [**searchPlaces**](doc//SearchApi.md#searchplaces) | **GET** /search/places |
|
*SearchApi* | [**searchPlaces**](doc//SearchApi.md#searchplaces) | **GET** /search/places |
|
||||||
*SearchApi* | [**searchRandom**](doc//SearchApi.md#searchrandom) | **POST** /search/random |
|
|
||||||
*SearchApi* | [**searchSmart**](doc//SearchApi.md#searchsmart) | **POST** /search/smart |
|
*SearchApi* | [**searchSmart**](doc//SearchApi.md#searchsmart) | **POST** /search/smart |
|
||||||
*ServerApi* | [**deleteServerLicense**](doc//ServerApi.md#deleteserverlicense) | **DELETE** /server/license |
|
*ServerApi* | [**deleteServerLicense**](doc//ServerApi.md#deleteserverlicense) | **DELETE** /server/license |
|
||||||
*ServerApi* | [**getAboutInfo**](doc//ServerApi.md#getaboutinfo) | **GET** /server/about |
|
*ServerApi* | [**getAboutInfo**](doc//ServerApi.md#getaboutinfo) | **GET** /server/about |
|
||||||
@ -332,7 +330,6 @@ Class | Method | HTTP request | Description
|
|||||||
- [JobCommand](doc//JobCommand.md)
|
- [JobCommand](doc//JobCommand.md)
|
||||||
- [JobCommandDto](doc//JobCommandDto.md)
|
- [JobCommandDto](doc//JobCommandDto.md)
|
||||||
- [JobCountsDto](doc//JobCountsDto.md)
|
- [JobCountsDto](doc//JobCountsDto.md)
|
||||||
- [JobCreateDto](doc//JobCreateDto.md)
|
|
||||||
- [JobName](doc//JobName.md)
|
- [JobName](doc//JobName.md)
|
||||||
- [JobSettingsDto](doc//JobSettingsDto.md)
|
- [JobSettingsDto](doc//JobSettingsDto.md)
|
||||||
- [JobStatusDto](doc//JobStatusDto.md)
|
- [JobStatusDto](doc//JobStatusDto.md)
|
||||||
@ -340,13 +337,14 @@ Class | Method | HTTP request | Description
|
|||||||
- [LibraryStatsResponseDto](doc//LibraryStatsResponseDto.md)
|
- [LibraryStatsResponseDto](doc//LibraryStatsResponseDto.md)
|
||||||
- [LicenseKeyDto](doc//LicenseKeyDto.md)
|
- [LicenseKeyDto](doc//LicenseKeyDto.md)
|
||||||
- [LicenseResponseDto](doc//LicenseResponseDto.md)
|
- [LicenseResponseDto](doc//LicenseResponseDto.md)
|
||||||
|
- [LoadTextualModelOnConnection](doc//LoadTextualModelOnConnection.md)
|
||||||
- [LogLevel](doc//LogLevel.md)
|
- [LogLevel](doc//LogLevel.md)
|
||||||
- [LoginCredentialDto](doc//LoginCredentialDto.md)
|
- [LoginCredentialDto](doc//LoginCredentialDto.md)
|
||||||
- [LoginResponseDto](doc//LoginResponseDto.md)
|
- [LoginResponseDto](doc//LoginResponseDto.md)
|
||||||
- [LogoutResponseDto](doc//LogoutResponseDto.md)
|
- [LogoutResponseDto](doc//LogoutResponseDto.md)
|
||||||
- [ManualJobName](doc//ManualJobName.md)
|
|
||||||
- [MapMarkerResponseDto](doc//MapMarkerResponseDto.md)
|
- [MapMarkerResponseDto](doc//MapMarkerResponseDto.md)
|
||||||
- [MapReverseGeocodeResponseDto](doc//MapReverseGeocodeResponseDto.md)
|
- [MapReverseGeocodeResponseDto](doc//MapReverseGeocodeResponseDto.md)
|
||||||
|
- [MapTheme](doc//MapTheme.md)
|
||||||
- [MemoriesResponse](doc//MemoriesResponse.md)
|
- [MemoriesResponse](doc//MemoriesResponse.md)
|
||||||
- [MemoriesUpdate](doc//MemoriesUpdate.md)
|
- [MemoriesUpdate](doc//MemoriesUpdate.md)
|
||||||
- [MemoryCreateDto](doc//MemoryCreateDto.md)
|
- [MemoryCreateDto](doc//MemoryCreateDto.md)
|
||||||
@ -379,7 +377,6 @@ Class | Method | HTTP request | Description
|
|||||||
- [PurchaseResponse](doc//PurchaseResponse.md)
|
- [PurchaseResponse](doc//PurchaseResponse.md)
|
||||||
- [PurchaseUpdate](doc//PurchaseUpdate.md)
|
- [PurchaseUpdate](doc//PurchaseUpdate.md)
|
||||||
- [QueueStatusDto](doc//QueueStatusDto.md)
|
- [QueueStatusDto](doc//QueueStatusDto.md)
|
||||||
- [RandomSearchDto](doc//RandomSearchDto.md)
|
|
||||||
- [RatingsResponse](doc//RatingsResponse.md)
|
- [RatingsResponse](doc//RatingsResponse.md)
|
||||||
- [RatingsUpdate](doc//RatingsUpdate.md)
|
- [RatingsUpdate](doc//RatingsUpdate.md)
|
||||||
- [ReactionLevel](doc//ReactionLevel.md)
|
- [ReactionLevel](doc//ReactionLevel.md)
|
||||||
@ -455,7 +452,6 @@ Class | Method | HTTP request | Description
|
|||||||
- [ToneMapping](doc//ToneMapping.md)
|
- [ToneMapping](doc//ToneMapping.md)
|
||||||
- [TranscodeHWAccel](doc//TranscodeHWAccel.md)
|
- [TranscodeHWAccel](doc//TranscodeHWAccel.md)
|
||||||
- [TranscodePolicy](doc//TranscodePolicy.md)
|
- [TranscodePolicy](doc//TranscodePolicy.md)
|
||||||
- [TrashResponseDto](doc//TrashResponseDto.md)
|
|
||||||
- [UpdateAlbumDto](doc//UpdateAlbumDto.md)
|
- [UpdateAlbumDto](doc//UpdateAlbumDto.md)
|
||||||
- [UpdateAlbumUserDto](doc//UpdateAlbumUserDto.md)
|
- [UpdateAlbumUserDto](doc//UpdateAlbumUserDto.md)
|
||||||
- [UpdateAssetDto](doc//UpdateAssetDto.md)
|
- [UpdateAssetDto](doc//UpdateAssetDto.md)
|
||||||
|
6
mobile/openapi/lib/api.dart
generated
6
mobile/openapi/lib/api.dart
generated
@ -144,7 +144,6 @@ part 'model/image_format.dart';
|
|||||||
part 'model/job_command.dart';
|
part 'model/job_command.dart';
|
||||||
part 'model/job_command_dto.dart';
|
part 'model/job_command_dto.dart';
|
||||||
part 'model/job_counts_dto.dart';
|
part 'model/job_counts_dto.dart';
|
||||||
part 'model/job_create_dto.dart';
|
|
||||||
part 'model/job_name.dart';
|
part 'model/job_name.dart';
|
||||||
part 'model/job_settings_dto.dart';
|
part 'model/job_settings_dto.dart';
|
||||||
part 'model/job_status_dto.dart';
|
part 'model/job_status_dto.dart';
|
||||||
@ -152,13 +151,14 @@ part 'model/library_response_dto.dart';
|
|||||||
part 'model/library_stats_response_dto.dart';
|
part 'model/library_stats_response_dto.dart';
|
||||||
part 'model/license_key_dto.dart';
|
part 'model/license_key_dto.dart';
|
||||||
part 'model/license_response_dto.dart';
|
part 'model/license_response_dto.dart';
|
||||||
|
part 'model/load_textual_model_on_connection.dart';
|
||||||
part 'model/log_level.dart';
|
part 'model/log_level.dart';
|
||||||
part 'model/login_credential_dto.dart';
|
part 'model/login_credential_dto.dart';
|
||||||
part 'model/login_response_dto.dart';
|
part 'model/login_response_dto.dart';
|
||||||
part 'model/logout_response_dto.dart';
|
part 'model/logout_response_dto.dart';
|
||||||
part 'model/manual_job_name.dart';
|
|
||||||
part 'model/map_marker_response_dto.dart';
|
part 'model/map_marker_response_dto.dart';
|
||||||
part 'model/map_reverse_geocode_response_dto.dart';
|
part 'model/map_reverse_geocode_response_dto.dart';
|
||||||
|
part 'model/map_theme.dart';
|
||||||
part 'model/memories_response.dart';
|
part 'model/memories_response.dart';
|
||||||
part 'model/memories_update.dart';
|
part 'model/memories_update.dart';
|
||||||
part 'model/memory_create_dto.dart';
|
part 'model/memory_create_dto.dart';
|
||||||
@ -191,7 +191,6 @@ part 'model/places_response_dto.dart';
|
|||||||
part 'model/purchase_response.dart';
|
part 'model/purchase_response.dart';
|
||||||
part 'model/purchase_update.dart';
|
part 'model/purchase_update.dart';
|
||||||
part 'model/queue_status_dto.dart';
|
part 'model/queue_status_dto.dart';
|
||||||
part 'model/random_search_dto.dart';
|
|
||||||
part 'model/ratings_response.dart';
|
part 'model/ratings_response.dart';
|
||||||
part 'model/ratings_update.dart';
|
part 'model/ratings_update.dart';
|
||||||
part 'model/reaction_level.dart';
|
part 'model/reaction_level.dart';
|
||||||
@ -267,7 +266,6 @@ part 'model/time_bucket_size.dart';
|
|||||||
part 'model/tone_mapping.dart';
|
part 'model/tone_mapping.dart';
|
||||||
part 'model/transcode_hw_accel.dart';
|
part 'model/transcode_hw_accel.dart';
|
||||||
part 'model/transcode_policy.dart';
|
part 'model/transcode_policy.dart';
|
||||||
part 'model/trash_response_dto.dart';
|
|
||||||
part 'model/update_album_dto.dart';
|
part 'model/update_album_dto.dart';
|
||||||
part 'model/update_album_user_dto.dart';
|
part 'model/update_album_user_dto.dart';
|
||||||
part 'model/update_asset_dto.dart';
|
part 'model/update_asset_dto.dart';
|
||||||
|
13
mobile/openapi/lib/api_client.dart
generated
13
mobile/openapi/lib/api_client.dart
generated
@ -166,6 +166,7 @@ class ApiClient {
|
|||||||
|
|
||||||
/// Returns a native instance of an OpenAPI class matching the [specified type][targetType].
|
/// Returns a native instance of an OpenAPI class matching the [specified type][targetType].
|
||||||
static dynamic fromJson(dynamic value, String targetType, {bool growable = false,}) {
|
static dynamic fromJson(dynamic value, String targetType, {bool growable = false,}) {
|
||||||
|
upgradeDto(value, targetType);
|
||||||
try {
|
try {
|
||||||
switch (targetType) {
|
switch (targetType) {
|
||||||
case 'String':
|
case 'String':
|
||||||
@ -342,8 +343,6 @@ class ApiClient {
|
|||||||
return JobCommandDto.fromJson(value);
|
return JobCommandDto.fromJson(value);
|
||||||
case 'JobCountsDto':
|
case 'JobCountsDto':
|
||||||
return JobCountsDto.fromJson(value);
|
return JobCountsDto.fromJson(value);
|
||||||
case 'JobCreateDto':
|
|
||||||
return JobCreateDto.fromJson(value);
|
|
||||||
case 'JobName':
|
case 'JobName':
|
||||||
return JobNameTypeTransformer().decode(value);
|
return JobNameTypeTransformer().decode(value);
|
||||||
case 'JobSettingsDto':
|
case 'JobSettingsDto':
|
||||||
@ -358,6 +357,8 @@ class ApiClient {
|
|||||||
return LicenseKeyDto.fromJson(value);
|
return LicenseKeyDto.fromJson(value);
|
||||||
case 'LicenseResponseDto':
|
case 'LicenseResponseDto':
|
||||||
return LicenseResponseDto.fromJson(value);
|
return LicenseResponseDto.fromJson(value);
|
||||||
|
case 'LoadTextualModelOnConnection':
|
||||||
|
return LoadTextualModelOnConnection.fromJson(value);
|
||||||
case 'LogLevel':
|
case 'LogLevel':
|
||||||
return LogLevelTypeTransformer().decode(value);
|
return LogLevelTypeTransformer().decode(value);
|
||||||
case 'LoginCredentialDto':
|
case 'LoginCredentialDto':
|
||||||
@ -366,12 +367,12 @@ class ApiClient {
|
|||||||
return LoginResponseDto.fromJson(value);
|
return LoginResponseDto.fromJson(value);
|
||||||
case 'LogoutResponseDto':
|
case 'LogoutResponseDto':
|
||||||
return LogoutResponseDto.fromJson(value);
|
return LogoutResponseDto.fromJson(value);
|
||||||
case 'ManualJobName':
|
|
||||||
return ManualJobNameTypeTransformer().decode(value);
|
|
||||||
case 'MapMarkerResponseDto':
|
case 'MapMarkerResponseDto':
|
||||||
return MapMarkerResponseDto.fromJson(value);
|
return MapMarkerResponseDto.fromJson(value);
|
||||||
case 'MapReverseGeocodeResponseDto':
|
case 'MapReverseGeocodeResponseDto':
|
||||||
return MapReverseGeocodeResponseDto.fromJson(value);
|
return MapReverseGeocodeResponseDto.fromJson(value);
|
||||||
|
case 'MapTheme':
|
||||||
|
return MapThemeTypeTransformer().decode(value);
|
||||||
case 'MemoriesResponse':
|
case 'MemoriesResponse':
|
||||||
return MemoriesResponse.fromJson(value);
|
return MemoriesResponse.fromJson(value);
|
||||||
case 'MemoriesUpdate':
|
case 'MemoriesUpdate':
|
||||||
@ -436,8 +437,6 @@ class ApiClient {
|
|||||||
return PurchaseUpdate.fromJson(value);
|
return PurchaseUpdate.fromJson(value);
|
||||||
case 'QueueStatusDto':
|
case 'QueueStatusDto':
|
||||||
return QueueStatusDto.fromJson(value);
|
return QueueStatusDto.fromJson(value);
|
||||||
case 'RandomSearchDto':
|
|
||||||
return RandomSearchDto.fromJson(value);
|
|
||||||
case 'RatingsResponse':
|
case 'RatingsResponse':
|
||||||
return RatingsResponse.fromJson(value);
|
return RatingsResponse.fromJson(value);
|
||||||
case 'RatingsUpdate':
|
case 'RatingsUpdate':
|
||||||
@ -588,8 +587,6 @@ class ApiClient {
|
|||||||
return TranscodeHWAccelTypeTransformer().decode(value);
|
return TranscodeHWAccelTypeTransformer().decode(value);
|
||||||
case 'TranscodePolicy':
|
case 'TranscodePolicy':
|
||||||
return TranscodePolicyTypeTransformer().decode(value);
|
return TranscodePolicyTypeTransformer().decode(value);
|
||||||
case 'TrashResponseDto':
|
|
||||||
return TrashResponseDto.fromJson(value);
|
|
||||||
case 'UpdateAlbumDto':
|
case 'UpdateAlbumDto':
|
||||||
return UpdateAlbumDto.fromJson(value);
|
return UpdateAlbumDto.fromJson(value);
|
||||||
case 'UpdateAlbumUserDto':
|
case 'UpdateAlbumUserDto':
|
||||||
|
11
mobile/openapi/lib/model/clip_config.dart
generated
11
mobile/openapi/lib/model/clip_config.dart
generated
@ -14,30 +14,36 @@ class CLIPConfig {
|
|||||||
/// Returns a new [CLIPConfig] instance.
|
/// Returns a new [CLIPConfig] instance.
|
||||||
CLIPConfig({
|
CLIPConfig({
|
||||||
required this.enabled,
|
required this.enabled,
|
||||||
|
required this.loadTextualModelOnConnection,
|
||||||
required this.modelName,
|
required this.modelName,
|
||||||
});
|
});
|
||||||
|
|
||||||
bool enabled;
|
bool enabled;
|
||||||
|
|
||||||
|
LoadTextualModelOnConnection loadTextualModelOnConnection;
|
||||||
|
|
||||||
String modelName;
|
String modelName;
|
||||||
|
|
||||||
@override
|
@override
|
||||||
bool operator ==(Object other) => identical(this, other) || other is CLIPConfig &&
|
bool operator ==(Object other) => identical(this, other) || other is CLIPConfig &&
|
||||||
other.enabled == enabled &&
|
other.enabled == enabled &&
|
||||||
|
other.loadTextualModelOnConnection == loadTextualModelOnConnection &&
|
||||||
other.modelName == modelName;
|
other.modelName == modelName;
|
||||||
|
|
||||||
@override
|
@override
|
||||||
int get hashCode =>
|
int get hashCode =>
|
||||||
// ignore: unnecessary_parenthesis
|
// ignore: unnecessary_parenthesis
|
||||||
(enabled.hashCode) +
|
(enabled.hashCode) +
|
||||||
|
(loadTextualModelOnConnection.hashCode) +
|
||||||
(modelName.hashCode);
|
(modelName.hashCode);
|
||||||
|
|
||||||
@override
|
@override
|
||||||
String toString() => 'CLIPConfig[enabled=$enabled, modelName=$modelName]';
|
String toString() => 'CLIPConfig[enabled=$enabled, loadTextualModelOnConnection=$loadTextualModelOnConnection, modelName=$modelName]';
|
||||||
|
|
||||||
Map<String, dynamic> toJson() {
|
Map<String, dynamic> toJson() {
|
||||||
final json = <String, dynamic>{};
|
final json = <String, dynamic>{};
|
||||||
json[r'enabled'] = this.enabled;
|
json[r'enabled'] = this.enabled;
|
||||||
|
json[r'loadTextualModelOnConnection'] = this.loadTextualModelOnConnection;
|
||||||
json[r'modelName'] = this.modelName;
|
json[r'modelName'] = this.modelName;
|
||||||
return json;
|
return json;
|
||||||
}
|
}
|
||||||
@ -46,12 +52,12 @@ class CLIPConfig {
|
|||||||
/// [value] if it's a [Map], null otherwise.
|
/// [value] if it's a [Map], null otherwise.
|
||||||
// ignore: prefer_constructors_over_static_methods
|
// ignore: prefer_constructors_over_static_methods
|
||||||
static CLIPConfig? fromJson(dynamic value) {
|
static CLIPConfig? fromJson(dynamic value) {
|
||||||
upgradeDto(value, "CLIPConfig");
|
|
||||||
if (value is Map) {
|
if (value is Map) {
|
||||||
final json = value.cast<String, dynamic>();
|
final json = value.cast<String, dynamic>();
|
||||||
|
|
||||||
return CLIPConfig(
|
return CLIPConfig(
|
||||||
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
||||||
|
loadTextualModelOnConnection: LoadTextualModelOnConnection.fromJson(json[r'loadTextualModelOnConnection'])!,
|
||||||
modelName: mapValueOfType<String>(json, r'modelName')!,
|
modelName: mapValueOfType<String>(json, r'modelName')!,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -101,6 +107,7 @@ class CLIPConfig {
|
|||||||
/// The list of required keys that must be present in a JSON.
|
/// The list of required keys that must be present in a JSON.
|
||||||
static const requiredKeys = <String>{
|
static const requiredKeys = <String>{
|
||||||
'enabled',
|
'enabled',
|
||||||
|
'loadTextualModelOnConnection',
|
||||||
'modelName',
|
'modelName',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
107
mobile/openapi/lib/model/load_textual_model_on_connection.dart
generated
Normal file
107
mobile/openapi/lib/model/load_textual_model_on_connection.dart
generated
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
//
|
||||||
|
// AUTO-GENERATED FILE, DO NOT MODIFY!
|
||||||
|
//
|
||||||
|
// @dart=2.18
|
||||||
|
|
||||||
|
// ignore_for_file: unused_element, unused_import
|
||||||
|
// ignore_for_file: always_put_required_named_parameters_first
|
||||||
|
// ignore_for_file: constant_identifier_names
|
||||||
|
// ignore_for_file: lines_longer_than_80_chars
|
||||||
|
|
||||||
|
part of openapi.api;
|
||||||
|
|
||||||
|
class LoadTextualModelOnConnection {
|
||||||
|
/// Returns a new [LoadTextualModelOnConnection] instance.
|
||||||
|
LoadTextualModelOnConnection({
|
||||||
|
required this.enabled,
|
||||||
|
required this.ttl,
|
||||||
|
});
|
||||||
|
|
||||||
|
bool enabled;
|
||||||
|
|
||||||
|
/// Minimum value: 0
|
||||||
|
num ttl;
|
||||||
|
|
||||||
|
@override
|
||||||
|
bool operator ==(Object other) => identical(this, other) || other is LoadTextualModelOnConnection &&
|
||||||
|
other.enabled == enabled &&
|
||||||
|
other.ttl == ttl;
|
||||||
|
|
||||||
|
@override
|
||||||
|
int get hashCode =>
|
||||||
|
// ignore: unnecessary_parenthesis
|
||||||
|
(enabled.hashCode) +
|
||||||
|
(ttl.hashCode);
|
||||||
|
|
||||||
|
@override
|
||||||
|
String toString() => 'LoadTextualModelOnConnection[enabled=$enabled, ttl=$ttl]';
|
||||||
|
|
||||||
|
Map<String, dynamic> toJson() {
|
||||||
|
final json = <String, dynamic>{};
|
||||||
|
json[r'enabled'] = this.enabled;
|
||||||
|
json[r'ttl'] = this.ttl;
|
||||||
|
return json;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a new [LoadTextualModelOnConnection] instance and imports its values from
|
||||||
|
/// [value] if it's a [Map], null otherwise.
|
||||||
|
// ignore: prefer_constructors_over_static_methods
|
||||||
|
static LoadTextualModelOnConnection? fromJson(dynamic value) {
|
||||||
|
if (value is Map) {
|
||||||
|
final json = value.cast<String, dynamic>();
|
||||||
|
|
||||||
|
return LoadTextualModelOnConnection(
|
||||||
|
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
||||||
|
ttl: num.parse('${json[r'ttl']}'),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
static List<LoadTextualModelOnConnection> listFromJson(dynamic json, {bool growable = false,}) {
|
||||||
|
final result = <LoadTextualModelOnConnection>[];
|
||||||
|
if (json is List && json.isNotEmpty) {
|
||||||
|
for (final row in json) {
|
||||||
|
final value = LoadTextualModelOnConnection.fromJson(row);
|
||||||
|
if (value != null) {
|
||||||
|
result.add(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.toList(growable: growable);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Map<String, LoadTextualModelOnConnection> mapFromJson(dynamic json) {
|
||||||
|
final map = <String, LoadTextualModelOnConnection>{};
|
||||||
|
if (json is Map && json.isNotEmpty) {
|
||||||
|
json = json.cast<String, dynamic>(); // ignore: parameter_assignments
|
||||||
|
for (final entry in json.entries) {
|
||||||
|
final value = LoadTextualModelOnConnection.fromJson(entry.value);
|
||||||
|
if (value != null) {
|
||||||
|
map[entry.key] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
// maps a json object with a list of LoadTextualModelOnConnection-objects as value to a dart map
|
||||||
|
static Map<String, List<LoadTextualModelOnConnection>> mapListFromJson(dynamic json, {bool growable = false,}) {
|
||||||
|
final map = <String, List<LoadTextualModelOnConnection>>{};
|
||||||
|
if (json is Map && json.isNotEmpty) {
|
||||||
|
// ignore: parameter_assignments
|
||||||
|
json = json.cast<String, dynamic>();
|
||||||
|
for (final entry in json.entries) {
|
||||||
|
map[entry.key] = LoadTextualModelOnConnection.listFromJson(entry.value, growable: growable,);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The list of required keys that must be present in a JSON.
|
||||||
|
static const requiredKeys = <String>{
|
||||||
|
'enabled',
|
||||||
|
'ttl',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
@ -5296,8 +5296,8 @@
|
|||||||
"name": "password",
|
"name": "password",
|
||||||
"required": false,
|
"required": false,
|
||||||
"in": "query",
|
"in": "query",
|
||||||
|
"example": "password",
|
||||||
"schema": {
|
"schema": {
|
||||||
"example": "password",
|
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -8642,12 +8642,16 @@
|
|||||||
"enabled": {
|
"enabled": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"loadTextualModelOnConnection": {
|
||||||
|
"$ref": "#/components/schemas/LoadTextualModelOnConnection"
|
||||||
|
},
|
||||||
"modelName": {
|
"modelName": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"enabled",
|
"enabled",
|
||||||
|
"loadTextualModelOnConnection",
|
||||||
"modelName"
|
"modelName"
|
||||||
],
|
],
|
||||||
"type": "object"
|
"type": "object"
|
||||||
@ -9488,6 +9492,17 @@
|
|||||||
],
|
],
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
|
"LoadTextualModelOnConnection": {
|
||||||
|
"properties": {
|
||||||
|
"enabled": {
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"enabled"
|
||||||
|
],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"LogLevel": {
|
"LogLevel": {
|
||||||
"enum": [
|
"enum": [
|
||||||
"verbose",
|
"verbose",
|
||||||
|
@ -1150,8 +1150,13 @@ export type SystemConfigLoggingDto = {
|
|||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
level: LogLevel;
|
level: LogLevel;
|
||||||
};
|
};
|
||||||
|
export type LoadTextualModelOnConnection = {
|
||||||
|
enabled: boolean;
|
||||||
|
ttl: number;
|
||||||
|
};
|
||||||
export type ClipConfig = {
|
export type ClipConfig = {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
|
loadTextualModelOnConnection: LoadTextualModelOnConnection;
|
||||||
modelName: string;
|
modelName: string;
|
||||||
};
|
};
|
||||||
export type DuplicateDetectionConfig = {
|
export type DuplicateDetectionConfig = {
|
||||||
|
@ -58,6 +58,9 @@ export interface SystemConfig {
|
|||||||
clip: {
|
clip: {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
modelName: string;
|
modelName: string;
|
||||||
|
loadTextualModelOnConnection: {
|
||||||
|
enabled: boolean;
|
||||||
|
};
|
||||||
};
|
};
|
||||||
duplicateDetection: {
|
duplicateDetection: {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
@ -205,6 +208,9 @@ export const defaults = Object.freeze<SystemConfig>({
|
|||||||
clip: {
|
clip: {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
modelName: 'ViT-B-32__openai',
|
modelName: 'ViT-B-32__openai',
|
||||||
|
loadTextualModelOnConnection: {
|
||||||
|
enabled: false,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
duplicateDetection: {
|
duplicateDetection: {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { ApiProperty } from '@nestjs/swagger';
|
import { ApiProperty } from '@nestjs/swagger';
|
||||||
import { Type } from 'class-transformer';
|
import { Type } from 'class-transformer';
|
||||||
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
|
import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
|
||||||
import { ValidateBoolean } from 'src/validation';
|
import { ValidateBoolean } from 'src/validation';
|
||||||
|
|
||||||
export class TaskConfig {
|
export class TaskConfig {
|
||||||
@ -14,7 +14,17 @@ export class ModelConfig extends TaskConfig {
|
|||||||
modelName!: string;
|
modelName!: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class CLIPConfig extends ModelConfig {}
|
export class LoadTextualModelOnConnection {
|
||||||
|
@ValidateBoolean()
|
||||||
|
enabled!: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class CLIPConfig extends ModelConfig {
|
||||||
|
@Type(() => LoadTextualModelOnConnection)
|
||||||
|
@ValidateNested()
|
||||||
|
@IsObject()
|
||||||
|
loadTextualModelOnConnection!: LoadTextualModelOnConnection;
|
||||||
|
}
|
||||||
|
|
||||||
export class DuplicateDetectionConfig extends TaskConfig {
|
export class DuplicateDetectionConfig extends TaskConfig {
|
||||||
@IsNumber()
|
@IsNumber()
|
||||||
|
@ -46,6 +46,11 @@ export interface Face {
|
|||||||
score: number;
|
score: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum LoadTextModelActions {
|
||||||
|
LOAD,
|
||||||
|
UNLOAD,
|
||||||
|
}
|
||||||
|
|
||||||
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||||
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||||
@ -54,4 +59,5 @@ export interface IMachineLearningRepository {
|
|||||||
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||||
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
||||||
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||||
|
prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise<void>;
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import {
|
|||||||
WebSocketServer,
|
WebSocketServer,
|
||||||
} from '@nestjs/websockets';
|
} from '@nestjs/websockets';
|
||||||
import { Server, Socket } from 'socket.io';
|
import { Server, Socket } from 'socket.io';
|
||||||
|
import { SystemConfigCore } from 'src/cores/system-config.core';
|
||||||
import {
|
import {
|
||||||
ArgsOf,
|
ArgsOf,
|
||||||
ClientEventMap,
|
ClientEventMap,
|
||||||
@ -18,6 +19,8 @@ import {
|
|||||||
ServerEvents,
|
ServerEvents,
|
||||||
} from 'src/interfaces/event.interface';
|
} from 'src/interfaces/event.interface';
|
||||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||||
|
import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface';
|
||||||
|
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
|
||||||
import { AuthService } from 'src/services/auth.service';
|
import { AuthService } from 'src/services/auth.service';
|
||||||
import { Instrumentation } from 'src/utils/instrumentation';
|
import { Instrumentation } from 'src/utils/instrumentation';
|
||||||
import { handlePromiseError } from 'src/utils/misc';
|
import { handlePromiseError } from 'src/utils/misc';
|
||||||
@ -33,6 +36,7 @@ type EmitHandlers = Partial<{ [T in EmitEvent]: Array<EventItem<T>> }>;
|
|||||||
@Injectable()
|
@Injectable()
|
||||||
export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository {
|
export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository {
|
||||||
private emitHandlers: EmitHandlers = {};
|
private emitHandlers: EmitHandlers = {};
|
||||||
|
private configCore: SystemConfigCore;
|
||||||
|
|
||||||
@WebSocketServer()
|
@WebSocketServer()
|
||||||
private server?: Server;
|
private server?: Server;
|
||||||
@ -40,8 +44,11 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
|||||||
constructor(
|
constructor(
|
||||||
private moduleRef: ModuleRef,
|
private moduleRef: ModuleRef,
|
||||||
@Inject(ILoggerRepository) private logger: ILoggerRepository,
|
@Inject(ILoggerRepository) private logger: ILoggerRepository,
|
||||||
|
@Inject(IMachineLearningRepository) private machineLearningRepository: IMachineLearningRepository,
|
||||||
|
@Inject(ISystemMetadataRepository) systemMetadataRepository: ISystemMetadataRepository,
|
||||||
) {
|
) {
|
||||||
this.logger.setContext(EventRepository.name);
|
this.logger.setContext(EventRepository.name);
|
||||||
|
this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger);
|
||||||
}
|
}
|
||||||
|
|
||||||
afterInit(server: Server) {
|
afterInit(server: Server) {
|
||||||
@ -63,6 +70,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
|||||||
queryParams: {},
|
queryParams: {},
|
||||||
metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' },
|
metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' },
|
||||||
});
|
});
|
||||||
|
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
|
||||||
|
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
||||||
|
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
|
||||||
|
try {
|
||||||
|
console.log(this.server);
|
||||||
|
this.machineLearningRepository.prepareTextModel(
|
||||||
|
machineLearning.url,
|
||||||
|
machineLearning.clip,
|
||||||
|
LoadTextModelActions.LOAD,
|
||||||
|
);
|
||||||
|
} catch (error) {
|
||||||
|
this.logger.warn(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
await client.join(auth.user.id);
|
await client.join(auth.user.id);
|
||||||
if (auth.session) {
|
if (auth.session) {
|
||||||
await client.join(auth.session.id);
|
await client.join(auth.session.id);
|
||||||
@ -78,6 +100,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
|||||||
async handleDisconnect(client: Socket) {
|
async handleDisconnect(client: Socket) {
|
||||||
this.logger.log(`Websocket Disconnect: ${client.id}`);
|
this.logger.log(`Websocket Disconnect: ${client.id}`);
|
||||||
await client.leave(client.nsp.name);
|
await client.leave(client.nsp.name);
|
||||||
|
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
|
||||||
|
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
||||||
|
if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) {
|
||||||
|
try {
|
||||||
|
this.machineLearningRepository.prepareTextModel(
|
||||||
|
machineLearning.url,
|
||||||
|
machineLearning.clip,
|
||||||
|
LoadTextModelActions.UNLOAD,
|
||||||
|
);
|
||||||
|
this.logger.debug('sent request to unload text model');
|
||||||
|
} catch (error) {
|
||||||
|
this.logger.warn(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
on<T extends EmitEvent>(item: EventItem<T>): void {
|
on<T extends EmitEvent>(item: EventItem<T>): void {
|
||||||
|
@ -7,6 +7,7 @@ import {
|
|||||||
FaceDetectionOptions,
|
FaceDetectionOptions,
|
||||||
FacialRecognitionResponse,
|
FacialRecognitionResponse,
|
||||||
IMachineLearningRepository,
|
IMachineLearningRepository,
|
||||||
|
LoadTextModelActions,
|
||||||
MachineLearningRequest,
|
MachineLearningRequest,
|
||||||
ModelPayload,
|
ModelPayload,
|
||||||
ModelTask,
|
ModelTask,
|
||||||
@ -20,13 +21,9 @@ const errorPrefix = 'Machine learning request';
|
|||||||
@Injectable()
|
@Injectable()
|
||||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||||
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
||||||
const formData = await this.getFormData(payload, config);
|
const formData = await this.getFormData(config, payload);
|
||||||
|
|
||||||
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
|
const res = await this.fetchData(url, '/predict', formData);
|
||||||
(error: Error | any) => {
|
|
||||||
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
if (res.status >= 400) {
|
if (res.status >= 400) {
|
||||||
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
|
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
|
||||||
@ -34,6 +31,30 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
|||||||
return res.json();
|
return res.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async fetchData(url: string, path: string, formData?: FormData): Promise<Response> {
|
||||||
|
const res = await fetch(new URL(path, url), { method: 'POST', body: formData }).catch((error: Error | any) => {
|
||||||
|
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
|
||||||
|
});
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
private prepareTextModelUrl: Record<LoadTextModelActions, string> = {
|
||||||
|
[LoadTextModelActions.LOAD]: '/load',
|
||||||
|
[LoadTextModelActions.UNLOAD]: '/unload',
|
||||||
|
};
|
||||||
|
|
||||||
|
async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) {
|
||||||
|
try {
|
||||||
|
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||||
|
const formData = await this.getFormData(request);
|
||||||
|
const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData);
|
||||||
|
if (res.status >= 400) {
|
||||||
|
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
|
||||||
|
}
|
||||||
|
} catch (error) {}
|
||||||
|
}
|
||||||
|
|
||||||
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
||||||
const request = {
|
const request = {
|
||||||
[ModelTask.FACIAL_RECOGNITION]: {
|
[ModelTask.FACIAL_RECOGNITION]: {
|
||||||
@ -61,16 +82,17 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
|||||||
return response[ModelTask.SEARCH];
|
return response[ModelTask.SEARCH];
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
|
private async getFormData(config: MachineLearningRequest, payload?: ModelPayload): Promise<FormData> {
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append('entries', JSON.stringify(config));
|
formData.append('entries', JSON.stringify(config));
|
||||||
|
if (payload) {
|
||||||
if ('imagePath' in payload) {
|
if ('imagePath' in payload) {
|
||||||
formData.append('image', new Blob([await readFile(payload.imagePath)]));
|
formData.append('image', new Blob([await readFile(payload.imagePath)]));
|
||||||
} else if ('text' in payload) {
|
} else if ('text' in payload) {
|
||||||
formData.append('text', payload.text);
|
formData.append('text', payload.text);
|
||||||
} else {
|
} else {
|
||||||
throw new Error('Invalid input');
|
throw new Error('Invalid input');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return formData;
|
return formData;
|
||||||
|
@ -75,6 +75,21 @@
|
|||||||
</FormatMessage>
|
</FormatMessage>
|
||||||
</p>
|
</p>
|
||||||
</SettingInputField>
|
</SettingInputField>
|
||||||
|
|
||||||
|
<SettingAccordion
|
||||||
|
key="Preload clip model"
|
||||||
|
title={$t('admin.machine_learning_preload_model')}
|
||||||
|
subtitle={$t('admin.machine_learning_preload_model_setting_description')}
|
||||||
|
>
|
||||||
|
<div class="ml-4 mt-4 flex flex-col gap-4">
|
||||||
|
<SettingSwitch
|
||||||
|
title={$t('admin.machine_learning_preload_model_enabled')}
|
||||||
|
subtitle={$t('admin.machine_learning_preload_model_enabled_description')}
|
||||||
|
bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled}
|
||||||
|
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</SettingAccordion>
|
||||||
</div>
|
</div>
|
||||||
</SettingAccordion>
|
</SettingAccordion>
|
||||||
|
|
||||||
|
@ -119,6 +119,12 @@
|
|||||||
"machine_learning_min_detection_score_description": "Minimum confidence score for a face to be detected from 0-1. Lower values will detect more faces but may result in false positives.",
|
"machine_learning_min_detection_score_description": "Minimum confidence score for a face to be detected from 0-1. Lower values will detect more faces but may result in false positives.",
|
||||||
"machine_learning_min_recognized_faces": "Minimum recognized faces",
|
"machine_learning_min_recognized_faces": "Minimum recognized faces",
|
||||||
"machine_learning_min_recognized_faces_description": "The minimum number of recognized faces for a person to be created. Increasing this makes Facial Recognition more precise at the cost of increasing the chance that a face is not assigned to a person.",
|
"machine_learning_min_recognized_faces_description": "The minimum number of recognized faces for a person to be created. Increasing this makes Facial Recognition more precise at the cost of increasing the chance that a face is not assigned to a person.",
|
||||||
|
"machine_learning_preload_model": "Preload model",
|
||||||
|
"machine_learning_preload_model_enabled": "Enable preload model",
|
||||||
|
"machine_learning_preload_model_enabled_description": "Preload the textual model during the connexion instead of during the first search",
|
||||||
|
"machine_learning_preload_model_setting_description": "Preload the textual model during the connexion",
|
||||||
|
"machine_learning_preload_model_ttl": "Inactivity time before a model in unloaded",
|
||||||
|
"machine_learning_preload_model_ttl_description": "Preload the textual model during the connexion",
|
||||||
"machine_learning_settings": "Machine Learning Settings",
|
"machine_learning_settings": "Machine Learning Settings",
|
||||||
"machine_learning_settings_description": "Manage machine learning features and settings",
|
"machine_learning_settings_description": "Manage machine learning features and settings",
|
||||||
"machine_learning_smart_search": "Smart Search",
|
"machine_learning_smart_search": "Smart Search",
|
||||||
|
@ -35,6 +35,9 @@ const websocket: Socket<Events> = io({
|
|||||||
reconnection: true,
|
reconnection: true,
|
||||||
forceNew: true,
|
forceNew: true,
|
||||||
autoConnect: false,
|
autoConnect: false,
|
||||||
|
query: {
|
||||||
|
background: false,
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const websocketStore = {
|
export const websocketStore = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user