diff --git a/server/src/interfaces/oauth.interface.ts b/server/src/interfaces/oauth.interface.ts new file mode 100644 index 0000000000..5e629726a0 --- /dev/null +++ b/server/src/interfaces/oauth.interface.ts @@ -0,0 +1,22 @@ +import { UserinfoResponse } from 'openid-client'; + +export const IOAuthRepository = 'IOAuthRepository'; + +export type OAuthConfig = { + clientId: string; + clientSecret: string; + issuerUrl: string; + mobileOverrideEnabled: boolean; + mobileRedirectUri: string; + profileSigningAlgorithm: string; + scope: string; + signingAlgorithm: string; +}; +export type OAuthProfile = UserinfoResponse; + +export interface IOAuthRepository { + init(): void; + authorize(config: OAuthConfig, redirectUrl: string): Promise; + getLogoutEndpoint(config: OAuthConfig): Promise; + getProfile(config: OAuthConfig, url: string, redirectUrl: string): Promise; +} diff --git a/server/src/repositories/index.ts b/server/src/repositories/index.ts index 5da4f678d3..5bf08d0d78 100644 --- a/server/src/repositories/index.ts +++ b/server/src/repositories/index.ts @@ -20,6 +20,7 @@ import { IMetadataRepository } from 'src/interfaces/metadata.interface'; import { IMetricRepository } from 'src/interfaces/metric.interface'; import { IMoveRepository } from 'src/interfaces/move.interface'; import { INotificationRepository } from 'src/interfaces/notification.interface'; +import { IOAuthRepository } from 'src/interfaces/oauth.interface'; import { IPartnerRepository } from 'src/interfaces/partner.interface'; import { IPersonRepository } from 'src/interfaces/person.interface'; import { ISearchRepository } from 'src/interfaces/search.interface'; @@ -56,6 +57,7 @@ import { MetadataRepository } from 'src/repositories/metadata.repository'; import { MetricRepository } from 'src/repositories/metric.repository'; import { MoveRepository } from 'src/repositories/move.repository'; import { NotificationRepository } from 'src/repositories/notification.repository'; +import { OAuthRepository } from 'src/repositories/oauth.repository'; import { PartnerRepository } from 'src/repositories/partner.repository'; import { PersonRepository } from 'src/repositories/person.repository'; import { SearchRepository } from 'src/repositories/search.repository'; @@ -94,6 +96,7 @@ export const repositories = [ { provide: IMetricRepository, useClass: MetricRepository }, { provide: IMoveRepository, useClass: MoveRepository }, { provide: INotificationRepository, useClass: NotificationRepository }, + { provide: IOAuthRepository, useClass: OAuthRepository }, { provide: IPartnerRepository, useClass: PartnerRepository }, { provide: IPersonRepository, useClass: PersonRepository }, { provide: ISearchRepository, useClass: SearchRepository }, diff --git a/server/src/repositories/oauth.repository.ts b/server/src/repositories/oauth.repository.ts new file mode 100644 index 0000000000..adde7099d0 --- /dev/null +++ b/server/src/repositories/oauth.repository.ts @@ -0,0 +1,73 @@ +import { Inject, Injectable, InternalServerErrorException } from '@nestjs/common'; +import { custom, generators, Issuer } from 'openid-client'; +import { ILoggerRepository } from 'src/interfaces/logger.interface'; +import { IOAuthRepository, OAuthConfig, OAuthProfile } from 'src/interfaces/oauth.interface'; +import { Instrumentation } from 'src/utils/instrumentation'; + +@Instrumentation() +@Injectable() +export class OAuthRepository implements IOAuthRepository { + constructor(@Inject(ILoggerRepository) private logger: ILoggerRepository) { + this.logger.setContext(OAuthRepository.name); + } + + init() { + custom.setHttpOptionsDefaults({ timeout: 30_000 }); + } + + async authorize(config: OAuthConfig, redirectUrl: string) { + const client = await this.getClient(config); + return client.authorizationUrl({ + redirect_uri: redirectUrl, + scope: config.scope, + state: generators.state(), + }); + } + + async getLogoutEndpoint(config: OAuthConfig) { + const client = await this.getClient(config); + return client.issuer.metadata.end_session_endpoint; + } + + async getProfile(config: OAuthConfig, url: string, redirectUrl: string): Promise { + const client = await this.getClient(config); + const params = client.callbackParams(url); + try { + const tokens = await client.callback(redirectUrl, params, { state: params.state }); + return await client.userinfo(tokens.access_token || ''); + } catch (error: Error | any) { + if (error.message.includes('unexpected JWT alg received')) { + this.logger.warn( + [ + 'Algorithm mismatch. Make sure the signing algorithm is set correctly in the OAuth settings.', + 'Or, that you have specified a signing key in your OAuth provider.', + ].join(' '), + ); + } + + throw error; + } + } + + private async getClient({ + issuerUrl, + clientId, + clientSecret, + profileSigningAlgorithm, + signingAlgorithm, + }: OAuthConfig) { + try { + const issuer = await Issuer.discover(issuerUrl); + return new issuer.Client({ + client_id: clientId, + client_secret: clientSecret, + response_types: ['code'], + userinfo_signed_response_alg: profileSigningAlgorithm === 'none' ? undefined : profileSigningAlgorithm, + id_token_signed_response_alg: signingAlgorithm, + }); + } catch (error: any | AggregateError) { + this.logger.error(`Error in OAuth discovery: ${error}`, error?.stack, error?.errors); + throw new InternalServerErrorException(`Error in OAuth discovery: ${error}`, { cause: error }); + } + } +} diff --git a/server/src/services/auth.service.spec.ts b/server/src/services/auth.service.spec.ts index 9f81e03171..3a5df0790d 100644 --- a/server/src/services/auth.service.spec.ts +++ b/server/src/services/auth.service.spec.ts @@ -1,5 +1,4 @@ import { BadRequestException, ForbiddenException, UnauthorizedException } from '@nestjs/common'; -import { Issuer, generators } from 'openid-client'; import { AuthDto, SignUpDto } from 'src/dtos/auth.dto'; import { UserMetadataEntity } from 'src/entities/user-metadata.entity'; import { UserEntity } from 'src/entities/user.entity'; @@ -7,6 +6,7 @@ import { AuthType, Permission } from 'src/enum'; import { IKeyRepository } from 'src/interfaces/api-key.interface'; import { ICryptoRepository } from 'src/interfaces/crypto.interface'; import { IEventRepository } from 'src/interfaces/event.interface'; +import { IOAuthRepository } from 'src/interfaces/oauth.interface'; import { ISessionRepository } from 'src/interfaces/session.interface'; import { ISharedLinkRepository } from 'src/interfaces/shared-link.interface'; import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface'; @@ -19,7 +19,7 @@ import { sharedLinkStub } from 'test/fixtures/shared-link.stub'; import { systemConfigStub } from 'test/fixtures/system-config.stub'; import { userStub } from 'test/fixtures/user.stub'; import { newTestService } from 'test/utils'; -import { Mock, Mocked, vitest } from 'vitest'; +import { Mocked } from 'vitest'; // const token = Buffer.from('my-api-key', 'utf8').toString('base64'); @@ -53,36 +53,19 @@ describe('AuthService', () => { let cryptoMock: Mocked; let eventMock: Mocked; let keyMock: Mocked; + let oauthMock: Mocked; let sessionMock: Mocked; let sharedLinkMock: Mocked; let systemMock: Mocked; let userMock: Mocked; - let callbackMock: Mock; - let userinfoMock: Mock; - beforeEach(() => { - callbackMock = vitest.fn().mockReturnValue({ access_token: 'access-token' }); - userinfoMock = vitest.fn().mockResolvedValue({ sub, email }); - - vitest.spyOn(generators, 'state').mockReturnValue('state'); - vitest.spyOn(Issuer, 'discover').mockResolvedValue({ - id_token_signing_alg_values_supported: ['RS256'], - Client: vitest.fn().mockResolvedValue({ - issuer: { - metadata: { - end_session_endpoint: 'http://end-session-endpoint', - }, - }, - authorizationUrl: vitest.fn().mockReturnValue('http://authorization-url'), - callbackParams: vitest.fn().mockReturnValue({ state: 'state' }), - callback: callbackMock, - userinfo: userinfoMock, - }), - } as any); - - ({ sut, cryptoMock, eventMock, keyMock, sessionMock, sharedLinkMock, systemMock, userMock } = + ({ sut, cryptoMock, eventMock, keyMock, oauthMock, sessionMock, sharedLinkMock, systemMock, userMock } = newTestService(AuthService)); + + oauthMock.authorize.mockResolvedValue('access-token'); + oauthMock.getProfile.mockResolvedValue({ sub, email }); + oauthMock.getLogoutEndpoint.mockResolvedValue('http://end-session-endpoint'); }); it('should be defined', () => { @@ -515,21 +498,21 @@ describe('AuthService', () => { expect(userMock.create).toHaveBeenCalledTimes(1); }); - // TODO write once oidc has been moved to a repo and can be mocked. - // it('should throw an error if user should be auto registered but the email claim does not exist', async () => { - // systemMock.get.mockResolvedValue(systemConfigStub.enabled); - // userMock.getByEmail.mockResolvedValue(null); - // userMock.getAdmin.mockResolvedValue(userStub.user1); - // userMock.create.mockResolvedValue(userStub.user1); - // sessionMock.create.mockResolvedValue(sessionStub.valid); + it('should throw an error if user should be auto registered but the email claim does not exist', async () => { + systemMock.get.mockResolvedValue(systemConfigStub.enabled); + userMock.getByEmail.mockResolvedValue(null); + userMock.getAdmin.mockResolvedValue(userStub.user1); + userMock.create.mockResolvedValue(userStub.user1); + sessionMock.create.mockResolvedValue(sessionStub.valid); + oauthMock.getProfile.mockResolvedValue({ sub, email: undefined }); - // await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).rejects.toBeInstanceOf( - // BadRequestException, - // ); + await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).rejects.toBeInstanceOf( + BadRequestException, + ); - // expect(userMock.getByEmail).toHaveBeenCalledTimes(1); - // expect(userMock.create).toHaveBeenCalledTimes(1); - // }); + expect(userMock.getByEmail).not.toHaveBeenCalled(); + expect(userMock.create).not.toHaveBeenCalled(); + }); for (const url of [ 'app.immich:/', @@ -545,7 +528,7 @@ describe('AuthService', () => { sessionMock.create.mockResolvedValue(sessionStub.valid); await sut.callback({ url }, loginDetails); - expect(callbackMock).toHaveBeenCalledWith('http://mobile-redirect', { state: 'state' }, { state: 'state' }); + expect(oauthMock.getProfile).toHaveBeenCalledWith(expect.objectContaining({}), url, 'http://mobile-redirect'); }); } @@ -567,7 +550,7 @@ describe('AuthService', () => { userMock.getByEmail.mockResolvedValue(null); userMock.getAdmin.mockResolvedValue(userStub.user1); userMock.create.mockResolvedValue(userStub.user1); - userinfoMock.mockResolvedValue({ sub, email, immich_quota: 'abc' }); + oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: 'abc' }); await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( loginResponseStub.user1oauth, @@ -581,7 +564,7 @@ describe('AuthService', () => { userMock.getByEmail.mockResolvedValue(null); userMock.getAdmin.mockResolvedValue(userStub.user1); userMock.create.mockResolvedValue(userStub.user1); - userinfoMock.mockResolvedValue({ sub, email, immich_quota: -5 }); + oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: -5 }); await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( loginResponseStub.user1oauth, @@ -595,7 +578,7 @@ describe('AuthService', () => { userMock.getByEmail.mockResolvedValue(null); userMock.getAdmin.mockResolvedValue(userStub.user1); userMock.create.mockResolvedValue(userStub.user1); - userinfoMock.mockResolvedValue({ sub, email, immich_quota: 0 }); + oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: 0 }); await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( loginResponseStub.user1oauth, @@ -615,7 +598,7 @@ describe('AuthService', () => { userMock.getByEmail.mockResolvedValue(null); userMock.getAdmin.mockResolvedValue(userStub.user1); userMock.create.mockResolvedValue(userStub.user1); - userinfoMock.mockResolvedValue({ sub, email, immich_quota: 5 }); + oauthMock.getProfile.mockResolvedValue({ sub, email, immich_quota: 5 }); await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( loginResponseStub.user1oauth, diff --git a/server/src/services/auth.service.ts b/server/src/services/auth.service.ts index d16c7af09a..1fbc8b69ff 100644 --- a/server/src/services/auth.service.ts +++ b/server/src/services/auth.service.ts @@ -1,16 +1,8 @@ -import { - BadRequestException, - ForbiddenException, - Injectable, - InternalServerErrorException, - UnauthorizedException, -} from '@nestjs/common'; +import { BadRequestException, ForbiddenException, Injectable, UnauthorizedException } from '@nestjs/common'; import { isNumber, isString } from 'class-validator'; import cookieParser from 'cookie'; import { DateTime } from 'luxon'; import { IncomingHttpHeaders } from 'node:http'; -import { Issuer, UserinfoResponse, custom, generators } from 'openid-client'; -import { SystemConfig } from 'src/config'; import { LOGIN_URL, MOBILE_REDIRECT, SALT_ROUNDS } from 'src/constants'; import { OnEvent } from 'src/decorators'; import { @@ -30,6 +22,7 @@ import { import { UserAdminResponseDto, mapUserAdmin } from 'src/dtos/user.dto'; import { UserEntity } from 'src/entities/user.entity'; import { AuthType, Permission } from 'src/enum'; +import { OAuthProfile } from 'src/interfaces/oauth.interface'; import { BaseService } from 'src/services/base.service'; import { isGranted } from 'src/utils/access'; import { HumanReadableSize } from 'src/utils/bytes'; @@ -42,8 +35,6 @@ export interface LoginDetails { deviceOS: string; } -type OAuthProfile = UserinfoResponse; - interface ClaimOptions { key: string; default: T; @@ -65,7 +56,7 @@ export type ValidateRequest = { export class AuthService extends BaseService { @OnEvent({ name: 'app.bootstrap' }) onBootstrap() { - custom.setHttpOptionsDefaults({ timeout: 30_000 }); + this.oauthRepository.init(); } async login(dto: LoginCredentialDto, details: LoginDetails) { @@ -191,21 +182,20 @@ export class AuthService extends BaseService { } async authorize(dto: OAuthConfigDto): Promise { - const config = await this.getConfig({ withCache: false }); - const client = await this.getOAuthClient(config); - const url = client.authorizationUrl({ - redirect_uri: this.normalize(config, dto.redirectUri), - scope: config.oauth.scope, - state: generators.state(), - }); + const { oauth } = await this.getConfig({ withCache: false }); + if (!oauth.enabled) { + throw new BadRequestException('OAuth is not enabled'); + } + + const url = await this.oauthRepository.authorize(oauth, dto.redirectUri); return { url }; } async callback(dto: OAuthCallbackDto, loginDetails: LoginDetails) { - const config = await this.getConfig({ withCache: false }); - const profile = await this.getOAuthProfile(config, dto.url); - const { autoRegister, defaultStorageQuota, storageLabelClaim, storageQuotaClaim } = config.oauth; + const { oauth } = await this.getConfig({ withCache: false }); + const profile = await this.oauthRepository.getProfile(oauth, dto.url, this.normalize(oauth, dto.url.split('?')[0])); + const { autoRegister, defaultStorageQuota, storageLabelClaim, storageQuotaClaim } = oauth; this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`); let user = await this.userRepository.getByOAuthId(profile.sub); @@ -263,8 +253,12 @@ export class AuthService extends BaseService { } async link(auth: AuthDto, dto: OAuthCallbackDto): Promise { - const config = await this.getConfig({ withCache: false }); - const { sub: oauthId } = await this.getOAuthProfile(config, dto.url); + const { oauth } = await this.getConfig({ withCache: false }); + const { sub: oauthId } = await this.oauthRepository.getProfile( + oauth, + dto.url, + this.normalize(oauth, dto.url.split('?')[0]), + ); const duplicate = await this.userRepository.getByOAuthId(oauthId); if (duplicate && duplicate.id !== auth.user.id) { this.logger.warn(`OAuth link account failed: sub is already linked to another user (${duplicate.email}).`); @@ -290,60 +284,7 @@ export class AuthService extends BaseService { return LOGIN_URL; } - const client = await this.getOAuthClient(config); - return client.issuer.metadata.end_session_endpoint || LOGIN_URL; - } - - private async getOAuthProfile(config: SystemConfig, url: string): Promise { - const redirectUri = this.normalize(config, url.split('?')[0]); - const client = await this.getOAuthClient(config); - const params = client.callbackParams(url); - try { - const tokens = await client.callback(redirectUri, params, { state: params.state }); - return client.userinfo(tokens.access_token || ''); - } catch (error: Error | any) { - if (error.message.includes('unexpected JWT alg received')) { - this.logger.warn( - [ - 'Algorithm mismatch. Make sure the signing algorithm is set correctly in the OAuth settings.', - 'Or, that you have specified a signing key in your OAuth provider.', - ].join(' '), - ); - } - - throw error; - } - } - - private async getOAuthClient(config: SystemConfig) { - const { enabled, clientId, clientSecret, issuerUrl, signingAlgorithm, profileSigningAlgorithm } = config.oauth; - - if (!enabled) { - throw new BadRequestException('OAuth2 is not enabled'); - } - - try { - const issuer = await Issuer.discover(issuerUrl); - return new issuer.Client({ - client_id: clientId, - client_secret: clientSecret, - response_types: ['code'], - userinfo_signed_response_alg: profileSigningAlgorithm === 'none' ? undefined : profileSigningAlgorithm, - id_token_signed_response_alg: signingAlgorithm, - }); - } catch (error: any | AggregateError) { - this.logger.error(`Error in OAuth discovery: ${error}`, error?.stack, error?.errors); - throw new InternalServerErrorException(`Error in OAuth discovery: ${error}`, { cause: error }); - } - } - - private normalize(config: SystemConfig, redirectUri: string) { - const isMobile = redirectUri.startsWith('app.immich:/'); - const { mobileRedirectUri, mobileOverrideEnabled } = config.oauth; - if (isMobile && mobileOverrideEnabled && mobileRedirectUri) { - return mobileRedirectUri; - } - return redirectUri; + return (await this.oauthRepository.getLogoutEndpoint(config.oauth)) || LOGIN_URL; } private getBearerToken(headers: IncomingHttpHeaders): string | null { @@ -427,4 +368,15 @@ export class AuthService extends BaseService { const value = profile[options.key as keyof OAuthProfile]; return options.isValid(value) ? (value as T) : options.default; } + + private normalize( + { mobileRedirectUri, mobileOverrideEnabled }: { mobileRedirectUri: string; mobileOverrideEnabled: boolean }, + redirectUri: string, + ) { + const isMobile = redirectUri.startsWith('app.immich:/'); + if (isMobile && mobileOverrideEnabled && mobileRedirectUri) { + return mobileRedirectUri; + } + return redirectUri; + } } diff --git a/server/src/services/base.service.ts b/server/src/services/base.service.ts index 3c28451a6d..e98f88ade1 100644 --- a/server/src/services/base.service.ts +++ b/server/src/services/base.service.ts @@ -23,6 +23,7 @@ import { IMetadataRepository } from 'src/interfaces/metadata.interface'; import { IMetricRepository } from 'src/interfaces/metric.interface'; import { IMoveRepository } from 'src/interfaces/move.interface'; import { INotificationRepository } from 'src/interfaces/notification.interface'; +import { IOAuthRepository } from 'src/interfaces/oauth.interface'; import { IPartnerRepository } from 'src/interfaces/partner.interface'; import { IPersonRepository } from 'src/interfaces/person.interface'; import { ISearchRepository } from 'src/interfaces/search.interface'; @@ -65,6 +66,7 @@ export class BaseService { @Inject(IMetricRepository) protected metricRepository: IMetricRepository, @Inject(IMoveRepository) protected moveRepository: IMoveRepository, @Inject(INotificationRepository) protected notificationRepository: INotificationRepository, + @Inject(IOAuthRepository) protected oauthRepository: IOAuthRepository, @Inject(IPartnerRepository) protected partnerRepository: IPartnerRepository, @Inject(IPersonRepository) protected personRepository: IPersonRepository, @Inject(ISearchRepository) protected searchRepository: ISearchRepository, diff --git a/server/test/repositories/oauth.repository.mock.ts b/server/test/repositories/oauth.repository.mock.ts new file mode 100644 index 0000000000..f87b3781e9 --- /dev/null +++ b/server/test/repositories/oauth.repository.mock.ts @@ -0,0 +1,11 @@ +import { IOAuthRepository } from 'src/interfaces/oauth.interface'; +import { Mocked } from 'vitest'; + +export const newOAuthRepositoryMock = (): Mocked => { + return { + init: vitest.fn(), + authorize: vitest.fn(), + getLogoutEndpoint: vitest.fn(), + getProfile: vitest.fn(), + }; +}; diff --git a/server/test/utils.ts b/server/test/utils.ts index c744443bd6..05257c19ee 100644 --- a/server/test/utils.ts +++ b/server/test/utils.ts @@ -21,6 +21,7 @@ import { newMetadataRepositoryMock } from 'test/repositories/metadata.repository import { newMetricRepositoryMock } from 'test/repositories/metric.repository.mock'; import { newMoveRepositoryMock } from 'test/repositories/move.repository.mock'; import { newNotificationRepositoryMock } from 'test/repositories/notification.repository.mock'; +import { newOAuthRepositoryMock } from 'test/repositories/oauth.repository.mock'; import { newPartnerRepositoryMock } from 'test/repositories/partner.repository.mock'; import { newPersonRepositoryMock } from 'test/repositories/person.repository.mock'; import { newSearchRepositoryMock } from 'test/repositories/search.repository.mock'; @@ -64,6 +65,7 @@ export const newTestService = (Service: Constructor(Service: Constructor(Service: Constructor