mirror of
https://github.com/immich-app/immich.git
synced 2025-07-09 03:04:16 -04:00
set probes
This commit is contained in:
parent
c80b16d24e
commit
b750440f90
@ -11,11 +11,3 @@ WHERE
|
|||||||
|
|
||||||
-- DatabaseRepository.getPostgresVersion
|
-- DatabaseRepository.getPostgresVersion
|
||||||
SHOW server_version
|
SHOW server_version
|
||||||
|
|
||||||
-- DatabaseRepository.shouldReindex
|
|
||||||
SELECT
|
|
||||||
idx_status
|
|
||||||
FROM
|
|
||||||
pg_vector_index_stat
|
|
||||||
WHERE
|
|
||||||
indexname = $1
|
|
||||||
|
@ -64,6 +64,7 @@ limit
|
|||||||
$15
|
$15
|
||||||
|
|
||||||
-- SearchRepository.searchSmart
|
-- SearchRepository.searchSmart
|
||||||
|
begin
|
||||||
select
|
select
|
||||||
"assets".*
|
"assets".*
|
||||||
from
|
from
|
||||||
@ -83,8 +84,10 @@ limit
|
|||||||
$7
|
$7
|
||||||
offset
|
offset
|
||||||
$8
|
$8
|
||||||
|
rollback
|
||||||
|
|
||||||
-- SearchRepository.searchDuplicates
|
-- SearchRepository.searchDuplicates
|
||||||
|
begin
|
||||||
with
|
with
|
||||||
"cte" as (
|
"cte" as (
|
||||||
select
|
select
|
||||||
@ -102,18 +105,20 @@ with
|
|||||||
and "assets"."id" != $5::uuid
|
and "assets"."id" != $5::uuid
|
||||||
and "assets"."stackId" is null
|
and "assets"."stackId" is null
|
||||||
order by
|
order by
|
||||||
smart_search.embedding <=> $6
|
"distance"
|
||||||
limit
|
limit
|
||||||
$7
|
$6
|
||||||
)
|
)
|
||||||
select
|
select
|
||||||
*
|
*
|
||||||
from
|
from
|
||||||
"cte"
|
"cte"
|
||||||
where
|
where
|
||||||
"cte"."distance" <= $8
|
"cte"."distance" <= $7
|
||||||
|
rollback
|
||||||
|
|
||||||
-- SearchRepository.searchFaces
|
-- SearchRepository.searchFaces
|
||||||
|
begin
|
||||||
with
|
with
|
||||||
"cte" as (
|
"cte" as (
|
||||||
select
|
select
|
||||||
@ -129,16 +134,17 @@ with
|
|||||||
"assets"."ownerId" = any ($2::uuid[])
|
"assets"."ownerId" = any ($2::uuid[])
|
||||||
and "assets"."deletedAt" is null
|
and "assets"."deletedAt" is null
|
||||||
order by
|
order by
|
||||||
face_search.embedding <=> $3
|
"distance"
|
||||||
limit
|
limit
|
||||||
$4
|
$3
|
||||||
)
|
)
|
||||||
select
|
select
|
||||||
*
|
*
|
||||||
from
|
from
|
||||||
"cte"
|
"cte"
|
||||||
where
|
where
|
||||||
"cte"."distance" <= $5
|
"cte"."distance" <= $4
|
||||||
|
commit
|
||||||
|
|
||||||
-- SearchRepository.searchPlaces
|
-- SearchRepository.searchPlaces
|
||||||
select
|
select
|
||||||
|
@ -25,7 +25,7 @@ import { vectorIndexQuery } from 'src/utils/database';
|
|||||||
import { isValidInteger } from 'src/validation';
|
import { isValidInteger } from 'src/validation';
|
||||||
import { DataSource, QueryRunner } from 'typeorm';
|
import { DataSource, QueryRunner } from 'typeorm';
|
||||||
|
|
||||||
let cachedVectorExtension: VectorExtension | undefined;
|
export let cachedVectorExtension: VectorExtension | undefined;
|
||||||
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
|
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
|
||||||
if (cachedVectorExtension) {
|
if (cachedVectorExtension) {
|
||||||
return cachedVectorExtension;
|
return cachedVectorExtension;
|
||||||
@ -50,6 +50,11 @@ export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Prom
|
|||||||
return cachedVectorExtension;
|
return cachedVectorExtension;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const probes: Record<VectorIndex, number> = {
|
||||||
|
[VectorIndex.CLIP]: 1,
|
||||||
|
[VectorIndex.FACE]: 1,
|
||||||
|
};
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class DatabaseRepository {
|
export class DatabaseRepository {
|
||||||
private readonly asyncLock = new AsyncLock();
|
private readonly asyncLock = new AsyncLock();
|
||||||
@ -183,21 +188,17 @@ export class DatabaseRepository {
|
|||||||
for (const indexName of names) {
|
for (const indexName of names) {
|
||||||
const row = rows.find((index) => index.indexname === indexName);
|
const row = rows.find((index) => index.indexname === indexName);
|
||||||
const table = VECTOR_INDEX_TABLES[indexName];
|
const table = VECTOR_INDEX_TABLES[indexName];
|
||||||
if (!row) {
|
|
||||||
promises.push(this.reindexVectors(indexName));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (vectorExtension) {
|
switch (vectorExtension) {
|
||||||
case DatabaseExtension.VECTOR:
|
case DatabaseExtension.VECTOR:
|
||||||
case DatabaseExtension.VECTORS: {
|
case DatabaseExtension.VECTORS: {
|
||||||
if (!row.indexdef.toLowerCase().includes(keyword)) {
|
if (!row?.indexdef.toLowerCase().includes(keyword)) {
|
||||||
promises.push(this.reindexVectors(indexName));
|
promises.push(this.reindexVectors(indexName));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DatabaseExtension.VECTORCHORD: {
|
case DatabaseExtension.VECTORCHORD: {
|
||||||
const matches = row.indexdef.match(/(?<=lists = \[)\d+/g);
|
const matches = row?.indexdef.match(/(?<=lists = \[)\d+/g);
|
||||||
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
|
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
|
||||||
promises.push(
|
promises.push(
|
||||||
this.db
|
this.db
|
||||||
@ -208,11 +209,14 @@ export class DatabaseRepository {
|
|||||||
const targetLists = this.targetListCount(count);
|
const targetLists = this.targetListCount(count);
|
||||||
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
|
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
|
||||||
if (
|
if (
|
||||||
!row.indexdef.toLowerCase().includes(keyword) ||
|
!row?.indexdef.toLowerCase().includes(keyword) ||
|
||||||
// slack factor is to avoid frequent reindexing if the count is borderline
|
// slack factor is to avoid frequent reindexing if the count is borderline
|
||||||
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
|
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
|
||||||
) {
|
) {
|
||||||
|
probes[indexName] = this.targetProbeCount(targetLists);
|
||||||
return this.reindexVectors(indexName, { lists: targetLists });
|
return this.reindexVectors(indexName, { lists: targetLists });
|
||||||
|
} else {
|
||||||
|
probes[indexName] = this.targetProbeCount(lists);
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
@ -239,7 +243,7 @@ export class DatabaseRepository {
|
|||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const dimSize = await this.getDimSize(table);
|
const dimSize = await this.getDimensionSize(table);
|
||||||
await this.db.transaction().execute(async (tx) => {
|
await this.db.transaction().execute(async (tx) => {
|
||||||
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
|
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
|
||||||
if (!rows.some((row) => row.columnName === 'embedding')) {
|
if (!rows.some((row) => row.columnName === 'embedding')) {
|
||||||
@ -261,7 +265,7 @@ export class DatabaseRepository {
|
|||||||
await sql`SET search_path TO "$user", public, vectors`.execute(tx);
|
await sql`SET search_path TO "$user", public, vectors`.execute(tx);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
|
async getDimensionSize(table: string, column = 'embedding'): Promise<number> {
|
||||||
const { rows } = await sql<{ dimsize: number }>`
|
const { rows } = await sql<{ dimsize: number }>`
|
||||||
SELECT atttypmod as dimsize
|
SELECT atttypmod as dimsize
|
||||||
FROM pg_attribute f
|
FROM pg_attribute f
|
||||||
@ -280,7 +284,41 @@ export class DatabaseRepository {
|
|||||||
return dimSize;
|
return dimSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: set probes in queries
|
async setDimensionSize(dimSize: number): Promise<void> {
|
||||||
|
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||||
|
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is done in two transactions to handle concurrent writes
|
||||||
|
await this.db.transaction().execute(async (trx) => {
|
||||||
|
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
|
||||||
|
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||||
|
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
|
||||||
|
trx,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
const vectorExtension = await this.getVectorExtension();
|
||||||
|
await this.db.transaction().execute(async (trx) => {
|
||||||
|
await sql`drop index if exists clip_index`.execute(trx);
|
||||||
|
await trx.schema
|
||||||
|
.alterTable('smart_search')
|
||||||
|
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
||||||
|
.execute();
|
||||||
|
await sql
|
||||||
|
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
|
||||||
|
.execute(trx);
|
||||||
|
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||||
|
});
|
||||||
|
probes[VectorIndex.CLIP] = 1;
|
||||||
|
|
||||||
|
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
|
||||||
|
}
|
||||||
|
|
||||||
|
async deleteAllSearchEmbeddings(): Promise<void> {
|
||||||
|
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
||||||
|
}
|
||||||
|
|
||||||
private targetListCount(count: number) {
|
private targetListCount(count: number) {
|
||||||
if (count < 128_000) {
|
if (count < 128_000) {
|
||||||
return 1;
|
return 1;
|
||||||
@ -291,6 +329,10 @@ export class DatabaseRepository {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private targetProbeCount(lists: number) {
|
||||||
|
return Math.ceil(lists / 8);
|
||||||
|
}
|
||||||
|
|
||||||
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
|
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
|
||||||
const { database } = this.configRepository.getEnv();
|
const { database } = this.configRepository.getEnv();
|
||||||
|
|
||||||
|
@ -5,9 +5,9 @@ import { randomUUID } from 'node:crypto';
|
|||||||
import { DB, Exif } from 'src/db';
|
import { DB, Exif } from 'src/db';
|
||||||
import { DummyValue, GenerateSql } from 'src/decorators';
|
import { DummyValue, GenerateSql } from 'src/decorators';
|
||||||
import { MapAsset } from 'src/dtos/asset-response.dto';
|
import { MapAsset } from 'src/dtos/asset-response.dto';
|
||||||
import { AssetStatus, AssetType, AssetVisibility, VectorIndex } from 'src/enum';
|
import { AssetStatus, AssetType, AssetVisibility, DatabaseExtension, VectorIndex } from 'src/enum';
|
||||||
import { getVectorExtension } from 'src/repositories/database.repository';
|
import { cachedVectorExtension, probes } from 'src/repositories/database.repository';
|
||||||
import { anyUuid, asUuid, searchAssetBuilder, vectorIndexQuery } from 'src/utils/database';
|
import { anyUuid, asUuid, searchAssetBuilder } from 'src/utils/database';
|
||||||
import { paginationHelper } from 'src/utils/pagination';
|
import { paginationHelper } from 'src/utils/pagination';
|
||||||
import { isValidInteger } from 'src/validation';
|
import { isValidInteger } from 'src/validation';
|
||||||
|
|
||||||
@ -233,19 +233,23 @@ export class SearchRepository {
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
|
searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
|
||||||
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
|
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
|
||||||
throw new Error(`Invalid value for 'size': ${pagination.size}`);
|
throw new Error(`Invalid value for 'size': ${pagination.size}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const items = await searchAssetBuilder(this.db, options)
|
return this.db.transaction().execute(async (trx) => {
|
||||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
|
||||||
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
|
await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
|
||||||
.limit(pagination.size + 1)
|
}
|
||||||
.offset((pagination.page - 1) * pagination.size)
|
const items = await searchAssetBuilder(trx, options)
|
||||||
.execute();
|
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||||
|
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
|
||||||
return paginationHelper(items, pagination.size);
|
.limit(pagination.size + 1)
|
||||||
|
.offset((pagination.page - 1) * pagination.size)
|
||||||
|
.execute();
|
||||||
|
return paginationHelper(items, pagination.size);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@GenerateSql({
|
@GenerateSql({
|
||||||
@ -260,29 +264,35 @@ export class SearchRepository {
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
|
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
|
||||||
return this.db
|
return this.db.transaction().execute(async (trx) => {
|
||||||
.with('cte', (qb) =>
|
if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
|
||||||
qb
|
await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
|
||||||
.selectFrom('assets')
|
}
|
||||||
.select([
|
|
||||||
'assets.id as assetId',
|
return await trx
|
||||||
'assets.duplicateId',
|
.with('cte', (qb) =>
|
||||||
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
|
qb
|
||||||
])
|
.selectFrom('assets')
|
||||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
.select([
|
||||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
'assets.id as assetId',
|
||||||
.where('assets.deletedAt', 'is', null)
|
'assets.duplicateId',
|
||||||
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
|
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
|
||||||
.where('assets.type', '=', type)
|
])
|
||||||
.where('assets.id', '!=', asUuid(assetId))
|
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||||
.where('assets.stackId', 'is', null)
|
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||||
.orderBy(sql`smart_search.embedding <=> ${embedding}`)
|
.where('assets.deletedAt', 'is', null)
|
||||||
.limit(64),
|
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
|
||||||
)
|
.where('assets.type', '=', type)
|
||||||
.selectFrom('cte')
|
.where('assets.id', '!=', asUuid(assetId))
|
||||||
.selectAll()
|
.where('assets.stackId', 'is', null)
|
||||||
.where('cte.distance', '<=', maxDistance as number)
|
.orderBy('distance')
|
||||||
.execute();
|
.limit(64),
|
||||||
|
)
|
||||||
|
.selectFrom('cte')
|
||||||
|
.selectAll()
|
||||||
|
.where('cte.distance', '<=', maxDistance as number)
|
||||||
|
.execute();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@GenerateSql({
|
@GenerateSql({
|
||||||
@ -300,31 +310,39 @@ export class SearchRepository {
|
|||||||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.db
|
return this.db.transaction().execute(async (trx) => {
|
||||||
.with('cte', (qb) =>
|
if (cachedVectorExtension === DatabaseExtension.VECTORCHORD) {
|
||||||
qb
|
await sql`set local vchord.probes = ${sql.lit(probes[VectorIndex.FACE])}`.execute(trx);
|
||||||
.selectFrom('asset_faces')
|
}
|
||||||
.select([
|
|
||||||
'asset_faces.id',
|
return await trx
|
||||||
'asset_faces.personId',
|
.with('cte', (qb) =>
|
||||||
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
|
qb
|
||||||
])
|
.selectFrom('asset_faces')
|
||||||
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
|
.select([
|
||||||
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
|
'asset_faces.id',
|
||||||
.leftJoin('person', 'person.id', 'asset_faces.personId')
|
'asset_faces.personId',
|
||||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
|
||||||
.where('assets.deletedAt', 'is', null)
|
])
|
||||||
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
|
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
|
||||||
.$if(!!minBirthDate, (qb) =>
|
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
|
||||||
qb.where((eb) => eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)])),
|
.leftJoin('person', 'person.id', 'asset_faces.personId')
|
||||||
)
|
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||||
.orderBy(sql`face_search.embedding <=> ${embedding}`)
|
.where('assets.deletedAt', 'is', null)
|
||||||
.limit(numResults),
|
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
|
||||||
)
|
.$if(!!minBirthDate, (qb) =>
|
||||||
.selectFrom('cte')
|
qb.where((eb) =>
|
||||||
.selectAll()
|
eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)]),
|
||||||
.where('cte.distance', '<=', maxDistance)
|
),
|
||||||
.execute();
|
)
|
||||||
|
.orderBy('distance')
|
||||||
|
.limit(numResults),
|
||||||
|
)
|
||||||
|
.selectFrom('cte')
|
||||||
|
.selectAll()
|
||||||
|
.where('cte.distance', '<=', maxDistance)
|
||||||
|
.execute();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@GenerateSql({ params: [DummyValue.STRING] })
|
@GenerateSql({ params: [DummyValue.STRING] })
|
||||||
@ -413,58 +431,6 @@ export class SearchRepository {
|
|||||||
.execute();
|
.execute();
|
||||||
}
|
}
|
||||||
|
|
||||||
async getDimensionSize(): Promise<number> {
|
|
||||||
const { rows } = await sql<{ dimsize: number }>`
|
|
||||||
select atttypmod as dimsize
|
|
||||||
from pg_attribute f
|
|
||||||
join pg_class c ON c.oid = f.attrelid
|
|
||||||
where c.relkind = 'r'::char
|
|
||||||
and f.attnum > 0
|
|
||||||
and c.relname = 'smart_search'
|
|
||||||
and f.attname = 'embedding'
|
|
||||||
`.execute(this.db);
|
|
||||||
|
|
||||||
const dimSize = rows[0]['dimsize'];
|
|
||||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
|
||||||
throw new Error(`Could not retrieve CLIP dimension size`);
|
|
||||||
}
|
|
||||||
return dimSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
async setDimensionSize(dimSize: number): Promise<void> {
|
|
||||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
|
||||||
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// this is done in two transactions to handle concurrent writes
|
|
||||||
await this.db.transaction().execute(async (trx) => {
|
|
||||||
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
|
|
||||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
|
||||||
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
|
|
||||||
trx,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
const vectorExtension = await getVectorExtension(this.db);
|
|
||||||
await this.db.transaction().execute(async (trx) => {
|
|
||||||
await sql`drop index if exists clip_index`.execute(trx);
|
|
||||||
await trx.schema
|
|
||||||
.alterTable('smart_search')
|
|
||||||
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
|
||||||
.execute();
|
|
||||||
await sql
|
|
||||||
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
|
|
||||||
.execute(trx);
|
|
||||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
|
||||||
});
|
|
||||||
|
|
||||||
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
|
|
||||||
}
|
|
||||||
|
|
||||||
async deleteAllSearchEmbeddings(): Promise<void> {
|
|
||||||
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
|
||||||
}
|
|
||||||
|
|
||||||
async getCountries(userIds: string[]): Promise<string[]> {
|
async getCountries(userIds: string[]): Promise<string[]> {
|
||||||
const res = await this.getExifField('country', userIds).execute();
|
const res = await this.getExifField('country', userIds).execute();
|
||||||
return res.map((row) => row.country!);
|
return res.map((row) => row.country!);
|
||||||
|
@ -109,7 +109,10 @@ export class DatabaseService extends BaseService {
|
|||||||
if (!database.skipMigrations) {
|
if (!database.skipMigrations) {
|
||||||
await this.databaseRepository.runMigrations();
|
await this.databaseRepository.runMigrations();
|
||||||
}
|
}
|
||||||
await this.databaseRepository.prewarm(VectorIndex.CLIP);
|
await Promise.all([
|
||||||
|
this.databaseRepository.prewarm(VectorIndex.CLIP),
|
||||||
|
this.databaseRepository.prewarm(VectorIndex.FACE),
|
||||||
|
]);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,28 +54,28 @@ describe(SmartInfoService.name, () => {
|
|||||||
it('should return if machine learning is disabled', async () => {
|
it('should return if machine learning is disabled', async () => {
|
||||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningDisabled as SystemConfig });
|
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningDisabled as SystemConfig });
|
||||||
|
|
||||||
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
|
||||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return if model and DB dimension size are equal', async () => {
|
it('should return if model and DB dimension size are equal', async () => {
|
||||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||||
|
|
||||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
||||||
|
|
||||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should update DB dimension size if model and DB have different values', async () => {
|
it('should update DB dimension size if model and DB have different values', async () => {
|
||||||
mocks.search.getDimensionSize.mockResolvedValue(768);
|
mocks.database.getDimensionSize.mockResolvedValue(768);
|
||||||
|
|
||||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
||||||
|
|
||||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(512);
|
expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(512);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -89,13 +89,13 @@ describe(SmartInfoService.name, () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
expect(mocks.systemMetadata.get).not.toHaveBeenCalled();
|
expect(mocks.systemMetadata.get).not.toHaveBeenCalled();
|
||||||
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
|
||||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return if model and DB dimension size are equal', async () => {
|
it('should return if model and DB dimension size are equal', async () => {
|
||||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||||
|
|
||||||
await sut.onConfigUpdate({
|
await sut.onConfigUpdate({
|
||||||
newConfig: {
|
newConfig: {
|
||||||
@ -106,13 +106,13 @@ describe(SmartInfoService.name, () => {
|
|||||||
} as SystemConfig,
|
} as SystemConfig,
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should update DB dimension size if model and DB have different values', async () => {
|
it('should update DB dimension size if model and DB have different values', async () => {
|
||||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||||
|
|
||||||
await sut.onConfigUpdate({
|
await sut.onConfigUpdate({
|
||||||
newConfig: {
|
newConfig: {
|
||||||
@ -123,12 +123,12 @@ describe(SmartInfoService.name, () => {
|
|||||||
} as SystemConfig,
|
} as SystemConfig,
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(768);
|
expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(768);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should clear embeddings if old and new models are different', async () => {
|
it('should clear embeddings if old and new models are different', async () => {
|
||||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||||
|
|
||||||
await sut.onConfigUpdate({
|
await sut.onConfigUpdate({
|
||||||
newConfig: {
|
newConfig: {
|
||||||
@ -139,9 +139,9 @@ describe(SmartInfoService.name, () => {
|
|||||||
} as SystemConfig,
|
} as SystemConfig,
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mocks.search.deleteAllSearchEmbeddings).toHaveBeenCalled();
|
expect(mocks.database.deleteAllSearchEmbeddings).toHaveBeenCalled();
|
||||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ describe(SmartInfoService.name, () => {
|
|||||||
|
|
||||||
await sut.handleQueueEncodeClip({});
|
await sut.handleQueueEncodeClip({});
|
||||||
|
|
||||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should queue the assets without clip embeddings', async () => {
|
it('should queue the assets without clip embeddings', async () => {
|
||||||
@ -163,7 +163,7 @@ describe(SmartInfoService.name, () => {
|
|||||||
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
||||||
]);
|
]);
|
||||||
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(false);
|
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(false);
|
||||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should queue all the assets', async () => {
|
it('should queue all the assets', async () => {
|
||||||
@ -175,7 +175,7 @@ describe(SmartInfoService.name, () => {
|
|||||||
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
||||||
]);
|
]);
|
||||||
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(true);
|
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(true);
|
||||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512);
|
expect(mocks.database.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ export class SmartInfoService extends BaseService {
|
|||||||
|
|
||||||
await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, async () => {
|
await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, async () => {
|
||||||
const { dimSize } = getCLIPModelInfo(newConfig.machineLearning.clip.modelName);
|
const { dimSize } = getCLIPModelInfo(newConfig.machineLearning.clip.modelName);
|
||||||
const dbDimSize = await this.searchRepository.getDimensionSize();
|
const dbDimSize = await this.databaseRepository.getDimensionSize('smart_search');
|
||||||
this.logger.verbose(`Current database CLIP dimension size is ${dbDimSize}`);
|
this.logger.verbose(`Current database CLIP dimension size is ${dbDimSize}`);
|
||||||
|
|
||||||
const modelChange =
|
const modelChange =
|
||||||
@ -53,10 +53,10 @@ export class SmartInfoService extends BaseService {
|
|||||||
`Dimension size of model ${newConfig.machineLearning.clip.modelName} is ${dimSize}, but database expects ${dbDimSize}.`,
|
`Dimension size of model ${newConfig.machineLearning.clip.modelName} is ${dimSize}, but database expects ${dbDimSize}.`,
|
||||||
);
|
);
|
||||||
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
|
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
|
||||||
await this.searchRepository.setDimensionSize(dimSize);
|
await this.databaseRepository.setDimensionSize(dimSize);
|
||||||
this.logger.log(`Successfully updated database CLIP dimension size from ${dbDimSize} to ${dimSize}.`);
|
this.logger.log(`Successfully updated database CLIP dimension size from ${dbDimSize} to ${dimSize}.`);
|
||||||
} else {
|
} else {
|
||||||
await this.searchRepository.deleteAllSearchEmbeddings();
|
await this.databaseRepository.deleteAllSearchEmbeddings();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: A job to reindex all assets should be scheduled, though user
|
// TODO: A job to reindex all assets should be scheduled, though user
|
||||||
@ -74,7 +74,7 @@ export class SmartInfoService extends BaseService {
|
|||||||
if (force) {
|
if (force) {
|
||||||
const { dimSize } = getCLIPModelInfo(machineLearning.clip.modelName);
|
const { dimSize } = getCLIPModelInfo(machineLearning.clip.modelName);
|
||||||
// in addition to deleting embeddings, update the dimension size in case it failed earlier
|
// in addition to deleting embeddings, update the dimension size in case it failed earlier
|
||||||
await this.searchRepository.setDimensionSize(dimSize);
|
await this.databaseRepository.setDimensionSize(dimSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
let queue: JobItem[] = [];
|
let queue: JobItem[] = [];
|
||||||
|
@ -397,12 +397,14 @@ export function vectorIndexQuery({ vectorExtension, table, indexName, lists }: V
|
|||||||
lists = [${lists ?? 1}]
|
lists = [${lists ?? 1}]
|
||||||
spherical_centroids = true
|
spherical_centroids = true
|
||||||
build_threads = 4
|
build_threads = 4
|
||||||
|
sampling_factor = 1024
|
||||||
$$)`;
|
$$)`;
|
||||||
}
|
}
|
||||||
case DatabaseExtension.VECTORS: {
|
case DatabaseExtension.VECTORS: {
|
||||||
return `
|
return `
|
||||||
CREATE INDEX IF NOT EXISTS ${indexName} ON ${table}
|
CREATE INDEX IF NOT EXISTS ${indexName} ON ${table}
|
||||||
USING vectors (embedding vector_cos_ops) WITH (options = $$
|
USING vectors (embedding vector_cos_ops) WITH (options = $$
|
||||||
|
optimizing.optimizing_threads = 4
|
||||||
[indexing.hnsw]
|
[indexing.hnsw]
|
||||||
m = 16
|
m = 16
|
||||||
ef_construction = 300
|
ef_construction = 300
|
||||||
|
Loading…
x
Reference in New Issue
Block a user