mirror of
https://github.com/immich-app/immich.git
synced 2025-07-09 03:04:16 -04:00
set probes
This commit is contained in:
parent
c80b16d24e
commit
b750440f90
@ -11,11 +11,3 @@ WHERE
|
||||
|
||||
-- DatabaseRepository.getPostgresVersion
|
||||
SHOW server_version
|
||||
|
||||
-- DatabaseRepository.shouldReindex
|
||||
SELECT
|
||||
idx_status
|
||||
FROM
|
||||
pg_vector_index_stat
|
||||
WHERE
|
||||
indexname = $1
|
||||
|
@ -64,6 +64,7 @@ limit
|
||||
$15
|
||||
|
||||
-- SearchRepository.searchSmart
|
||||
begin
|
||||
select
|
||||
"assets".*
|
||||
from
|
||||
@ -83,8 +84,10 @@ limit
|
||||
$7
|
||||
offset
|
||||
$8
|
||||
rollback
|
||||
|
||||
-- SearchRepository.searchDuplicates
|
||||
begin
|
||||
with
|
||||
"cte" as (
|
||||
select
|
||||
@ -102,18 +105,20 @@ with
|
||||
and "assets"."id" != $5::uuid
|
||||
and "assets"."stackId" is null
|
||||
order by
|
||||
smart_search.embedding <=> $6
|
||||
"distance"
|
||||
limit
|
||||
$7
|
||||
$6
|
||||
)
|
||||
select
|
||||
*
|
||||
from
|
||||
"cte"
|
||||
where
|
||||
"cte"."distance" <= $8
|
||||
"cte"."distance" <= $7
|
||||
rollback
|
||||
|
||||
-- SearchRepository.searchFaces
|
||||
begin
|
||||
with
|
||||
"cte" as (
|
||||
select
|
||||
@ -129,16 +134,17 @@ with
|
||||
"assets"."ownerId" = any ($2::uuid[])
|
||||
and "assets"."deletedAt" is null
|
||||
order by
|
||||
face_search.embedding <=> $3
|
||||
"distance"
|
||||
limit
|
||||
$4
|
||||
$3
|
||||
)
|
||||
select
|
||||
*
|
||||
from
|
||||
"cte"
|
||||
where
|
||||
"cte"."distance" <= $5
|
||||
"cte"."distance" <= $4
|
||||
commit
|
||||
|
||||
-- SearchRepository.searchPlaces
|
||||
select
|
||||
|
@ -25,7 +25,7 @@ import { vectorIndexQuery } from 'src/utils/database';
|
||||
import { isValidInteger } from 'src/validation';
|
||||
import { DataSource, QueryRunner } from 'typeorm';
|
||||
|
||||
let cachedVectorExtension: VectorExtension | undefined;
|
||||
export let cachedVectorExtension: VectorExtension | undefined;
|
||||
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
|
||||
if (cachedVectorExtension) {
|
||||
return cachedVectorExtension;
|
||||
@ -50,6 +50,11 @@ export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Prom
|
||||
return cachedVectorExtension;
|
||||
}
|
||||
|
||||
export const probes: Record<VectorIndex, number> = {
|
||||
[VectorIndex.CLIP]: 1,
|
||||
[VectorIndex.FACE]: 1,
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class DatabaseRepository {
|
||||
private readonly asyncLock = new AsyncLock();
|
||||
@ -183,21 +188,17 @@ export class DatabaseRepository {
|
||||
for (const indexName of names) {
|
||||
const row = rows.find((index) => index.indexname === indexName);
|
||||
const table = VECTOR_INDEX_TABLES[indexName];
|
||||
if (!row) {
|
||||
promises.push(this.reindexVectors(indexName));
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (vectorExtension) {
|
||||
case DatabaseExtension.VECTOR:
|
||||
case DatabaseExtension.VECTORS: {
|
||||
if (!row.indexdef.toLowerCase().includes(keyword)) {
|
||||
if (!row?.indexdef.toLowerCase().includes(keyword)) {
|
||||
promises.push(this.reindexVectors(indexName));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DatabaseExtension.VECTORCHORD: {
|
||||
const matches = row.indexdef.match(/(?<=lists = \[)\d+/g);
|
||||
const matches = row?.indexdef.match(/(?<=lists = \[)\d+/g);
|
||||
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
|
||||
promises.push(
|
||||
this.db
|
||||
@ -208,11 +209,14 @@ export class DatabaseRepository {
|
||||
const targetLists = this.targetListCount(count);
|
||||
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
|
||||
if (
|
||||
!row.indexdef.toLowerCase().includes(keyword) ||
|
||||
!row?.indexdef.toLowerCase().includes(keyword) ||
|
||||
// slack factor is to avoid frequent reindexing if the count is borderline
|
||||
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
|
||||
) {
|
||||
probes[indexName] = this.targetProbeCount(targetLists);
|
||||
return this.reindexVectors(indexName, { lists: targetLists });
|
||||
} else {
|
||||
probes[indexName] = this.targetProbeCount(lists);
|
||||
}
|
||||
}),
|
||||
);
|
||||
@ -239,7 +243,7 @@ export class DatabaseRepository {
|
||||
);
|
||||
return;
|
||||
}
|
||||
const dimSize = await this.getDimSize(table);
|
||||
const dimSize = await this.getDimensionSize(table);
|
||||
await this.db.transaction().execute(async (tx) => {
|
||||
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
|
||||
if (!rows.some((row) => row.columnName === 'embedding')) {
|
||||
@ -261,7 +265,7 @@ export class DatabaseRepository {
|
||||
await sql`SET search_path TO "$user", public, vectors`.execute(tx);
|
||||
}
|
||||
|
||||
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
|
||||
async getDimensionSize(table: string, column = 'embedding'): Promise<number> {
|
||||
const { rows } = await sql<{ dimsize: number }>`
|
||||
SELECT atttypmod as dimsize
|
||||
FROM pg_attribute f
|
||||
@ -280,7 +284,41 @@ export class DatabaseRepository {
|
||||
return dimSize;
|
||||
}
|
||||
|
||||
// TODO: set probes in queries
|
||||
async setDimensionSize(dimSize: number): Promise<void> {
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||
}
|
||||
|
||||
// this is done in two transactions to handle concurrent writes
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
|
||||
trx,
|
||||
);
|
||||
});
|
||||
|
||||
const vectorExtension = await this.getVectorExtension();
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`drop index if exists clip_index`.execute(trx);
|
||||
await trx.schema
|
||||
.alterTable('smart_search')
|
||||
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
||||
.execute();
|
||||
await sql
|
||||
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
|
||||
.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
});
|
||||
probes[VectorIndex.CLIP] = 1;
|
||||
|
||||
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
async deleteAllSearchEmbeddings(): Promise<void> {
|
||||
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
private targetListCount(count: number) {
|
||||
if (count < 128_000) {
|
||||
return 1;
|
||||
@ -291,6 +329,10 @@ export class DatabaseRepository {
|
||||
}
|
||||
}
|
||||
|
||||
private targetProbeCount(lists: number) {
|
||||
return Math.ceil(lists / 8);
|
||||
}
|
||||
|
||||
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
|
||||
const { database } = this.configRepository.getEnv();
|
||||
|
||||
|
@ -5,9 +5,9 @@ import { randomUUID } from 'node:crypto';
|
||||
import { DB, Exif } from 'src/db';
|
||||
import { DummyValue, GenerateSql } from 'src/decorators';
|
||||
import { MapAsset } from 'src/dtos/asset-response.dto';
|
||||
import { AssetStatus, AssetType, AssetVisibility, VectorIndex } from 'src/enum';
|
||||
import { getVectorExtension } from 'src/repositories/database.repository';
|
||||
import { anyUuid, asUuid, searchAssetBuilder, vectorIndexQuery } from 'src/utils/database';
|
||||
import { AssetStatus, AssetType, AssetVisibility, DatabaseExtension, VectorIndex } from 'src/enum';
|
||||
import { cachedVectorExtension, probes } from 'src/repositories/database.repository';
|
||||
import { anyUuid, asUuid, searchAssetBuilder } from 'src/utils/database';
|
||||
import { paginationHelper } from 'src/utils/pagination';
|
||||
import { isValidInteger } from 'src/validation';
|
||||
|
||||
@ -233,19 +233,23 @@ export class SearchRepository {
|
||||
},
|
||||
],
|
||||
})
|
||||
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
|
||||
searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
|
||||
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
|
||||
throw new Error(`Invalid value for 'size': ${pagination.size}`);
|
||||
}
|
||||
|
||||
const items = await searchAssetBuilder(this.db, options)
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
|
||||
.limit(pagination.size + 1)
|
||||
.offset((pagination.page - 1) * pagination.size)
|
||||
.execute();
|
||||
|
||||
return paginationHelper(items, pagination.size);
|
||||
return this.db.transaction().execute(async (trx) => {
|
||||
if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
|
||||
await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
|
||||
}
|
||||
const items = await searchAssetBuilder(trx, options)
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
|
||||
.limit(pagination.size + 1)
|
||||
.offset((pagination.page - 1) * pagination.size)
|
||||
.execute();
|
||||
return paginationHelper(items, pagination.size);
|
||||
});
|
||||
}
|
||||
|
||||
@GenerateSql({
|
||||
@ -260,29 +264,35 @@ export class SearchRepository {
|
||||
],
|
||||
})
|
||||
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
|
||||
return this.db
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('assets')
|
||||
.select([
|
||||
'assets.id as assetId',
|
||||
'assets.duplicateId',
|
||||
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
|
||||
.where('assets.type', '=', type)
|
||||
.where('assets.id', '!=', asUuid(assetId))
|
||||
.where('assets.stackId', 'is', null)
|
||||
.orderBy(sql`smart_search.embedding <=> ${embedding}`)
|
||||
.limit(64),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance as number)
|
||||
.execute();
|
||||
return this.db.transaction().execute(async (trx) => {
|
||||
if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
|
||||
await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
|
||||
}
|
||||
|
||||
return await trx
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('assets')
|
||||
.select([
|
||||
'assets.id as assetId',
|
||||
'assets.duplicateId',
|
||||
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
|
||||
.where('assets.type', '=', type)
|
||||
.where('assets.id', '!=', asUuid(assetId))
|
||||
.where('assets.stackId', 'is', null)
|
||||
.orderBy('distance')
|
||||
.limit(64),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance as number)
|
||||
.execute();
|
||||
});
|
||||
}
|
||||
|
||||
@GenerateSql({
|
||||
@ -300,31 +310,39 @@ export class SearchRepository {
|
||||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||
}
|
||||
|
||||
return this.db
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('asset_faces')
|
||||
.select([
|
||||
'asset_faces.id',
|
||||
'asset_faces.personId',
|
||||
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
|
||||
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
|
||||
.leftJoin('person', 'person.id', 'asset_faces.personId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
|
||||
.$if(!!minBirthDate, (qb) =>
|
||||
qb.where((eb) => eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)])),
|
||||
)
|
||||
.orderBy(sql`face_search.embedding <=> ${embedding}`)
|
||||
.limit(numResults),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance)
|
||||
.execute();
|
||||
return this.db.transaction().execute(async (trx) => {
|
||||
if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
|
||||
await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.FACE])}`.execute(trx);
|
||||
}
|
||||
|
||||
return await trx
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('asset_faces')
|
||||
.select([
|
||||
'asset_faces.id',
|
||||
'asset_faces.personId',
|
||||
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
|
||||
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
|
||||
.leftJoin('person', 'person.id', 'asset_faces.personId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
|
||||
.$if(!!minBirthDate, (qb) =>
|
||||
qb.where((eb) =>
|
||||
eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)]),
|
||||
),
|
||||
)
|
||||
.orderBy('distance')
|
||||
.limit(numResults),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance)
|
||||
.execute();
|
||||
});
|
||||
}
|
||||
|
||||
@GenerateSql({ params: [DummyValue.STRING] })
|
||||
@ -413,58 +431,6 @@ export class SearchRepository {
|
||||
.execute();
|
||||
}
|
||||
|
||||
async getDimensionSize(): Promise<number> {
|
||||
const { rows } = await sql<{ dimsize: number }>`
|
||||
select atttypmod as dimsize
|
||||
from pg_attribute f
|
||||
join pg_class c ON c.oid = f.attrelid
|
||||
where c.relkind = 'r'::char
|
||||
and f.attnum > 0
|
||||
and c.relname = 'smart_search'
|
||||
and f.attname = 'embedding'
|
||||
`.execute(this.db);
|
||||
|
||||
const dimSize = rows[0]['dimsize'];
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Could not retrieve CLIP dimension size`);
|
||||
}
|
||||
return dimSize;
|
||||
}
|
||||
|
||||
async setDimensionSize(dimSize: number): Promise<void> {
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||
}
|
||||
|
||||
// this is done in two transactions to handle concurrent writes
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
|
||||
trx,
|
||||
);
|
||||
});
|
||||
|
||||
const vectorExtension = await getVectorExtension(this.db);
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`drop index if exists clip_index`.execute(trx);
|
||||
await trx.schema
|
||||
.alterTable('smart_search')
|
||||
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
||||
.execute();
|
||||
await sql
|
||||
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
|
||||
.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
});
|
||||
|
||||
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
async deleteAllSearchEmbeddings(): Promise<void> {
|
||||
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
async getCountries(userIds: string[]): Promise<string[]> {
|
||||
const res = await this.getExifField('country', userIds).execute();
|
||||
return res.map((row) => row.country!);
|
||||
|
@ -109,7 +109,10 @@ export class DatabaseService extends BaseService {
|
||||
if (!database.skipMigrations) {
|
||||
await this.databaseRepository.runMigrations();
|
||||
}
|
||||
await this.databaseRepository.prewarm(VectorIndex.CLIP);
|
||||
await Promise.all([
|
||||
this.databaseRepository.prewarm(VectorIndex.CLIP),
|
||||
this.databaseRepository.prewarm(VectorIndex.FACE),
|
||||
]);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -54,28 +54,28 @@ describe(SmartInfoService.name, () => {
|
||||
it('should return if machine learning is disabled', async () => {
|
||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningDisabled as SystemConfig });
|
||||
|
||||
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return if model and DB dimension size are equal', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should update DB dimension size if model and DB have different values', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(768);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(768);
|
||||
|
||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(512);
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(512);
|
||||
});
|
||||
});
|
||||
|
||||
@ -89,13 +89,13 @@ describe(SmartInfoService.name, () => {
|
||||
});
|
||||
|
||||
expect(mocks.systemMetadata.get).not.toHaveBeenCalled();
|
||||
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return if model and DB dimension size are equal', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigUpdate({
|
||||
newConfig: {
|
||||
@ -106,13 +106,13 @@ describe(SmartInfoService.name, () => {
|
||||
} as SystemConfig,
|
||||
});
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should update DB dimension size if model and DB have different values', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigUpdate({
|
||||
newConfig: {
|
||||
@ -123,12 +123,12 @@ describe(SmartInfoService.name, () => {
|
||||
} as SystemConfig,
|
||||
});
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(768);
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(768);
|
||||
});
|
||||
|
||||
it('should clear embeddings if old and new models are different', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigUpdate({
|
||||
newConfig: {
|
||||
@ -139,9 +139,9 @@ describe(SmartInfoService.name, () => {
|
||||
} as SystemConfig,
|
||||
});
|
||||
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).toHaveBeenCalled();
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@ -151,7 +151,7 @@ describe(SmartInfoService.name, () => {
|
||||
|
||||
await sut.handleQueueEncodeClip({});
|
||||
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should queue the assets without clip embeddings', async () => {
|
||||
@ -163,7 +163,7 @@ describe(SmartInfoService.name, () => {
|
||||
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
||||
]);
|
||||
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(false);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should queue all the assets', async () => {
|
||||
@ -175,7 +175,7 @@ describe(SmartInfoService.name, () => {
|
||||
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
||||
]);
|
||||
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(true);
|
||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512);
|
||||
expect(mocks.database.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512);
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -38,7 +38,7 @@ export class SmartInfoService extends BaseService {
|
||||
|
||||
await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, async () => {
|
||||
const { dimSize } = getCLIPModelInfo(newConfig.machineLearning.clip.modelName);
|
||||
const dbDimSize = await this.searchRepository.getDimensionSize();
|
||||
const dbDimSize = await this.databaseRepository.getDimensionSize('smart_search');
|
||||
this.logger.verbose(`Current database CLIP dimension size is ${dbDimSize}`);
|
||||
|
||||
const modelChange =
|
||||
@ -53,10 +53,10 @@ export class SmartInfoService extends BaseService {
|
||||
`Dimension size of model ${newConfig.machineLearning.clip.modelName} is ${dimSize}, but database expects ${dbDimSize}.`,
|
||||
);
|
||||
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
|
||||
await this.searchRepository.setDimensionSize(dimSize);
|
||||
await this.databaseRepository.setDimensionSize(dimSize);
|
||||
this.logger.log(`Successfully updated database CLIP dimension size from ${dbDimSize} to ${dimSize}.`);
|
||||
} else {
|
||||
await this.searchRepository.deleteAllSearchEmbeddings();
|
||||
await this.databaseRepository.deleteAllSearchEmbeddings();
|
||||
}
|
||||
|
||||
// TODO: A job to reindex all assets should be scheduled, though user
|
||||
@ -74,7 +74,7 @@ export class SmartInfoService extends BaseService {
|
||||
if (force) {
|
||||
const { dimSize } = getCLIPModelInfo(machineLearning.clip.modelName);
|
||||
// in addition to deleting embeddings, update the dimension size in case it failed earlier
|
||||
await this.searchRepository.setDimensionSize(dimSize);
|
||||
await this.databaseRepository.setDimensionSize(dimSize);
|
||||
}
|
||||
|
||||
let queue: JobItem[] = [];
|
||||
|
@ -397,12 +397,14 @@ export function vectorIndexQuery({ vectorExtension, table, indexName, lists }: V
|
||||
lists = [${lists ?? 1}]
|
||||
spherical_centroids = true
|
||||
build_threads = 4
|
||||
sampling_factor = 1024
|
||||
$$)`;
|
||||
}
|
||||
case DatabaseExtension.VECTORS: {
|
||||
return `
|
||||
CREATE INDEX IF NOT EXISTS ${indexName} ON ${table}
|
||||
USING vectors (embedding vector_cos_ops) WITH (options = $$
|
||||
optimizing.optimizing_threads = 4
|
||||
[indexing.hnsw]
|
||||
m = 16
|
||||
ef_construction = 300
|
||||
|
Loading…
x
Reference in New Issue
Block a user