set probes

This commit is contained in:
mertalev 2025-05-09 15:01:11 -04:00
parent c80b16d24e
commit b750440f90
No known key found for this signature in database
GPG Key ID: DF6ABC77AAD98C95
8 changed files with 180 additions and 169 deletions

View File

@ -11,11 +11,3 @@ WHERE
-- DatabaseRepository.getPostgresVersion -- DatabaseRepository.getPostgresVersion
SHOW server_version SHOW server_version
-- DatabaseRepository.shouldReindex
SELECT
idx_status
FROM
pg_vector_index_stat
WHERE
indexname = $1

View File

@ -64,6 +64,7 @@ limit
$15 $15
-- SearchRepository.searchSmart -- SearchRepository.searchSmart
begin
select select
"assets".* "assets".*
from from
@ -83,8 +84,10 @@ limit
$7 $7
offset offset
$8 $8
rollback
-- SearchRepository.searchDuplicates -- SearchRepository.searchDuplicates
begin
with with
"cte" as ( "cte" as (
select select
@ -102,18 +105,20 @@ with
and "assets"."id" != $5::uuid and "assets"."id" != $5::uuid
and "assets"."stackId" is null and "assets"."stackId" is null
order by order by
smart_search.embedding <=> $6 "distance"
limit limit
$7 $6
) )
select select
* *
from from
"cte" "cte"
where where
"cte"."distance" <= $8 "cte"."distance" <= $7
rollback
-- SearchRepository.searchFaces -- SearchRepository.searchFaces
begin
with with
"cte" as ( "cte" as (
select select
@ -129,16 +134,17 @@ with
"assets"."ownerId" = any ($2::uuid[]) "assets"."ownerId" = any ($2::uuid[])
and "assets"."deletedAt" is null and "assets"."deletedAt" is null
order by order by
face_search.embedding <=> $3 "distance"
limit limit
$4 $3
) )
select select
* *
from from
"cte" "cte"
where where
"cte"."distance" <= $5 "cte"."distance" <= $4
commit
-- SearchRepository.searchPlaces -- SearchRepository.searchPlaces
select select

View File

@ -25,7 +25,7 @@ import { vectorIndexQuery } from 'src/utils/database';
import { isValidInteger } from 'src/validation'; import { isValidInteger } from 'src/validation';
import { DataSource, QueryRunner } from 'typeorm'; import { DataSource, QueryRunner } from 'typeorm';
let cachedVectorExtension: VectorExtension | undefined; export let cachedVectorExtension: VectorExtension | undefined;
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> { export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
if (cachedVectorExtension) { if (cachedVectorExtension) {
return cachedVectorExtension; return cachedVectorExtension;
@ -50,6 +50,11 @@ export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Prom
return cachedVectorExtension; return cachedVectorExtension;
} }
export const probes: Record<VectorIndex, number> = {
[VectorIndex.CLIP]: 1,
[VectorIndex.FACE]: 1,
};
@Injectable() @Injectable()
export class DatabaseRepository { export class DatabaseRepository {
private readonly asyncLock = new AsyncLock(); private readonly asyncLock = new AsyncLock();
@ -183,21 +188,17 @@ export class DatabaseRepository {
for (const indexName of names) { for (const indexName of names) {
const row = rows.find((index) => index.indexname === indexName); const row = rows.find((index) => index.indexname === indexName);
const table = VECTOR_INDEX_TABLES[indexName]; const table = VECTOR_INDEX_TABLES[indexName];
if (!row) {
promises.push(this.reindexVectors(indexName));
continue;
}
switch (vectorExtension) { switch (vectorExtension) {
case DatabaseExtension.VECTOR: case DatabaseExtension.VECTOR:
case DatabaseExtension.VECTORS: { case DatabaseExtension.VECTORS: {
if (!row.indexdef.toLowerCase().includes(keyword)) { if (!row?.indexdef.toLowerCase().includes(keyword)) {
promises.push(this.reindexVectors(indexName)); promises.push(this.reindexVectors(indexName));
} }
break; break;
} }
case DatabaseExtension.VECTORCHORD: { case DatabaseExtension.VECTORCHORD: {
const matches = row.indexdef.match(/(?<=lists = \[)\d+/g); const matches = row?.indexdef.match(/(?<=lists = \[)\d+/g);
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1; const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
promises.push( promises.push(
this.db this.db
@ -208,11 +209,14 @@ export class DatabaseRepository {
const targetLists = this.targetListCount(count); const targetLists = this.targetListCount(count);
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`); this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
if ( if (
!row.indexdef.toLowerCase().includes(keyword) || !row?.indexdef.toLowerCase().includes(keyword) ||
// slack factor is to avoid frequent reindexing if the count is borderline // slack factor is to avoid frequent reindexing if the count is borderline
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR)) (lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
) { ) {
probes[indexName] = this.targetProbeCount(targetLists);
return this.reindexVectors(indexName, { lists: targetLists }); return this.reindexVectors(indexName, { lists: targetLists });
} else {
probes[indexName] = this.targetProbeCount(lists);
} }
}), }),
); );
@ -239,7 +243,7 @@ export class DatabaseRepository {
); );
return; return;
} }
const dimSize = await this.getDimSize(table); const dimSize = await this.getDimensionSize(table);
await this.db.transaction().execute(async (tx) => { await this.db.transaction().execute(async (tx) => {
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx); await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
if (!rows.some((row) => row.columnName === 'embedding')) { if (!rows.some((row) => row.columnName === 'embedding')) {
@ -261,7 +265,7 @@ export class DatabaseRepository {
await sql`SET search_path TO "$user", public, vectors`.execute(tx); await sql`SET search_path TO "$user", public, vectors`.execute(tx);
} }
private async getDimSize(table: string, column = 'embedding'): Promise<number> { async getDimensionSize(table: string, column = 'embedding'): Promise<number> {
const { rows } = await sql<{ dimsize: number }>` const { rows } = await sql<{ dimsize: number }>`
SELECT atttypmod as dimsize SELECT atttypmod as dimsize
FROM pg_attribute f FROM pg_attribute f
@ -280,7 +284,41 @@ export class DatabaseRepository {
return dimSize; return dimSize;
} }
// TODO: set probes in queries async setDimensionSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
// this is done in two transactions to handle concurrent writes
await this.db.transaction().execute(async (trx) => {
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
trx,
);
});
const vectorExtension = await this.getVectorExtension();
await this.db.transaction().execute(async (trx) => {
await sql`drop index if exists clip_index`.execute(trx);
await trx.schema
.alterTable('smart_search')
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
.execute();
await sql
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
});
probes[VectorIndex.CLIP] = 1;
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
}
async deleteAllSearchEmbeddings(): Promise<void> {
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
}
private targetListCount(count: number) { private targetListCount(count: number) {
if (count < 128_000) { if (count < 128_000) {
return 1; return 1;
@ -291,6 +329,10 @@ export class DatabaseRepository {
} }
} }
private targetProbeCount(lists: number) {
return Math.ceil(lists / 8);
}
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> { async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
const { database } = this.configRepository.getEnv(); const { database } = this.configRepository.getEnv();

View File

@ -5,9 +5,9 @@ import { randomUUID } from 'node:crypto';
import { DB, Exif } from 'src/db'; import { DB, Exif } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators'; import { DummyValue, GenerateSql } from 'src/decorators';
import { MapAsset } from 'src/dtos/asset-response.dto'; import { MapAsset } from 'src/dtos/asset-response.dto';
import { AssetStatus, AssetType, AssetVisibility, VectorIndex } from 'src/enum'; import { AssetStatus, AssetType, AssetVisibility, DatabaseExtension, VectorIndex } from 'src/enum';
import { getVectorExtension } from 'src/repositories/database.repository'; import { cachedVectorExtension, probes } from 'src/repositories/database.repository';
import { anyUuid, asUuid, searchAssetBuilder, vectorIndexQuery } from 'src/utils/database'; import { anyUuid, asUuid, searchAssetBuilder } from 'src/utils/database';
import { paginationHelper } from 'src/utils/pagination'; import { paginationHelper } from 'src/utils/pagination';
import { isValidInteger } from 'src/validation'; import { isValidInteger } from 'src/validation';
@ -233,19 +233,23 @@ export class SearchRepository {
}, },
], ],
}) })
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) { searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) { if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
throw new Error(`Invalid value for 'size': ${pagination.size}`); throw new Error(`Invalid value for 'size': ${pagination.size}`);
} }
const items = await searchAssetBuilder(this.db, options) return this.db.transaction().execute(async (trx) => {
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId') if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`) await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
.limit(pagination.size + 1) }
.offset((pagination.page - 1) * pagination.size) const items = await searchAssetBuilder(trx, options)
.execute(); .innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
return paginationHelper(items, pagination.size); .limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute();
return paginationHelper(items, pagination.size);
});
} }
@GenerateSql({ @GenerateSql({
@ -260,29 +264,35 @@ export class SearchRepository {
], ],
}) })
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) { searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
return this.db return this.db.transaction().execute(async (trx) => {
.with('cte', (qb) => if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
qb await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
.selectFrom('assets') }
.select([
'assets.id as assetId', return await trx
'assets.duplicateId', .with('cte', (qb) =>
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'), qb
]) .selectFrom('assets')
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId') .select([
.where('assets.ownerId', '=', anyUuid(userIds)) 'assets.id as assetId',
.where('assets.deletedAt', 'is', null) 'assets.duplicateId',
.where('assets.visibility', '!=', AssetVisibility.HIDDEN) sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
.where('assets.type', '=', type) ])
.where('assets.id', '!=', asUuid(assetId)) .innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.where('assets.stackId', 'is', null) .where('assets.ownerId', '=', anyUuid(userIds))
.orderBy(sql`smart_search.embedding <=> ${embedding}`) .where('assets.deletedAt', 'is', null)
.limit(64), .where('assets.visibility', '!=', AssetVisibility.HIDDEN)
) .where('assets.type', '=', type)
.selectFrom('cte') .where('assets.id', '!=', asUuid(assetId))
.selectAll() .where('assets.stackId', 'is', null)
.where('cte.distance', '<=', maxDistance as number) .orderBy('distance')
.execute(); .limit(64),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance as number)
.execute();
});
} }
@GenerateSql({ @GenerateSql({
@ -300,31 +310,39 @@ export class SearchRepository {
throw new Error(`Invalid value for 'numResults': ${numResults}`); throw new Error(`Invalid value for 'numResults': ${numResults}`);
} }
return this.db return this.db.transaction().execute(async (trx) => {
.with('cte', (qb) => if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
qb await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.FACE])}`.execute(trx);
.selectFrom('asset_faces') }
.select([
'asset_faces.id', return await trx
'asset_faces.personId', .with('cte', (qb) =>
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'), qb
]) .selectFrom('asset_faces')
.innerJoin('assets', 'assets.id', 'asset_faces.assetId') .select([
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id') 'asset_faces.id',
.leftJoin('person', 'person.id', 'asset_faces.personId') 'asset_faces.personId',
.where('assets.ownerId', '=', anyUuid(userIds)) sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
.where('assets.deletedAt', 'is', null) ])
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null)) .innerJoin('assets', 'assets.id', 'asset_faces.assetId')
.$if(!!minBirthDate, (qb) => .innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
qb.where((eb) => eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)])), .leftJoin('person', 'person.id', 'asset_faces.personId')
) .where('assets.ownerId', '=', anyUuid(userIds))
.orderBy(sql`face_search.embedding <=> ${embedding}`) .where('assets.deletedAt', 'is', null)
.limit(numResults), .$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
) .$if(!!minBirthDate, (qb) =>
.selectFrom('cte') qb.where((eb) =>
.selectAll() eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)]),
.where('cte.distance', '<=', maxDistance) ),
.execute(); )
.orderBy('distance')
.limit(numResults),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance)
.execute();
});
} }
@GenerateSql({ params: [DummyValue.STRING] }) @GenerateSql({ params: [DummyValue.STRING] })
@ -413,58 +431,6 @@ export class SearchRepository {
.execute(); .execute();
} }
async getDimensionSize(): Promise<number> {
const { rows } = await sql<{ dimsize: number }>`
select atttypmod as dimsize
from pg_attribute f
join pg_class c ON c.oid = f.attrelid
where c.relkind = 'r'::char
and f.attnum > 0
and c.relname = 'smart_search'
and f.attname = 'embedding'
`.execute(this.db);
const dimSize = rows[0]['dimsize'];
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve CLIP dimension size`);
}
return dimSize;
}
async setDimensionSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
// this is done in two transactions to handle concurrent writes
await this.db.transaction().execute(async (trx) => {
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
trx,
);
});
const vectorExtension = await getVectorExtension(this.db);
await this.db.transaction().execute(async (trx) => {
await sql`drop index if exists clip_index`.execute(trx);
await trx.schema
.alterTable('smart_search')
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
.execute();
await sql
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
});
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
}
async deleteAllSearchEmbeddings(): Promise<void> {
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
}
async getCountries(userIds: string[]): Promise<string[]> { async getCountries(userIds: string[]): Promise<string[]> {
const res = await this.getExifField('country', userIds).execute(); const res = await this.getExifField('country', userIds).execute();
return res.map((row) => row.country!); return res.map((row) => row.country!);

View File

@ -109,7 +109,10 @@ export class DatabaseService extends BaseService {
if (!database.skipMigrations) { if (!database.skipMigrations) {
await this.databaseRepository.runMigrations(); await this.databaseRepository.runMigrations();
} }
await this.databaseRepository.prewarm(VectorIndex.CLIP); await Promise.all([
this.databaseRepository.prewarm(VectorIndex.CLIP),
this.databaseRepository.prewarm(VectorIndex.FACE),
]);
}); });
} }

View File

@ -54,28 +54,28 @@ describe(SmartInfoService.name, () => {
it('should return if machine learning is disabled', async () => { it('should return if machine learning is disabled', async () => {
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningDisabled as SystemConfig }); await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningDisabled as SystemConfig });
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
}); });
it('should return if model and DB dimension size are equal', async () => { it('should return if model and DB dimension size are equal', async () => {
mocks.search.getDimensionSize.mockResolvedValue(512); mocks.database.getDimensionSize.mockResolvedValue(512);
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig }); await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1); expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
}); });
it('should update DB dimension size if model and DB have different values', async () => { it('should update DB dimension size if model and DB have different values', async () => {
mocks.search.getDimensionSize.mockResolvedValue(768); mocks.database.getDimensionSize.mockResolvedValue(768);
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig }); await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1); expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(512); expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(512);
}); });
}); });
@ -89,13 +89,13 @@ describe(SmartInfoService.name, () => {
}); });
expect(mocks.systemMetadata.get).not.toHaveBeenCalled(); expect(mocks.systemMetadata.get).not.toHaveBeenCalled();
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
}); });
it('should return if model and DB dimension size are equal', async () => { it('should return if model and DB dimension size are equal', async () => {
mocks.search.getDimensionSize.mockResolvedValue(512); mocks.database.getDimensionSize.mockResolvedValue(512);
await sut.onConfigUpdate({ await sut.onConfigUpdate({
newConfig: { newConfig: {
@ -106,13 +106,13 @@ describe(SmartInfoService.name, () => {
} as SystemConfig, } as SystemConfig,
}); });
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1); expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
}); });
it('should update DB dimension size if model and DB have different values', async () => { it('should update DB dimension size if model and DB have different values', async () => {
mocks.search.getDimensionSize.mockResolvedValue(512); mocks.database.getDimensionSize.mockResolvedValue(512);
await sut.onConfigUpdate({ await sut.onConfigUpdate({
newConfig: { newConfig: {
@ -123,12 +123,12 @@ describe(SmartInfoService.name, () => {
} as SystemConfig, } as SystemConfig,
}); });
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1); expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(768); expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(768);
}); });
it('should clear embeddings if old and new models are different', async () => { it('should clear embeddings if old and new models are different', async () => {
mocks.search.getDimensionSize.mockResolvedValue(512); mocks.database.getDimensionSize.mockResolvedValue(512);
await sut.onConfigUpdate({ await sut.onConfigUpdate({
newConfig: { newConfig: {
@ -139,9 +139,9 @@ describe(SmartInfoService.name, () => {
} as SystemConfig, } as SystemConfig,
}); });
expect(mocks.search.deleteAllSearchEmbeddings).toHaveBeenCalled(); expect(mocks.database.deleteAllSearchEmbeddings).toHaveBeenCalled();
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1); expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
}); });
}); });
@ -151,7 +151,7 @@ describe(SmartInfoService.name, () => {
await sut.handleQueueEncodeClip({}); await sut.handleQueueEncodeClip({});
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
}); });
it('should queue the assets without clip embeddings', async () => { it('should queue the assets without clip embeddings', async () => {
@ -163,7 +163,7 @@ describe(SmartInfoService.name, () => {
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } }, { name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
]); ]);
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(false); expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(false);
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled(); expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
}); });
it('should queue all the assets', async () => { it('should queue all the assets', async () => {
@ -175,7 +175,7 @@ describe(SmartInfoService.name, () => {
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } }, { name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
]); ]);
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(true); expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(true);
expect(mocks.search.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512); expect(mocks.database.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512);
}); });
}); });

View File

@ -38,7 +38,7 @@ export class SmartInfoService extends BaseService {
await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, async () => { await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, async () => {
const { dimSize } = getCLIPModelInfo(newConfig.machineLearning.clip.modelName); const { dimSize } = getCLIPModelInfo(newConfig.machineLearning.clip.modelName);
const dbDimSize = await this.searchRepository.getDimensionSize(); const dbDimSize = await this.databaseRepository.getDimensionSize('smart_search');
this.logger.verbose(`Current database CLIP dimension size is ${dbDimSize}`); this.logger.verbose(`Current database CLIP dimension size is ${dbDimSize}`);
const modelChange = const modelChange =
@ -53,10 +53,10 @@ export class SmartInfoService extends BaseService {
`Dimension size of model ${newConfig.machineLearning.clip.modelName} is ${dimSize}, but database expects ${dbDimSize}.`, `Dimension size of model ${newConfig.machineLearning.clip.modelName} is ${dimSize}, but database expects ${dbDimSize}.`,
); );
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`); this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
await this.searchRepository.setDimensionSize(dimSize); await this.databaseRepository.setDimensionSize(dimSize);
this.logger.log(`Successfully updated database CLIP dimension size from ${dbDimSize} to ${dimSize}.`); this.logger.log(`Successfully updated database CLIP dimension size from ${dbDimSize} to ${dimSize}.`);
} else { } else {
await this.searchRepository.deleteAllSearchEmbeddings(); await this.databaseRepository.deleteAllSearchEmbeddings();
} }
// TODO: A job to reindex all assets should be scheduled, though user // TODO: A job to reindex all assets should be scheduled, though user
@ -74,7 +74,7 @@ export class SmartInfoService extends BaseService {
if (force) { if (force) {
const { dimSize } = getCLIPModelInfo(machineLearning.clip.modelName); const { dimSize } = getCLIPModelInfo(machineLearning.clip.modelName);
// in addition to deleting embeddings, update the dimension size in case it failed earlier // in addition to deleting embeddings, update the dimension size in case it failed earlier
await this.searchRepository.setDimensionSize(dimSize); await this.databaseRepository.setDimensionSize(dimSize);
} }
let queue: JobItem[] = []; let queue: JobItem[] = [];

View File

@ -397,12 +397,14 @@ export function vectorIndexQuery({ vectorExtension, table, indexName, lists }: V
lists = [${lists ?? 1}] lists = [${lists ?? 1}]
spherical_centroids = true spherical_centroids = true
build_threads = 4 build_threads = 4
sampling_factor = 1024
$$)`; $$)`;
} }
case DatabaseExtension.VECTORS: { case DatabaseExtension.VECTORS: {
return ` return `
CREATE INDEX IF NOT EXISTS ${indexName} ON ${table} CREATE INDEX IF NOT EXISTS ${indexName} ON ${table}
USING vectors (embedding vector_cos_ops) WITH (options = $$ USING vectors (embedding vector_cos_ops) WITH (options = $$
optimizing.optimizing_threads = 4
[indexing.hnsw] [indexing.hnsw]
m = 16 m = 16
ef_construction = 300 ef_construction = 300