1
0
forked from Cutlery/immich

use embedding in db instead of fetching

This commit is contained in:
mertalev 2024-03-23 18:41:07 -04:00
parent 33d09fb5ef
commit 179834faeb
No known key found for this signature in database
GPG Key ID: 9181CD92C0A1C5E3
5 changed files with 46 additions and 25 deletions

View File

@ -15,6 +15,7 @@ export class SmartSearchEntity {
type: 'float4',
array: true,
select: false,
transformer: { from: (v) => JSON.parse(v), to: (v) => v },
})
embedding!: number[];
}

View File

@ -133,7 +133,6 @@ export interface SearchExifOptions {
export interface SearchEmbeddingOptions {
embedding: Embedding;
userIds: string[];
maxDistance?: number;
}
export interface SearchPeopleOptions {
@ -175,6 +174,13 @@ export type SmartSearchOptions = SearchDateOptions &
export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {
hasPerson?: boolean;
numResults: number;
maxDistance?: number;
}
export interface AssetDuplicateSearch {
assetId: string;
userIds: string[];
maxDistance?: number;
}
export interface FaceSearchResult {
@ -192,7 +198,7 @@ export interface ISearchRepository {
init(modelName: string): Promise<void>;
searchMetadata(pagination: SearchPaginationOptions, options: AssetSearchOptions): Paginated<AssetEntity>;
searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions): Paginated<AssetEntity>;
searchDuplicates(options: SearchEmbeddingOptions): Promise<AssetDuplicateResult[]>;
searchDuplicates(options: AssetDuplicateSearch): Promise<AssetDuplicateResult[]>;
searchFaces(search: FaceEmbeddingSearch): Promise<FaceSearchResult[]>;
upsert(assetId: string, embedding: number[]): Promise<void>;
searchPlaces(placeName: string): Promise<GeodataPlacesEntity[]>;

View File

@ -11,14 +11,14 @@ import { Repository } from 'typeorm';
export class AssetDuplicateRepository implements IAssetDuplicateRepository {
constructor(@InjectRepository(AssetDuplicateEntity) private repository: Repository<AssetDuplicateEntity>) {}
async create(duplicateId: string, assetIds: string[]) {
async create(id: string, assetIds: string[]) {
await this.repository.manager.transaction(async (manager) => {
await manager.upsert(
AssetDuplicateEntity,
assetIds.map((assetId) => ({ duplicateId, assetId })),
assetIds.map((assetId) => ({ id, assetId })),
['assetId'],
);
await manager.update(AssetEntity, assetIds, { duplicateId });
await manager.update(AssetEntity, assetIds, { duplicateId: id });
});
}

View File

@ -10,11 +10,11 @@ import { SmartSearchEntity } from 'src/entities/smart-search.entity';
import { DatabaseExtension } from 'src/interfaces/database.interface';
import {
AssetDuplicateResult,
AssetDuplicateSearch,
AssetSearchOptions,
FaceEmbeddingSearch,
FaceSearchResult,
ISearchRepository,
SearchEmbeddingOptions,
SearchPaginationOptions,
SmartSearchOptions,
} from 'src/interfaces/search.interface';
@ -155,25 +155,27 @@ export class SearchRepository implements ISearchRepository {
},
],
})
searchDuplicates({ embedding, maxDistance, userIds }: SearchEmbeddingOptions): Promise<AssetDuplicateResult[]> {
searchDuplicates({ assetId, maxDistance, userIds }: AssetDuplicateSearch): Promise<AssetDuplicateResult[]> {
const cte = this.assetRepository.createQueryBuilder('asset');
cte
.select('asset.id', 'assetId')
.addSelect('asset.duplicateId')
.addSelect('search.embedding <=> :embedding', 'distance')
.select('search.assetId', 'assetId')
.addSelect('asset.duplicateId', 'duplicateId')
.addSelect(`(SELECT embedding FROM smart_search WHERE "assetId" = :assetId) <=> search.embedding`, 'distance')
.innerJoin('asset.smartSearch', 'search')
.where('asset.ownerId IN (:...userIds )')
.orderBy('asset.embedding <=> :embedding')
.andWhere('asset.id != :assetId', { assetId })
.orderBy('search.embedding <=> (SELECT embedding FROM smart_search WHERE "assetId" = :assetId)')
.limit(64)
.setParameters({ embedding: asVector(embedding), userIds });
.setParameters({ assetId, userIds });
const builder = this.assetRepository
.createQueryBuilder('asset')
const builder = this.assetRepository.manager
.createQueryBuilder()
.addCommonTableExpression(cte, 'cte')
.select('cte.*')
.where('cte.distance <= :maxDistance', { maxDistance });
.from('cte', 'res')
.select('res.*')
.where('res.distance <= :maxDistance', { maxDistance });
return builder.getMany() as any as Promise<AssetDuplicateResult[]>;
return builder.getRawMany() as any as Promise<AssetDuplicateResult[]>;
}
@GenerateSql({

View File

@ -178,24 +178,36 @@ export class SearchService {
async handleSearchDuplicates({ id }: IEntityJob): Promise<JobStatus> {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.clip.enabled) {
const asset = await this.assetRepository.getById(id);
if (!asset) {
this.logger.error(`Asset ${id} not found`);
return JobStatus.FAILED;
}
if (!asset.isVisible) {
this.logger.debug(`Asset ${id} is not visible, skipping`);
await this.assetRepository.upsertJobStatus({
assetId: asset.id,
duplicatesDetectedAt: new Date(),
});
return JobStatus.SKIPPED;
}
const asset = await this.assetRepository.getById(id, { smartSearch: true });
if (!asset?.isVisible || asset.duplicateId) {
if (asset.duplicateId) {
this.logger.debug(`Asset ${id} already has a duplicateId, skipping`);
return JobStatus.SKIPPED;
}
if (!asset?.resizePath || !asset.smartSearch?.embedding) {
if (!asset.resizePath) {
this.logger.debug(`Asset ${id} is missing preview image`);
return JobStatus.FAILED;
}
const duplicateAssets = await this.searchRepository.searchDuplicates({
userIds: [asset.ownerId],
embedding: asset.smartSearch.embedding,
assetId: asset.id,
maxDistance: machineLearning.clip.duplicateThreshold,
userIds: [asset.ownerId],
});
if (duplicateAssets.length > 0) {
@ -212,7 +224,7 @@ export class SearchService {
await this.assetRepository.upsertJobStatus({
assetId: asset.id,
facesRecognizedAt: new Date(),
duplicatesDetectedAt: new Date(),
});
return JobStatus.SUCCESS;