diff --git a/server/src/queries/search.repository.sql b/server/src/queries/search.repository.sql index 06590dc817..73f276a7fb 100644 --- a/server/src/queries/search.repository.sql +++ b/server/src/queries/search.repository.sql @@ -136,6 +136,7 @@ with "asset_faces" inner join "assets" on "assets"."id" = "asset_faces"."assetId" inner join "face_search" on "face_search"."faceId" = "asset_faces"."id" + left join "person" on "person"."id" = "asset_faces"."personId" where "assets"."ownerId" = any ($2::uuid[]) and "assets"."deletedAt" is null diff --git a/server/src/repositories/search.repository.ts b/server/src/repositories/search.repository.ts index e2e389f47c..954ab0fe5a 100644 --- a/server/src/repositories/search.repository.ts +++ b/server/src/repositories/search.repository.ts @@ -163,6 +163,7 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions { hasPerson?: boolean; numResults: number; maxDistance: number; + minBirthDate?: Date; } export interface AssetDuplicateSearch { @@ -338,7 +339,7 @@ export class SearchRepository { }, ], }) - searchFaces({ userIds, embedding, numResults, maxDistance, hasPerson }: FaceEmbeddingSearch) { + searchFaces({ userIds, embedding, numResults, maxDistance, hasPerson, minBirthDate }: FaceEmbeddingSearch) { if (!isValidInteger(numResults, { min: 1, max: 1000 })) { throw new Error(`Invalid value for 'numResults': ${numResults}`); } @@ -354,9 +355,13 @@ export class SearchRepository { ]) .innerJoin('assets', 'assets.id', 'asset_faces.assetId') .innerJoin('face_search', 'face_search.faceId', 'asset_faces.id') + .leftJoin('person', 'person.id', 'asset_faces.personId') .where('assets.ownerId', '=', anyUuid(userIds)) .where('assets.deletedAt', 'is', null) .$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null)) + .$if(!!minBirthDate, (qb) => + qb.where((eb) => eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)])), + ) .orderBy(sql`face_search.embedding <=> ${embedding}`) .limit(numResults), ) diff --git a/server/src/services/person.service.spec.ts b/server/src/services/person.service.spec.ts index 073cf71247..1d8cdfd3b9 100644 --- a/server/src/services/person.service.spec.ts +++ b/server/src/services/person.service.spec.ts @@ -896,6 +896,66 @@ describe(PersonService.name, () => { }); }); + it('should match existing person if their birth date is unknown', async () => { + if (!faceStub.primaryFace1.person) { + throw new Error('faceStub.primaryFace1.person is null'); + } + + const faces = [ + { ...faceStub.noPerson1, distance: 0 }, + { ...faceStub.primaryFace1, distance: 0.2 }, + { ...faceStub.withBirthDate, distance: 0.3 }, + ] as FaceSearchResult[]; + + mocks.systemMetadata.get.mockResolvedValue({ machineLearning: { facialRecognition: { minFaces: 1 } } }); + mocks.search.searchFaces.mockResolvedValue(faces); + mocks.person.getFaceByIdWithAssets.mockResolvedValue(faceStub.noPerson1); + mocks.person.create.mockResolvedValue(faceStub.primaryFace1.person); + + await sut.handleRecognizeFaces({ id: faceStub.noPerson1.id }); + + expect(mocks.person.create).not.toHaveBeenCalled(); + expect(mocks.person.reassignFaces).toHaveBeenCalledTimes(1); + expect(mocks.person.reassignFaces).toHaveBeenCalledWith({ + faceIds: expect.arrayContaining([faceStub.noPerson1.id]), + newPersonId: faceStub.primaryFace1.person.id, + }); + expect(mocks.person.reassignFaces).toHaveBeenCalledWith({ + faceIds: expect.not.arrayContaining([faceStub.face1.id]), + newPersonId: faceStub.primaryFace1.person.id, + }); + }); + + it('should match existing person if their birth date is before file creation', async () => { + if (!faceStub.primaryFace1.person) { + throw new Error('faceStub.primaryFace1.person is null'); + } + + const faces = [ + { ...faceStub.noPerson1, distance: 0 }, + { ...faceStub.withBirthDate, distance: 0.2 }, + { ...faceStub.primaryFace1, distance: 0.3 }, + ] as FaceSearchResult[]; + + mocks.systemMetadata.get.mockResolvedValue({ machineLearning: { facialRecognition: { minFaces: 1 } } }); + mocks.search.searchFaces.mockResolvedValue(faces); + mocks.person.getFaceByIdWithAssets.mockResolvedValue(faceStub.noPerson1); + mocks.person.create.mockResolvedValue(faceStub.primaryFace1.person); + + await sut.handleRecognizeFaces({ id: faceStub.noPerson1.id }); + + expect(mocks.person.create).not.toHaveBeenCalled(); + expect(mocks.person.reassignFaces).toHaveBeenCalledTimes(1); + expect(mocks.person.reassignFaces).toHaveBeenCalledWith({ + faceIds: expect.arrayContaining([faceStub.noPerson1.id]), + newPersonId: faceStub.withBirthDate.person?.id, + }); + expect(mocks.person.reassignFaces).toHaveBeenCalledWith({ + faceIds: expect.not.arrayContaining([faceStub.face1.id]), + newPersonId: faceStub.withBirthDate.person?.id, + }); + }); + it('should create a new person if the face is a core point with no person', async () => { const faces = [ { ...faceStub.noPerson1, distance: 0 }, diff --git a/server/src/services/person.service.ts b/server/src/services/person.service.ts index b34b0ddcff..ec412ad307 100644 --- a/server/src/services/person.service.ts +++ b/server/src/services/person.service.ts @@ -483,6 +483,7 @@ export class PersonService extends BaseService { embedding: face.faceSearch.embedding, maxDistance: machineLearning.facialRecognition.maxDistance, numResults: machineLearning.facialRecognition.minFaces, + minBirthDate: face.asset.fileCreatedAt, }); // `matches` also includes the face itself @@ -508,6 +509,7 @@ export class PersonService extends BaseService { maxDistance: machineLearning.facialRecognition.maxDistance, numResults: 1, hasPerson: true, + minBirthDate: face.asset.fileCreatedAt, }); if (matchWithPerson.length > 0) { diff --git a/server/test/factory.ts b/server/test/factory.ts index 028b530255..faca12d068 100644 --- a/server/test/factory.ts +++ b/server/test/factory.ts @@ -1,9 +1,9 @@ import { Insertable, Kysely } from 'kysely'; import { randomBytes } from 'node:crypto'; import { Writable } from 'node:stream'; -import { Assets, DB, Partners, Sessions } from 'src/db'; +import { AssetFaces, Assets, DB, Person as DbPerson, FaceSearch, Partners, Sessions } from 'src/db'; import { AuthDto } from 'src/dtos/auth.dto'; -import { AssetType } from 'src/enum'; +import { AssetType, SourceType } from 'src/enum'; import { AccessRepository } from 'src/repositories/access.repository'; import { ActivityRepository } from 'src/repositories/activity.repository'; import { AlbumRepository } from 'src/repositories/album.repository'; @@ -37,7 +37,7 @@ import { VersionHistoryRepository } from 'src/repositories/version-history.repos import { ViewRepository } from 'src/repositories/view-repository'; import { UserTable } from 'src/schema/tables/user.table'; import { newTelemetryRepositoryMock } from 'test/repositories/telemetry.repository.mock'; -import { newUuid } from 'test/small.factory'; +import { newDate, newEmbedding, newUuid } from 'test/small.factory'; import { automock } from 'test/utils'; class CustomWritable extends Writable { @@ -61,12 +61,18 @@ type Asset = Partial>; type User = Partial>; type Session = Omit, 'token'> & { token?: string }; type Partner = Insertable; +type AssetFace = Partial>; +type Person = Partial>; +type Face = Partial>; export class TestFactory { private assets: Asset[] = []; private sessions: Session[] = []; private users: User[] = []; private partners: Partner[] = []; + private assetFaces: AssetFace[] = []; + private persons: Person[] = []; + private faces: Face[] = []; private constructor(private context: TestContext) {} @@ -141,6 +147,53 @@ export class TestFactory { }; } + static assetFace(assetFace: AssetFace) { + const defaults = { + assetId: assetFace.assetId || newUuid(), + boundingBoxX1: assetFace.boundingBoxX1 || 0, + boundingBoxX2: assetFace.boundingBoxX2 || 1, + boundingBoxY1: assetFace.boundingBoxY1 || 0, + boundingBoxY2: assetFace.boundingBoxY2 || 1, + deletedAt: assetFace.deletedAt || null, + id: assetFace.id || newUuid(), + imageHeight: assetFace.imageHeight || 10, + imageWidth: assetFace.imageWidth || 10, + personId: assetFace.personId || null, + sourceType: assetFace.sourceType || SourceType.MACHINE_LEARNING, + }; + + return { ...defaults, ...assetFace }; + } + + static person(person: Person) { + const defaults = { + birthDate: person.birthDate || null, + color: person.color || null, + createdAt: person.createdAt || newDate(), + faceAssetId: person.faceAssetId || null, + id: person.id || newUuid(), + isFavorite: person.isFavorite || false, + isHidden: person.isHidden || false, + name: person.name || 'Test Name', + ownerId: person.ownerId || newUuid(), + thumbnailPath: person.thumbnailPath || '/path/to/thumbnail.jpg', + updatedAt: person.updatedAt || newDate(), + updateId: person.updateId || newUuid(), + }; + return { ...defaults, ...person }; + } + + static face(face: Face) { + const defaults = { + faceId: face.faceId || newUuid(), + embedding: face.embedding || newEmbedding(), + }; + return { + ...defaults, + ...face, + }; + } + withAsset(asset: Asset) { this.assets.push(asset); return this; @@ -161,6 +214,21 @@ export class TestFactory { return this; } + withAssetFace(assetFace: AssetFace) { + this.assetFaces.push(assetFace); + return this; + } + + withPerson(person: Person) { + this.persons.push(person); + return this; + } + + withFaces(face: Face) { + this.faces.push(face); + return this; + } + async create() { for (const user of this.users) { await this.context.createUser(user); @@ -178,6 +246,16 @@ export class TestFactory { await this.context.createAsset(asset); } + for (const person of this.persons) { + await this.context.createPerson(person); + } + + await this.context.refreshFaces( + this.assetFaces, + [], + this.faces.map((f) => TestFactory.face(f)), + ); + return this.context; } } @@ -276,4 +354,16 @@ export class TestContext { createSession(session: Session) { return this.session.create(TestFactory.session(session)); } + + createPerson(person: Person) { + return this.person.create(TestFactory.person(person)); + } + + refreshFaces(facesToAdd: AssetFace[], faceIdsToRemove: string[], embeddingsToAdd?: Insertable[]) { + return this.person.refreshFaces( + facesToAdd.map((f) => TestFactory.assetFace(f)), + faceIdsToRemove, + embeddingsToAdd, + ); + } } diff --git a/server/test/fixtures/face.stub.ts b/server/test/fixtures/face.stub.ts index 74a59a85a8..37fab86962 100644 --- a/server/test/fixtures/face.stub.ts +++ b/server/test/fixtures/face.stub.ts @@ -164,4 +164,19 @@ export const faceStub = { sourceType: SourceType.EXIF, deletedAt: null, }), + withBirthDate: Object.freeze({ + id: 'assetFaceId10', + assetId: assetStub.image.id, + asset: assetStub.image, + personId: personStub.withBirthDate.id, + person: personStub.withBirthDate, + boundingBoxX1: 0, + boundingBoxY1: 0, + boundingBoxX2: 1, + boundingBoxY2: 1, + imageHeight: 1024, + imageWidth: 1024, + sourceType: SourceType.MACHINE_LEARNING, + deletedAt: null, + }), }; diff --git a/server/test/medium/specs/person.service.spec.ts b/server/test/medium/specs/person.service.spec.ts new file mode 100644 index 0000000000..e79564b437 --- /dev/null +++ b/server/test/medium/specs/person.service.spec.ts @@ -0,0 +1,201 @@ +import { Kysely } from 'kysely'; +import { JobStatus, SourceType } from 'src/enum'; +import { PersonService } from 'src/services/person.service'; +import { TestContext, TestFactory } from 'test/factory'; +import { newEmbedding } from 'test/small.factory'; +import { getKyselyDB, newTestService } from 'test/utils'; + +const setup = async (db: Kysely) => { + const context = await TestContext.from(db).create(); + const { sut, mocks } = newTestService(PersonService, context); + + return { sut, mocks, context }; +}; + +describe.concurrent(PersonService.name, () => { + let sut: PersonService; + let context: TestContext; + + beforeAll(async () => { + ({ sut, context } = await setup(await getKyselyDB())); + }); + + describe('handleRecognizeFaces', () => { + it('should skip if face source type is not MACHINE_LEARNING', async () => { + const user = TestFactory.user(); + const asset = TestFactory.asset({ ownerId: user.id }); + const assetFace = TestFactory.assetFace({ assetId: asset.id, sourceType: SourceType.MANUAL }); + const face = TestFactory.face({ faceId: assetFace.id }); + await context.getFactory().withUser(user).withAsset(asset).withAssetFace(assetFace).withFaces(face).create(); + + const result = await sut.handleRecognizeFaces({ id: assetFace.id, deferred: false }); + + expect(result).toBe(JobStatus.SKIPPED); + const newPersonId = await context.db + .selectFrom('asset_faces') + .select('asset_faces.personId') + .where('asset_faces.id', '=', assetFace.id) + .executeTakeFirst(); + expect(newPersonId?.personId).toBeNull(); + }); + + it('should fail if face does not have an embedding', async () => { + const user = TestFactory.user(); + const asset = TestFactory.asset({ ownerId: user.id }); + const assetFace = TestFactory.assetFace({ assetId: asset.id, sourceType: SourceType.MACHINE_LEARNING }); + await context.getFactory().withUser(user).withAsset(asset).withAssetFace(assetFace).create(); + + const result = await sut.handleRecognizeFaces({ id: assetFace.id, deferred: false }); + + expect(result).toBe(JobStatus.FAILED); + const newPersonId = await context.db + .selectFrom('asset_faces') + .select('asset_faces.personId') + .where('asset_faces.id', '=', assetFace.id) + .executeTakeFirst(); + expect(newPersonId?.personId).toBeNull(); + }); + + it('should skip if face already has a person assigned', async () => { + const user = TestFactory.user(); + const asset = TestFactory.asset({ ownerId: user.id }); + const person = TestFactory.person({ ownerId: user.id }); + const assetFace = TestFactory.assetFace({ + assetId: asset.id, + sourceType: SourceType.MACHINE_LEARNING, + personId: person.id, + }); + const face = TestFactory.face({ faceId: assetFace.id }); + await context + .getFactory() + .withUser(user) + .withAsset(asset) + .withPerson(person) + .withAssetFace(assetFace) + .withFaces(face) + .create(); + + const result = await sut.handleRecognizeFaces({ id: assetFace.id, deferred: false }); + + expect(result).toBe(JobStatus.SKIPPED); + const newPersonId = await context.db + .selectFrom('asset_faces') + .select('asset_faces.personId') + .where('asset_faces.id', '=', assetFace.id) + .executeTakeFirst(); + expect(newPersonId?.personId).toEqual(person.id); + }); + + it('should create a new person if no matches are found', async () => { + const user = TestFactory.user(); + const embedding = newEmbedding(); + + let factory = context.getFactory().withUser(user); + + for (let i = 0; i < 3; i++) { + const existingAsset = TestFactory.asset({ ownerId: user.id }); + const existingAssetFace = TestFactory.assetFace({ + assetId: existingAsset.id, + sourceType: SourceType.MACHINE_LEARNING, + }); + const existingFace = TestFactory.face({ faceId: existingAssetFace.id, embedding }); + factory = factory.withAsset(existingAsset).withAssetFace(existingAssetFace).withFaces(existingFace); + } + + const newAsset = TestFactory.asset({ ownerId: user.id }); + const newAssetFace = TestFactory.assetFace({ assetId: newAsset.id, sourceType: SourceType.MACHINE_LEARNING }); + const newFace = TestFactory.face({ faceId: newAssetFace.id, embedding }); + + await factory.withAsset(newAsset).withAssetFace(newAssetFace).withFaces(newFace).create(); + + const result = await sut.handleRecognizeFaces({ id: newAssetFace.id, deferred: false }); + + expect(result).toBe(JobStatus.SUCCESS); + + const newPersonId = await context.db + .selectFrom('asset_faces') + .select('asset_faces.personId') + .where('asset_faces.id', '=', newAssetFace.id) + .executeTakeFirstOrThrow(); + expect(newPersonId.personId).toBeDefined(); + }); + + it('should assign face to an existing person if matches are found', async () => { + const user = TestFactory.user(); + const existingPerson = TestFactory.person({ ownerId: user.id }); + const embedding = newEmbedding(); + + let factory = context.getFactory().withUser(user).withPerson(existingPerson); + + const assetFaces: string[] = []; + + for (let i = 0; i < 3; i++) { + const existingAsset = TestFactory.asset({ ownerId: user.id }); + const existingAssetFace = TestFactory.assetFace({ + assetId: existingAsset.id, + sourceType: SourceType.MACHINE_LEARNING, + }); + assetFaces.push(existingAssetFace.id); + const existingFace = TestFactory.face({ faceId: existingAssetFace.id, embedding }); + factory = factory.withAsset(existingAsset).withAssetFace(existingAssetFace).withFaces(existingFace); + } + + const newAsset = TestFactory.asset({ ownerId: user.id }); + const newAssetFace = TestFactory.assetFace({ assetId: newAsset.id, sourceType: SourceType.MACHINE_LEARNING }); + const newFace = TestFactory.face({ faceId: newAssetFace.id, embedding }); + await factory.withAsset(newAsset).withAssetFace(newAssetFace).withFaces(newFace).create(); + await context.person.reassignFaces({ newPersonId: existingPerson.id, faceIds: assetFaces }); + + const result = await sut.handleRecognizeFaces({ id: newAssetFace.id, deferred: false }); + + expect(result).toBe(JobStatus.SUCCESS); + + const after = await context.db + .selectFrom('asset_faces') + .select('asset_faces.personId') + .where('asset_faces.id', '=', newAssetFace.id) + .executeTakeFirstOrThrow(); + expect(after.personId).toEqual(existingPerson.id); + }); + + it('should not assign face to an existing person if asset is older than person', async () => { + const user = TestFactory.user(); + const assetCreatedAt = new Date('2020-02-23T05:06:29.716Z'); + const birthDate = new Date(assetCreatedAt.getTime() + 3600 * 1000 * 365); + const existingPerson = TestFactory.person({ ownerId: user.id, birthDate }); + const embedding = newEmbedding(); + + let factory = context.getFactory().withUser(user).withPerson(existingPerson); + + const assetFaces: string[] = []; + + for (let i = 0; i < 3; i++) { + const existingAsset = TestFactory.asset({ ownerId: user.id }); + const existingAssetFace = TestFactory.assetFace({ + assetId: existingAsset.id, + sourceType: SourceType.MACHINE_LEARNING, + }); + assetFaces.push(existingAssetFace.id); + const existingFace = TestFactory.face({ faceId: existingAssetFace.id, embedding }); + factory = factory.withAsset(existingAsset).withAssetFace(existingAssetFace).withFaces(existingFace); + } + + const newAsset = TestFactory.asset({ ownerId: user.id, fileCreatedAt: assetCreatedAt }); + const newAssetFace = TestFactory.assetFace({ assetId: newAsset.id, sourceType: SourceType.MACHINE_LEARNING }); + const newFace = TestFactory.face({ faceId: newAssetFace.id, embedding }); + await factory.withAsset(newAsset).withAssetFace(newAssetFace).withFaces(newFace).create(); + await context.person.reassignFaces({ newPersonId: existingPerson.id, faceIds: assetFaces }); + + const result = await sut.handleRecognizeFaces({ id: newAssetFace.id, deferred: false }); + + expect(result).toBe(JobStatus.SKIPPED); + + const after = await context.db + .selectFrom('asset_faces') + .select('asset_faces.personId') + .where('asset_faces.id', '=', newAssetFace.id) + .executeTakeFirstOrThrow(); + expect(after.personId).toBeNull(); + }); + }); +}); diff --git a/server/test/repositories/person.repository.mock.ts b/server/test/repositories/person.repository.mock.ts new file mode 100644 index 0000000000..80a6a25c74 --- /dev/null +++ b/server/test/repositories/person.repository.mock.ts @@ -0,0 +1,36 @@ +import { PersonRepository } from 'src/repositories/person.repository'; +import { RepositoryInterface } from 'src/types'; +import { Mocked, vitest } from 'vitest'; + +export const newPersonRepositoryMock = (): Mocked> => { + return { + reassignFaces: vitest.fn(), + unassignFaces: vitest.fn(), + delete: vitest.fn(), + deleteFaces: vitest.fn(), + getAllFaces: vitest.fn(), + getAll: vitest.fn(), + getAllForUser: vitest.fn(), + getAllWithoutFaces: vitest.fn(), + getFaces: vitest.fn(), + getFaceById: vitest.fn(), + getFaceByIdWithAssets: vitest.fn(), + reassignFace: vitest.fn(), + getById: vitest.fn(), + getByName: vitest.fn(), + getDistinctNames: vitest.fn(), + getStatistics: vitest.fn(), + getNumberOfPeople: vitest.fn(), + create: vitest.fn(), + createAll: vitest.fn(), + refreshFaces: vitest.fn(), + update: vitest.fn(), + updateAll: vitest.fn(), + getFacesByIds: vitest.fn(), + getRandomFace: vitest.fn(), + getLatestFaceDate: vitest.fn(), + createAssetFace: vitest.fn(), + deleteAssetFace: vitest.fn(), + softDeleteAssetFaces: vitest.fn(), + }; +}; diff --git a/server/test/small.factory.ts b/server/test/small.factory.ts index 0f6d059b6a..70ec6e5495 100644 --- a/server/test/small.factory.ts +++ b/server/test/small.factory.ts @@ -12,6 +12,12 @@ export const newUuids = () => export const newDate = () => new Date(); export const newUpdateId = () => 'uuid-v7'; export const newSha1 = () => Buffer.from('this is a fake hash'); +export const newEmbedding = () => { + const embedding = Array.from({ length: 512 }) + .fill(0) + .map(() => Math.random()); + return '[' + embedding + ']'; +}; const authFactory = ({ apiKey, ...user }: Partial & { apiKey?: Partial } = {}) => { const auth: AuthDto = { diff --git a/server/test/utils.ts b/server/test/utils.ts index 4df7904d75..06142dc149 100644 --- a/server/test/utils.ts +++ b/server/test/utils.ts @@ -58,6 +58,7 @@ import { newDatabaseRepositoryMock } from 'test/repositories/database.repository import { newJobRepositoryMock } from 'test/repositories/job.repository.mock'; import { newMediaRepositoryMock } from 'test/repositories/media.repository.mock'; import { newMetadataRepositoryMock } from 'test/repositories/metadata.repository.mock'; +import { newPersonRepositoryMock } from 'test/repositories/person.repository.mock'; import { newStorageRepositoryMock } from 'test/repositories/storage.repository.mock'; import { newSystemMetadataRepositoryMock } from 'test/repositories/system-metadata.repository.mock'; import { ITelemetryRepositoryMock, newTelemetryRepositoryMock } from 'test/repositories/telemetry.repository.mock'; @@ -197,7 +198,7 @@ export const newTestService = ( notification: automock(NotificationRepository, { args: [loggerMock] }), oauth: automock(OAuthRepository, { args: [loggerMock] }), partner: automock(PartnerRepository, { strict: false }), - person: automock(PersonRepository, { strict: false }), + person: newPersonRepositoryMock(), process: automock(ProcessRepository, { args: [loggerMock] }), search: automock(SearchRepository, { args: [loggerMock], strict: false }), // eslint-disable-next-line no-sparse-arrays