From dbf30b77bfbcb6dd14e87199cdcc498f48982577 Mon Sep 17 00:00:00 2001 From: santanoce <98591467+santanoce@users.noreply.github.com> Date: Fri, 17 Apr 2026 20:45:33 +0200 Subject: [PATCH] feat(server): added backchannel logout api endpoint (#26235) * feat(server): added backchannel logout api endpoint * test(server): fixed e2e tests * fix(server): fixed suggested changes by reviewer * feat(server): created function invalidateOAuth * fix(server): fixed session.repository.sql * test(server): added unit tests for backchannelLogout function * test(server): added e2e tests for oidc backchnnel logout * docs(server): added documentation on backchannel logout url * docs(server): fixed typo * feat(server): minor improvements of the oidc backchannel logout * test(server): fixed tests after merge with main * fix(server): fixed e2e test file * refactor(server): tiny refactor of validateLogoutToken * chore: cleanup * fix: tests * fix: make jwks extractable --------- Co-authored-by: Daniel Dietzler --- docs/docs/administration/oauth.md | 4 + e2e-auth-server/auth-server.ts | 36 +++- e2e-auth-server/test-keys.ts | 38 ++++ e2e/src/specs/server/api/oauth.e2e-spec.ts | 47 ++++- mobile/openapi/README.md | 1 + .../openapi/lib/api/authentication_api.dart | 53 ++++++ open-api/immich-openapi-specs.json | 44 +++++ open-api/typescript-sdk/src/fetch-client.ts | 16 ++ server/src/app.common.ts | 3 +- server/src/controllers/oauth.controller.ts | 16 +- server/src/dtos/auth.dto.ts | 5 + server/src/queries/session.repository.sql | 2 +- server/src/repositories/oauth.repository.ts | 68 +++++++- server/src/repositories/session.repository.ts | 24 ++- .../1776442031775-AddOauthSidToSession.ts | 11 ++ server/src/schema/tables/session.table.ts | 3 + server/src/services/auth.service.spec.ts | 162 +++++++++++++++--- server/src/services/auth.service.ts | 60 ++++++- server/src/services/session.service.spec.ts | 7 +- server/src/services/session.service.ts | 4 +- server/test/factories/session.factory.ts | 1 + 21 files changed, 558 insertions(+), 47 deletions(-) create mode 100644 e2e-auth-server/test-keys.ts create mode 100644 server/src/schema/migrations/1776442031775-AddOauthSidToSession.ts diff --git a/docs/docs/administration/oauth.md b/docs/docs/administration/oauth.md index 3b1e8c729d..8d259d8074 100644 --- a/docs/docs/administration/oauth.md +++ b/docs/docs/administration/oauth.md @@ -50,6 +50,10 @@ Before enabling OAuth in Immich, a new client application needs to be configured - `https://immich.example.com/auth/login` - `https://immich.example.com/user-settings` +3. Configure Backchannel logout URL + + If the authentication server supports it, the **Backchannel logout URL** can be specified, and it is of the form: `http://DOMAIN:PORT/api/oauth/backchannel-logout`. + ## Enable OAuth Once you have a new OAuth client application configured, Immich can be configured using the Administration Settings page, available on the web (Administration -> Settings). diff --git a/e2e-auth-server/auth-server.ts b/e2e-auth-server/auth-server.ts index bcfeca1e1c..15aaa71c1c 100644 --- a/e2e-auth-server/auth-server.ts +++ b/e2e-auth-server/auth-server.ts @@ -1,5 +1,12 @@ -import { exportJWK, generateKeyPair } from 'jose'; +import { + calculateJwkThumbprint, + exportJWK, + importPKCS8, + importSPKI, + SignJWT, +} from 'jose'; import Provider from 'oidc-provider'; +import { PRIVATE_KEY_PEM, PUBLIC_KEY_PEM } from './test-keys'; export enum OAuthClient { DEFAULT = 'client-default', @@ -44,6 +51,29 @@ const claims = [ }, ]; +const privateKey = await importPKCS8(PRIVATE_KEY_PEM, 'RS256', { + extractable: true, +}); +const publicKey = await importSPKI(PUBLIC_KEY_PEM, 'RS256', { + extractable: true, +}); +const kid = await calculateJwkThumbprint(await exportJWK(publicKey)); + +export async function generateLogoutToken(iss: string, sub: string) { + return await new SignJWT({ + iss: iss, + aud: OAuthClient.DEFAULT, + iat: Math.floor(Date.now() / 1000), + jti: crypto.randomUUID(), + sub: sub, + events: { + 'http://schemas.openid.net/event/backchannel-logout': {}, + }, + }) + .setProtectedHeader({ alg: 'RS256', typ: 'logout+jwt', kid: kid }) + .sign(privateKey); +} + const withDefaultClaims = (sub: string) => ({ sub, email: `${sub}@immich.app`, @@ -66,10 +96,6 @@ const getClaims = (sub: string, use?: string) => { }; const setup = async () => { - const { privateKey, publicKey } = await generateKeyPair('RS256', { - extractable: true, - }); - const redirectUris = [ 'http://127.0.0.1:2285/auth/login', 'https://photos.immich.app/oauth/mobile-redirect', diff --git a/e2e-auth-server/test-keys.ts b/e2e-auth-server/test-keys.ts new file mode 100644 index 0000000000..a37e822029 --- /dev/null +++ b/e2e-auth-server/test-keys.ts @@ -0,0 +1,38 @@ +export const PRIVATE_KEY_PEM = `-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCVj5C7hzN3E2HO +TcJ+DN/e2NSTQFj4rPylz4J8xjm8Es7l0k2kK5EEGvUNVGZbw7s055c+6kwP9eqg +B5XFE7+26Fcq1sou6Tbm310kU4dnMW5l2CgwrhaGyb1pNysao0AMLT60dFYqtUwn +ha9ceCsa+ZU1JrknVf3rONtppBvhWoI7CO9XX1keVQ0unHPzCWUjpXTzC8OGEbmB +2w7ZIUf8OfJkd5RZ4OtIpML71W9n13aDxT50x2/EW/pFLFtQ/oaleOKHpvlRXDRX +W86G4moUJym3gHMXMUj2aOcFG2UJnpLruKz3i5qZwYiTRlBP6O9EIQNCVtYxchuN +V1CCcBU1AgMBAAECggEAJLfXMu8Nx89ynPVyyUMMaFfoEpHC9iR0L5obQVpiPMYK +VRqVVLecdftPS9s7eQ58BNBRzdC0ZVu841aRYs3HLNbsZZhPkYZQpAxU//Dg5okY +fzj7Hv5yidt4HN9+Pd8z/3lRMnj4WapifLaBt8xJ2ujJBMBRxzJBsXDnT0+Kx7+y +bYDeuVfyUTEikaK3QZTbuRF3D3eiuN16GG+hv8UqTF2eYbPxdiLjYpTSHa4mH88C +qfJz2Xt4SEzmyeo3G+MO17wDFOwtEe8ojlJfULHnHJSFdUwTfYIFM1bg5/fJ9MOS +/fO3TSG+wkQqjQa6eoGssAzP87fL2XNLzlDtGY/7uQKBgQDHuJHOtf1EjOvNYiP7 +EN+8QGs41ghzt9CQRQxWbHpusR3IW3P83KMXwYmrlG70oOUXBRGSB/ESXUofXc5W +pu5+Y55S44aUnu/a9yOBttYW0dtHZSL0zFT+PlVASwUzFZ2zcH1KXlUkSpfL5OAD +PyDDTnBZ2AWh45fRO9wLo6PPuQKBgQC/tI03RqU3mOjqukKbquYeIpXHfRU5Z0DM +u9ru1THYEl6fmkMXycxo/mvW3awyFuyKy/VodqIgKnFgumEqCHZh6OAMm/LC7TfA +l9tjFSs/MyOqQVD4kbX+z6Oq4c4GccDoXfsQ3gzECoBapegi/F+6/25y+/C8ghXb +J/Jg1GQXXQKBgQDFgWbfzuVZZyrBfu4qGLPJDMN7/114YizknwPma3xf/tN/EcGQ +K/k1QvWMMkvPq1UiAKcxjJ0AFjV482FcG9T6NDWbrtmmG88C8Sex3Ue2ZW2+GuwI +vhDHJIlV/Vp0/Elp7DJa2xLDwuh+gCZvz3vs6KL+ljxrrhCyn8mp0PfsMQKBgFFZ +KnuETOO0zVGdzFoGQTQUdP58A5+iQwsdxB+I9Ge+E80iRso3ZbhADj7VPhbbR3D2 +b6LuhImluQrUzBpsEOAnU7vGCVPSGdBuIDiBaSKebsn2gYeZPWNtdQQ0YZq2dqek +Cb/0mfIuipzsvf7qnSza62F7q4IyqVegMegI+Jg5AoGATM3NMy7JZeKzSkm+3ohU +3xZOwgqKV9SH+0OeYWpuBxT7D7FlrKKI4NJ3XN3hg2f/DJAF6dH11CPe7pk94yol +HMbh+PQUQ6GYvAzxIOvagWboQ3lzeyubNMpyFjfOrIE/WOQCUBZ9tIwCHIarIuyi +QRuNOj3+U8T/n1Ww352HBdw= +-----END PRIVATE KEY-----`; + +export const PUBLIC_KEY_PEM = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAlY+Qu4czdxNhzk3Cfgzf +3tjUk0BY+Kz8pc+CfMY5vBLO5dJNpCuRBBr1DVRmW8O7NOeXPupMD/XqoAeVxRO/ +tuhXKtbKLuk25t9dJFOHZzFuZdgoMK4Whsm9aTcrGqNADC0+tHRWKrVMJ4WvXHgr +GvmVNSa5J1X96zjbaaQb4VqCOwjvV19ZHlUNLpxz8wllI6V08wvDhhG5gdsO2SFH +/DnyZHeUWeDrSKTC+9VvZ9d2g8U+dMdvxFv6RSxbUP6GpXjih6b5UVw0V1vOhuJq +FCcpt4BzFzFI9mjnBRtlCZ6S67is94uamcGIk0ZQT+jvRCEDQlbWMXIbjVdQgnAV +NQIDAQAB +-----END PUBLIC KEY-----`; diff --git a/e2e/src/specs/server/api/oauth.e2e-spec.ts b/e2e/src/specs/server/api/oauth.e2e-spec.ts index 98cb28c821..a3bc0d8770 100644 --- a/e2e/src/specs/server/api/oauth.e2e-spec.ts +++ b/e2e/src/specs/server/api/oauth.e2e-spec.ts @@ -1,9 +1,10 @@ -import { OAuthClient, OAuthUser } from '@immich/e2e-auth-server'; +import { OAuthClient, OAuthUser, generateLogoutToken } from '@immich/e2e-auth-server'; import { LoginResponseDto, SystemConfigOAuthDto, getConfigDefaults, getMyUser, + getSessions, startOAuth, updateConfig, } from '@immich/sdk'; @@ -334,6 +335,50 @@ describe(`/oauth`, () => { }); }); + describe(`POST /oauth/backchannel-logout`, () => { + it(`should throw an error if the logout_token is not provided`, async () => { + const { status, body } = await request(app).post('/oauth/backchannel-logout').send({}); + expect(status).toBe(400); + expect(body).toEqual(errorDto.badRequest(['[logout_token] Invalid input: expected string, received undefined'])); + }); + + it(`should throw an error if an invalid logout token is provided`, async () => { + const { status, body } = await request(app) + .post('/oauth/backchannel-logout') + .send({ logout_token: 'invalid token' }); + expect(status).toBe(400); + expect(body).toEqual(errorDto.badRequest('Error backchannel logout: token validation failed')); + }); + + it(`should logout user if a valid logout token is provided`, async () => { + await setupOAuth(admin.accessToken, { + enabled: true, + clientId: OAuthClient.DEFAULT, + clientSecret: OAuthClient.DEFAULT, + autoRegister: true, + signingAlgorithm: 'RS256', + buttonText: 'Login with Immich', + }); + + const callbackParams = await loginWithOAuth('backchannel-logout-user'); + const { status: callbackStatus, body: callbackBody } = await request(app) + .post('/oauth/callback') + .send(callbackParams); + expect(callbackStatus).toBe(201); + + await expect(getSessions({ headers: asBearerAuth(callbackBody.accessToken) })).resolves.toHaveLength(1); + + const logoutToken = await generateLogoutToken('http://0.0.0.0:2286', 'backchannel-logout-user'); + const { status, body } = await request(app).post('/oauth/backchannel-logout').send({ logout_token: logoutToken }); + expect(status).toBe(200); + expect(body).toMatchObject({}); + + await expect(getSessions({ headers: asBearerAuth(callbackBody.accessToken) })).rejects.toMatchObject({ + status: 401, + }); + }); + }); + describe('mobile redirect override', () => { beforeAll(async () => { await setupOAuth(admin.accessToken, { diff --git a/mobile/openapi/README.md b/mobile/openapi/README.md index 138cf1735a..50bbff2bae 100644 --- a/mobile/openapi/README.md +++ b/mobile/openapi/README.md @@ -126,6 +126,7 @@ Class | Method | HTTP request | Description *AuthenticationApi* | [**lockAuthSession**](doc//AuthenticationApi.md#lockauthsession) | **POST** /auth/session/lock | Lock auth session *AuthenticationApi* | [**login**](doc//AuthenticationApi.md#login) | **POST** /auth/login | Login *AuthenticationApi* | [**logout**](doc//AuthenticationApi.md#logout) | **POST** /auth/logout | Logout +*AuthenticationApi* | [**logoutOAuth**](doc//AuthenticationApi.md#logoutoauth) | **POST** /oauth/backchannel-logout | Backchannel OAuth logout *AuthenticationApi* | [**redirectOAuthToMobile**](doc//AuthenticationApi.md#redirectoauthtomobile) | **GET** /oauth/mobile-redirect | Redirect OAuth to mobile *AuthenticationApi* | [**resetPinCode**](doc//AuthenticationApi.md#resetpincode) | **DELETE** /auth/pin-code | Reset pin code *AuthenticationApi* | [**setupPinCode**](doc//AuthenticationApi.md#setuppincode) | **POST** /auth/pin-code | Setup pin code diff --git a/mobile/openapi/lib/api/authentication_api.dart b/mobile/openapi/lib/api/authentication_api.dart index 52d46a525b..e1219f2c03 100644 --- a/mobile/openapi/lib/api/authentication_api.dart +++ b/mobile/openapi/lib/api/authentication_api.dart @@ -424,6 +424,59 @@ class AuthenticationApi { return null; } + /// Backchannel OAuth logout + /// + /// Logout the OAuth account and invalidate the session specified by the sid claim or all sessions if the sid claim is not present. + /// + /// Note: This method returns the HTTP [Response]. + /// + /// Parameters: + /// + /// * [String] logoutToken (required): + /// OAuth logout token + Future logoutOAuthWithHttpInfo(String logoutToken,) async { + // ignore: prefer_const_declarations + final apiPath = r'/oauth/backchannel-logout'; + + // ignore: prefer_final_locals + Object? postBody; + + final queryParams = []; + final headerParams = {}; + final formParams = {}; + + const contentTypes = ['application/x-www-form-urlencoded']; + + if (logoutToken != null) { + formParams[r'logout_token'] = parameterToString(logoutToken); + } + + return apiClient.invokeAPI( + apiPath, + 'POST', + queryParams, + postBody, + headerParams, + formParams, + contentTypes.isEmpty ? null : contentTypes.first, + ); + } + + /// Backchannel OAuth logout + /// + /// Logout the OAuth account and invalidate the session specified by the sid claim or all sessions if the sid claim is not present. + /// + /// Parameters: + /// + /// * [String] logoutToken (required): + /// OAuth logout token + Future logoutOAuth(String logoutToken,) async { + final response = await logoutOAuthWithHttpInfo(logoutToken,); + if (response.statusCode >= HttpStatus.badRequest) { + throw ApiException(response.statusCode, await _decodeBodyBytes(response)); + } + } + /// Redirect OAuth to mobile /// /// Requests to this URL are automatically forwarded to the mobile app, and is used in some cases for OAuth redirecting. diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json index 85816e8eda..5853fa6b0d 100644 --- a/open-api/immich-openapi-specs.json +++ b/open-api/immich-openapi-specs.json @@ -7359,6 +7359,38 @@ "x-immich-state": "Stable" } }, + "/oauth/backchannel-logout": { + "post": { + "description": "Logout the OAuth account and invalidate the session specified by the sid claim or all sessions if the sid claim is not present.", + "operationId": "logoutOAuth", + "parameters": [], + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/OAuthBackchannelLogoutDto" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "" + } + }, + "summary": "Backchannel OAuth logout", + "tags": [ + "Authentication" + ], + "x-immich-history": [ + { + "version": "v2", + "state": "Added" + } + ] + } + }, "/oauth/callback": { "post": { "description": "Complete the OAuth authorization process by exchanging the authorization code for a session token.", @@ -19031,6 +19063,18 @@ ], "type": "object" }, + "OAuthBackchannelLogoutDto": { + "properties": { + "logout_token": { + "description": "OAuth logout token", + "type": "string" + } + }, + "required": [ + "logout_token" + ], + "type": "object" + }, "OAuthCallbackDto": { "properties": { "codeVerifier": { diff --git a/open-api/typescript-sdk/src/fetch-client.ts b/open-api/typescript-sdk/src/fetch-client.ts index 6277df805c..f86ede7592 100644 --- a/open-api/typescript-sdk/src/fetch-client.ts +++ b/open-api/typescript-sdk/src/fetch-client.ts @@ -1409,6 +1409,10 @@ export type OAuthAuthorizeResponseDto = { /** OAuth authorization URL */ url: string; }; +export type OAuthBackchannelLogoutDto = { + /** OAuth logout token */ + logout_token: string; +}; export type OAuthCallbackDto = { /** OAuth code verifier (PKCE) */ codeVerifier?: string; @@ -4909,6 +4913,18 @@ export function startOAuth({ oAuthConfigDto }: { body: oAuthConfigDto }))); } +/** + * Backchannel OAuth logout + */ +export function logoutOAuth({ oAuthBackchannelLogoutDto }: { + oAuthBackchannelLogoutDto: OAuthBackchannelLogoutDto; +}, opts?: Oazapfts.RequestOpts) { + return oazapfts.ok(oazapfts.fetchText("/oauth/backchannel-logout", oazapfts.form({ + ...opts, + method: "POST", + body: oAuthBackchannelLogoutDto + }))); +} /** * Finish OAuth */ diff --git a/server/src/app.common.ts b/server/src/app.common.ts index 2159721932..5ea52a0599 100644 --- a/server/src/app.common.ts +++ b/server/src/app.common.ts @@ -1,5 +1,5 @@ import { NestExpressApplication } from '@nestjs/platform-express'; -import { json } from 'body-parser'; +import { json, urlencoded } from 'body-parser'; import compression from 'compression'; import cookieParser from 'cookie-parser'; import helmetMiddleware from 'helmet'; @@ -56,6 +56,7 @@ export async function configureExpress( app.use(cookieParser()); app.use(json({ limit: '10mb' })); + app.use(urlencoded({ limit: '10mb' })); if (configRepository.isDev()) { app.enableCors(); diff --git a/server/src/controllers/oauth.controller.ts b/server/src/controllers/oauth.controller.ts index 797bf497ef..7f2313a058 100644 --- a/server/src/controllers/oauth.controller.ts +++ b/server/src/controllers/oauth.controller.ts @@ -1,11 +1,12 @@ import { Body, Controller, Get, HttpCode, HttpStatus, Post, Redirect, Req, Res } from '@nestjs/common'; -import { ApiTags } from '@nestjs/swagger'; +import { ApiConsumes, ApiTags } from '@nestjs/swagger'; import { Request, Response } from 'express'; import { Endpoint, HistoryBuilder } from 'src/decorators'; import { AuthDto, LoginResponseDto, OAuthAuthorizeResponseDto, + OAuthBackchannelLogoutDto, OAuthCallbackDto, OAuthConfigDto, } from 'src/dtos/auth.dto'; @@ -112,4 +113,17 @@ export class OAuthController { unlinkOAuthAccount(@Auth() auth: AuthDto): Promise { return this.service.unlink(auth); } + + @Post('backchannel-logout') + @HttpCode(HttpStatus.OK) + @ApiConsumes('application/x-www-form-urlencoded') + @Endpoint({ + summary: 'Backchannel OAuth logout', + description: + 'Logout the OAuth account and invalidate the session specified by the sid claim or all sessions if the sid claim is not present.', + history: new HistoryBuilder().added('v2'), + }) + async logoutOAuth(@Body() dto: OAuthBackchannelLogoutDto): Promise { + return this.service.backchannelLogout(dto); + } } diff --git a/server/src/dtos/auth.dto.ts b/server/src/dtos/auth.dto.ts index 95d2bb126a..1f75401e33 100644 --- a/server/src/dtos/auth.dto.ts +++ b/server/src/dtos/auth.dto.ts @@ -124,6 +124,10 @@ const OAuthAuthorizeResponseSchema = z }) .meta({ id: 'OAuthAuthorizeResponseDto' }); +const OAuthBackchannelLogoutSchema = z + .object({ logout_token: z.string().describe('OAuth logout token') }) + .meta({ id: 'OAuthBackchannelLogoutDto' }); + const AuthStatusResponseSchema = z .object({ pinCode: z.boolean().describe('Has PIN code set'), @@ -147,4 +151,5 @@ export class ValidateAccessTokenResponseDto extends createZodDto(ValidateAccessT export class OAuthCallbackDto extends createZodDto(OAuthCallbackSchema) {} export class OAuthConfigDto extends createZodDto(OAuthConfigSchema) {} export class OAuthAuthorizeResponseDto extends createZodDto(OAuthAuthorizeResponseSchema) {} +export class OAuthBackchannelLogoutDto extends createZodDto(OAuthBackchannelLogoutSchema) {} export class AuthStatusResponseDto extends createZodDto(AuthStatusResponseSchema) {} diff --git a/server/src/queries/session.repository.sql b/server/src/queries/session.repository.sql index b399646409..a29b6f7cc3 100644 --- a/server/src/queries/session.repository.sql +++ b/server/src/queries/session.repository.sql @@ -74,7 +74,7 @@ delete from "session" where "id" = $1::uuid --- SessionRepository.invalidate +-- SessionRepository.invalidateAll delete from "session" where "userId" = $1 diff --git a/server/src/repositories/oauth.repository.ts b/server/src/repositories/oauth.repository.ts index 8fb969a233..012648b58d 100644 --- a/server/src/repositories/oauth.repository.ts +++ b/server/src/repositories/oauth.repository.ts @@ -1,4 +1,5 @@ import { Injectable, InternalServerErrorException } from '@nestjs/common'; +import { createRemoteJWKSet, jwtVerify, JWTVerifyGetKey } from 'jose'; import { allowInsecureRequests as allowInsecureRequestsExecute, authorizationCodeGrant, @@ -71,12 +72,12 @@ export class OAuthRepository { return client.serverMetadata().end_session_endpoint; } - async getProfile( + async getProfileAndOAuthSid( config: OAuthConfig, url: string, expectedState: string, codeVerifier: string, - ): Promise { + ): Promise<{ profile: OAuthProfile; sid?: string }> { const client = await this.getClient(config); const pkceCodeVerifier = client.serverMetadata().supportsPKCE() ? codeVerifier : undefined; @@ -96,7 +97,15 @@ export class OAuthRepository { throw new Error('Unexpected profile response, no `sub`'); } - return profile; + let sid: string | undefined; + if (tokens.id_token) { + const claims = tokens.claims(); + if (typeof claims?.sid === 'string') { + sid = claims.sid; + } + } + + return { profile, sid }; } catch (error: Error | any) { if (error.message.includes('unexpected JWT alg received')) { this.logger.warn( @@ -126,6 +135,59 @@ export class OAuthRepository { }; } + private jwksClients: Map = new Map(); // useful for caching and performnce + async validateLogoutToken(config: OAuthConfig, logoutToken: string): Promise<{ sub?: string; sid?: string } | null> { + const client = await this.getClient(config); + const algorithm = client.clientMetadata().id_token_signed_response_alg ?? 'RS256'; + let keyOrGetter: Uint8Array | JWTVerifyGetKey; + + try { + if (algorithm.startsWith('HS')) { + keyOrGetter = new TextEncoder().encode(config.clientSecret); + } else { + const jwksUri = client.serverMetadata().jwks_uri; + if (!jwksUri) { + throw new Error('Unable to get JWKS URI'); + } + + if (!this.jwksClients.has(jwksUri)) { + this.jwksClients.set(jwksUri, createRemoteJWKSet(new URL(jwksUri))); + } + keyOrGetter = this.jwksClients.get(jwksUri) as JWTVerifyGetKey; + } + + const { payload } = await jwtVerify(logoutToken, keyOrGetter as any, { + issuer: client.serverMetadata().issuer, + audience: config.clientId, + algorithms: [algorithm], + maxTokenAge: '2m', + clockTolerance: '5s', + }); + + // Validate specific Logout Token claims (RFC 8963): + // "events" claim must exist and contain the backchannel-logout event + const events = payload.events as Record | undefined; + if (!events || !events['http://schemas.openid.net/event/backchannel-logout']) { + throw new Error('Missing backchannel-logout event claim'); + } + + // "nonce" must not be present + if (payload.nonce) { + throw new Error('Logout token must not contain a nonce'); + } + + return { + sub: payload.sub, + sid: payload.sid as string | undefined, + }; + } catch (error: Error | any) { + this.logger.error(`Error validating JWT logout token: ${error.message}`); + this.logger.error(error); + + throw new Error('Error validating JWT logout token', { cause: error }); + } + } + private async getClient({ issuerUrl, clientId, diff --git a/server/src/repositories/session.repository.ts b/server/src/repositories/session.repository.ts index e008943f21..451b2263e5 100644 --- a/server/src/repositories/session.repository.ts +++ b/server/src/repositories/session.repository.ts @@ -102,7 +102,7 @@ export class SessionRepository { } @GenerateSql({ params: [{ userId: DummyValue.UUID, excludeId: DummyValue.UUID }] }) - async invalidate({ userId, excludeId }: { userId: string; excludeId?: string }) { + async invalidateAll({ userId, excludeId }: { userId: string; excludeId?: string }) { await this.db .deleteFrom('session') .where('userId', '=', userId) @@ -110,6 +110,28 @@ export class SessionRepository { .execute(); } + @GenerateSql({ params: [DummyValue.STRING, DummyValue.STRING] }) + async invalidateOAuth({ oauthSid, oauthId }: { oauthSid?: string; oauthId?: string }): Promise { + let query = this.db.deleteFrom('session').returning('session.id'); + + if (oauthSid && oauthId) { + query = query + .using('user') + .whereRef('user.id', '=', 'session.userId') + .where('session.oauthSid', '=', oauthSid) + .where('user.oauthId', '=', oauthId); + } else if (!oauthSid && oauthId) { + query = query.using('user').whereRef('user.id', '=', 'session.userId').where('user.oauthId', '=', oauthId); + } else if (oauthSid && !oauthId) { + query = query.where('session.oauthSid', '=', oauthSid); + } else { + throw new Error('Invalid arguments: at least one of oauthSid or oauthId must be present'); + } + + const deletedRows = await query.execute(); + return deletedRows.map((row) => row.id); + } + @GenerateSql({ params: [DummyValue.UUID] }) async lockAll(userId: string) { await this.db.updateTable('session').set({ pinExpiresAt: null }).where('userId', '=', userId).execute(); diff --git a/server/src/schema/migrations/1776442031775-AddOauthSidToSession.ts b/server/src/schema/migrations/1776442031775-AddOauthSidToSession.ts new file mode 100644 index 0000000000..7c96bcf8f4 --- /dev/null +++ b/server/src/schema/migrations/1776442031775-AddOauthSidToSession.ts @@ -0,0 +1,11 @@ +import { Kysely, sql } from 'kysely'; + +export async function up(db: Kysely): Promise { + await sql`ALTER TABLE "session" ADD "oauthSid" character varying;`.execute(db); + await sql`CREATE INDEX "session_oauthSid_idx" ON "session" ("oauthSid");`.execute(db); +} + +export async function down(db: Kysely): Promise { + await sql`DROP INDEX "session_oauthSid_idx";`.execute(db); + await sql`ALTER TABLE "session" DROP COLUMN "oauthSid";`.execute(db); +} diff --git a/server/src/schema/tables/session.table.ts b/server/src/schema/tables/session.table.ts index e57628d6da..950c1eeffd 100644 --- a/server/src/schema/tables/session.table.ts +++ b/server/src/schema/tables/session.table.ts @@ -52,4 +52,7 @@ export class SessionTable { @Column({ type: 'timestamp with time zone', nullable: true }) pinExpiresAt!: Timestamp | null; + + @Column({ nullable: true, index: true }) + oauthSid!: string | null; } diff --git a/server/src/services/auth.service.spec.ts b/server/src/services/auth.service.spec.ts index a21790f5fe..8e1c7ff2c4 100644 --- a/server/src/services/auth.service.spec.ts +++ b/server/src/services/auth.service.spec.ts @@ -196,6 +196,64 @@ describe(AuthService.name, () => { }); }); + describe('backchannelLogout', () => { + const dto = { logout_token: 'fake-jwt-token' }; + + it('should throw a Bad Request Exception if OAuth is not enabled', async () => { + await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException); + await expect(sut.backchannelLogout(dto)).rejects.toThrow( + 'Received backchannel logout request but OAuth is not enabled', + ); + }); + + it('should throw a Bad Request Exception if the logout token validation fails', async () => { + mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); + mocks.oauth.validateLogoutToken.mockRejectedValue(new Error('Token validation failed')); + + await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException); + await expect(sut.backchannelLogout(dto)).rejects.toThrow('Error backchannel logout: token validation failed'); + }); + + it('should throw a Bad Request Exception if there are no claims in the logout token', async () => { + mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); + mocks.oauth.validateLogoutToken.mockResolvedValue(null); + + await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException); + await expect(sut.backchannelLogout(dto)).rejects.toThrow('Invalid logout token: no claims found'); + }); + + it('should throw a Bad Request Exception if there is neither the sub nor the sid claim', async () => { + mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); + mocks.oauth.validateLogoutToken.mockResolvedValue({ sub: '', sid: '' }); + + await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException); + await expect(sut.backchannelLogout(dto)).rejects.toThrow( + 'Invalid logout token: it must contain either a sub or a sid claim', + ); + }); + + it('should invalidate the OAuth session(s) if the logout token is valid', async () => { + const claims = { sub: 'fake-sub', sid: 'fake-sid' }; + const deletedSessionIds: string[] = ['fake-session-1', 'fake-session-2']; + + mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); + mocks.oauth.validateLogoutToken.mockResolvedValue(claims); + mocks.session.invalidateOAuth.mockResolvedValue(deletedSessionIds); + mocks.event.emit.mockResolvedValue(void 0); + mocks.event.emit.mockResolvedValue(void 0); + + await sut.backchannelLogout(dto); + + expect(mocks.session.invalidateOAuth).toHaveBeenCalledWith({ + oauthSid: claims.sid, + oauthId: claims.sub, + }); + + expect(mocks.event.emit).toHaveBeenCalledWith('SessionDelete', { sessionId: 'fake-session-1' }); + expect(mocks.event.emit).toHaveBeenCalledWith('SessionDelete', { sessionId: 'fake-session-2' }); + }); + }); + describe('adminSignUp', () => { const dto: SignUpDto = { email: 'test@immich.com', password: 'password', name: 'immich admin' }; @@ -250,6 +308,7 @@ describe(AuthService.name, () => { user: UserFactory.create(), pinExpiresAt: null, appVersion: null, + oauthSid: null, }; mocks.session.getByToken.mockResolvedValue(sessionWithToken); @@ -416,6 +475,7 @@ describe(AuthService.name, () => { user: UserFactory.create(), pinExpiresAt: null, appVersion: null, + oauthSid: null, }; mocks.session.getByToken.mockResolvedValue(sessionWithToken); @@ -444,6 +504,7 @@ describe(AuthService.name, () => { isPendingSyncReset: false, pinExpiresAt: null, appVersion: null, + oauthSid: null, }; mocks.session.getByToken.mockResolvedValue(sessionWithToken); @@ -466,6 +527,7 @@ describe(AuthService.name, () => { isPendingSyncReset: false, pinExpiresAt: null, appVersion: null, + oauthSid: null, }; mocks.session.getByToken.mockResolvedValue(sessionWithToken); @@ -601,7 +663,7 @@ describe(AuthService.name, () => { it('should not allow auto registering', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); mocks.user.getByEmail.mockResolvedValue(void 0); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create()); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() }); await expect( sut.callback( @@ -619,7 +681,7 @@ describe(AuthService.name, () => { const profile = OAuthProfileFactory.create(); mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); - mocks.oauth.getProfile.mockResolvedValue(profile); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile }); mocks.user.getByEmail.mockResolvedValue(user); mocks.user.update.mockResolvedValue(user); mocks.session.create.mockResolvedValue(SessionFactory.create()); @@ -639,7 +701,7 @@ describe(AuthService.name, () => { const profile = OAuthProfileFactory.create({ email: ' TEST@IMMICH.CLOUD ' }); mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); - mocks.oauth.getProfile.mockResolvedValue(profile); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile }); mocks.user.getByEmail.mockResolvedValue(user); mocks.user.update.mockResolvedValue(user); mocks.session.create.mockResolvedValue(SessionFactory.create()); @@ -658,7 +720,7 @@ describe(AuthService.name, () => { const user = UserFactory.create({ oauthId: 'existing-sub' }); mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithAutoRegister); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create()); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() }); mocks.user.getByEmail.mockResolvedValueOnce(user); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); @@ -679,7 +741,7 @@ describe(AuthService.name, () => { mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' })); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create()); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() }); mocks.session.create.mockResolvedValue(SessionFactory.create()); await sut.callback( @@ -698,7 +760,7 @@ describe(AuthService.name, () => { mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.create.mockResolvedValue(UserFactory.create()); mocks.session.create.mockResolvedValue(SessionFactory.create()); - mocks.oauth.getProfile.mockResolvedValue({ sub: 'sub' }); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: { sub: 'sub' } }); await expect( sut.callback( @@ -720,12 +782,12 @@ describe(AuthService.name, () => { it(`should use the mobile redirect override for a url of ${url}`, async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithMobileOverride); mocks.user.getByOAuthId.mockResolvedValue(UserFactory.create()); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create()); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() }); mocks.session.create.mockResolvedValue(SessionFactory.create()); await sut.callback({ url, state: 'xyz789', codeVerifier: 'foo' }, {}, loginDetails); - expect(mocks.oauth.getProfile).toHaveBeenCalledWith( + expect(mocks.oauth.getProfileAndOAuthSid).toHaveBeenCalledWith( expect.objectContaining({}), 'http://mobile-redirect?code=abc123', 'xyz789', @@ -738,7 +800,7 @@ describe(AuthService.name, () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create()); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() }); mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' })); mocks.session.create.mockResolvedValue(SessionFactory.create()); @@ -753,9 +815,9 @@ describe(AuthService.name, () => { it('should infer name from given and family names', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); - mocks.oauth.getProfile.mockResolvedValue( - OAuthProfileFactory.create({ name: undefined, given_name: 'Given', family_name: 'Family' }), - ); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: OAuthProfileFactory.create({ name: undefined, given_name: 'Given', family_name: 'Family' }), + }); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.create.mockResolvedValue(UserFactory.create()); @@ -774,7 +836,7 @@ describe(AuthService.name, () => { const profile = OAuthProfileFactory.create({ name: undefined, given_name: undefined, family_name: undefined }); mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); - mocks.oauth.getProfile.mockResolvedValue(profile); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile }); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.create.mockResolvedValue(UserFactory.create()); @@ -791,7 +853,9 @@ describe(AuthService.name, () => { it('should ignore an invalid storage quota', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: 'abc' })); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: OAuthProfileFactory.create({ immich_quota: 'abc' }), + }); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' })); @@ -808,7 +872,9 @@ describe(AuthService.name, () => { it('should ignore a negative quota', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: -5 })); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: OAuthProfileFactory.create({ immich_quota: -5 }), + }); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' })); @@ -825,7 +891,7 @@ describe(AuthService.name, () => { it('should set quota for 0 quota', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: 0 })); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create({ immich_quota: 0 }) }); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' })); @@ -842,7 +908,7 @@ describe(AuthService.name, () => { it('should use a valid storage quota', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: 5 })); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create({ immich_quota: 5 }) }); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.getByOAuthId.mockResolvedValue(void 0); @@ -864,7 +930,7 @@ describe(AuthService.name, () => { const profile = OAuthProfileFactory.create({ picture: 'https://auth.immich.cloud/profiles/1.jpg' }); mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); - mocks.oauth.getProfile.mockResolvedValue(profile); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile }); mocks.user.getByOAuthId.mockResolvedValue(user); mocks.crypto.randomUUID.mockReturnValue(fileId); mocks.oauth.getProfilePicture.mockResolvedValue({ @@ -891,13 +957,13 @@ describe(AuthService.name, () => { const user = UserFactory.create({ oauthId: 'oauth-id', profileImagePath: 'not-empty' }); mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); - mocks.oauth.getProfile.mockResolvedValue( - OAuthProfileFactory.create({ + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: OAuthProfileFactory.create({ sub: user.oauthId, email: user.email, picture: 'https://auth.immich.cloud/profiles/1.jpg', }), - ); + }); mocks.user.getByOAuthId.mockResolvedValue(user); mocks.user.update.mockResolvedValue(user); mocks.session.create.mockResolvedValue(SessionFactory.create()); @@ -914,7 +980,9 @@ describe(AuthService.name, () => { it('should only allow "admin" and "user" for the role claim', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithAutoRegister); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_role: 'foo' })); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: OAuthProfileFactory.create({ immich_role: 'foo' }), + }); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true })); mocks.user.getByOAuthId.mockResolvedValue(void 0); @@ -932,7 +1000,9 @@ describe(AuthService.name, () => { it('should create an admin user if the role claim is set to admin', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithAutoRegister); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_role: 'admin' })); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: OAuthProfileFactory.create({ immich_role: 'admin' }), + }); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getByOAuthId.mockResolvedValue(void 0); mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' })); @@ -951,7 +1021,9 @@ describe(AuthService.name, () => { mocks.systemMetadata.get.mockResolvedValue({ oauth: { ...systemConfigStub.oauthWithAutoRegister.oauth, roleClaim: 'my_role' }, }); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ my_role: 'admin' })); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: OAuthProfileFactory.create({ my_role: 'admin' }), + }); mocks.user.getByEmail.mockResolvedValue(void 0); mocks.user.getByOAuthId.mockResolvedValue(void 0); mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' })); @@ -974,7 +1046,7 @@ describe(AuthService.name, () => { const profile = OAuthProfileFactory.create(); mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); - mocks.oauth.getProfile.mockResolvedValue(profile); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile }); mocks.user.update.mockResolvedValue(user); await sut.link( @@ -986,13 +1058,36 @@ describe(AuthService.name, () => { expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: profile.sub }); }); + it('should link an account and update the session with the oauthSid', async () => { + const user = UserFactory.create(); + const session = SessionFactory.create(); + const auth = AuthFactory.from(user).session(session).build(); + + mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ + profile: { sub: 'sub' }, + sid: session.oauthSid ?? undefined, + }); + mocks.user.update.mockResolvedValue(user); + mocks.session.update.mockResolvedValue(session); + + await sut.link( + auth, + { url: 'http://immich/user-settings?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + ); + + expect(mocks.session.update).toHaveBeenCalledWith(session.id, { oauthSid: session.oauthSid }); + expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: 'sub' }); + }); + it('should not link an already linked oauth.sub', async () => { const authUser = UserFactory.create(); const authApiKey = ApiKeyFactory.create({ permissions: [] }); const auth = { user: authUser, apiKey: authApiKey }; mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); - mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create()); + mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() }); mocks.user.getByOAuthId.mockResolvedValue({ id: 'other-user' } as UserAdmin); await expect( @@ -1015,6 +1110,21 @@ describe(AuthService.name, () => { expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: '' }); }); + + it('should unlink an account and remove the oauthSid from the session', async () => { + const user = UserFactory.create(); + const session = SessionFactory.create(); + const auth = AuthFactory.from(user).session(session).build(); + + mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); + mocks.session.update.mockResolvedValue(session); + mocks.user.update.mockResolvedValue(user); + + await sut.unlink(auth); + + expect(mocks.session.update).toHaveBeenCalledWith(session.id, { oauthSid: null }); + expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: '' }); + }); }); describe('setupPinCode', () => { diff --git a/server/src/services/auth.service.ts b/server/src/services/auth.service.ts index ea3a896298..1824b043ef 100644 --- a/server/src/services/auth.service.ts +++ b/server/src/services/auth.service.ts @@ -12,6 +12,7 @@ import { ChangePasswordDto, LoginCredentialDto, LogoutResponseDto, + OAuthBackchannelLogoutDto, OAuthCallbackDto, OAuthConfigDto, PinCodeChangeDto, @@ -91,6 +92,40 @@ export class AuthService extends BaseService { }; } + async backchannelLogout(dto: OAuthBackchannelLogoutDto): Promise { + const { oauth } = await this.getConfig({ withCache: false }); + if (!oauth.enabled) { + throw new BadRequestException('Received backchannel logout request but OAuth is not enabled'); + } + + let claims; + try { + claims = await this.oauthRepository.validateLogoutToken(oauth, dto.logout_token); + } catch (error: Error | any) { + this.logger.error(`Error backchannel logout: ${error.message}`); + this.logger.error(error); + + throw new BadRequestException('Error backchannel logout: token validation failed'); + } + + if (!claims) { + throw new BadRequestException('Invalid logout token: no claims found'); + } + + if (!claims.sub && !claims.sid) { + throw new BadRequestException('Invalid logout token: it must contain either a sub or a sid claim'); + } + + const deletedSessionIds = await this.sessionRepository.invalidateOAuth({ + oauthSid: claims.sid, + oauthId: claims.sub, + }); + + for (const sessionId of deletedSessionIds) { + await this.eventRepository.emit('SessionDelete', { sessionId }); + } + } + async changePassword(auth: AuthDto, dto: ChangePasswordDto): Promise { const { password, newPassword } = dto; const user = await this.userRepository.getForChangePassword(auth.user.id); @@ -276,7 +311,12 @@ export class AuthService extends BaseService { } const url = this.resolveRedirectUri(oauth, dto.url); - const profile = await this.oauthRepository.getProfile(oauth, url, expectedState, codeVerifier); + const { profile, sid: oauthSid } = await this.oauthRepository.getProfileAndOAuthSid( + oauth, + url, + expectedState, + codeVerifier, + ); const normalizedEmail = profile.email ? profile.email.trim().toLowerCase() : undefined; const { autoRegister, defaultStorageQuota, storageLabelClaim, storageQuotaClaim, roleClaim } = oauth; this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`); @@ -342,7 +382,7 @@ export class AuthService extends BaseService { await this.syncProfilePicture(user, profile.picture); } - return this.createLoginResponse(user, loginDetails); + return this.createLoginResponse(user, loginDetails, oauthSid); } private async syncProfilePicture(user: UserAdmin, url: string) { @@ -380,18 +420,29 @@ export class AuthService extends BaseService { } const { oauth } = await this.getConfig({ withCache: false }); - const { sub: oauthId } = await this.oauthRepository.getProfile(oauth, dto.url, expectedState, codeVerifier); + const { + profile: { sub: oauthId }, + sid, + } = await this.oauthRepository.getProfileAndOAuthSid(oauth, dto.url, expectedState, codeVerifier); const duplicate = await this.userRepository.getByOAuthId(oauthId); if (duplicate && duplicate.id !== auth.user.id) { this.logger.warn(`OAuth link account failed: sub is already linked to another user (${duplicate.email}).`); throw new BadRequestException('This OAuth account has already been linked to another user.'); } + if (auth.session) { + await this.sessionRepository.update(auth.session.id, { oauthSid: sid }); + } + const user = await this.userRepository.update(auth.user.id, { oauthId }); return mapUserAdmin(user); } async unlink(auth: AuthDto): Promise { + if (auth.session) { + await this.sessionRepository.update(auth.session.id, { oauthSid: null }); + } + const user = await this.userRepository.update(auth.user.id, { oauthId: '' }); return mapUserAdmin(user); } @@ -548,7 +599,7 @@ export class AuthService extends BaseService { await this.sessionRepository.update(auth.session.id, { pinExpiresAt: null }); } - private async createLoginResponse(user: UserAdmin, loginDetails: LoginDetails) { + private async createLoginResponse(user: UserAdmin, loginDetails: LoginDetails, oauthSid?: string) { const token = this.cryptoRepository.randomBytesAsText(32); const hashed = this.cryptoRepository.hashSha256(token); @@ -558,6 +609,7 @@ export class AuthService extends BaseService { deviceType: loginDetails.deviceType, appVersion: loginDetails.appVersion, userId: user.id, + oauthSid: oauthSid ?? null, }); return mapLoginResponse(user, token); diff --git a/server/src/services/session.service.spec.ts b/server/src/services/session.service.spec.ts index 8f4409a508..c6bd5f2f72 100644 --- a/server/src/services/session.service.spec.ts +++ b/server/src/services/session.service.spec.ts @@ -46,11 +46,14 @@ describe('SessionService', () => { const currentSession = SessionFactory.create(); const auth = AuthFactory.from().session(currentSession).build(); - mocks.session.invalidate.mockResolvedValue(); + mocks.session.invalidateAll.mockResolvedValue(); await sut.deleteAll(auth); - expect(mocks.session.invalidate).toHaveBeenCalledWith({ userId: auth.user.id, excludeId: currentSession.id }); + expect(mocks.session.invalidateAll).toHaveBeenCalledWith({ + userId: auth.user.id, + excludeId: currentSession.id, + }); }); }); diff --git a/server/src/services/session.service.ts b/server/src/services/session.service.ts index 8b5bd13928..735a8c2453 100644 --- a/server/src/services/session.service.ts +++ b/server/src/services/session.service.ts @@ -73,7 +73,7 @@ export class SessionService extends BaseService { async deleteAll(auth: AuthDto): Promise { const userId = auth.user.id; const currentSessionId = auth.session?.id; - await this.sessionRepository.invalidate({ userId, excludeId: currentSessionId }); + await this.sessionRepository.invalidateAll({ userId, excludeId: currentSessionId }); } async lock(auth: AuthDto, id: string): Promise { @@ -83,6 +83,6 @@ export class SessionService extends BaseService { @OnEvent({ name: 'AuthChangePassword' }) async onAuthChangePassword({ userId, currentSessionId }: ArgOf<'AuthChangePassword'>): Promise { - await this.sessionRepository.invalidate({ userId, excludeId: currentSessionId }); + await this.sessionRepository.invalidateAll({ userId, excludeId: currentSessionId }); } } diff --git a/server/test/factories/session.factory.ts b/server/test/factories/session.factory.ts index 8d4cb28727..44a25edcfa 100644 --- a/server/test/factories/session.factory.ts +++ b/server/test/factories/session.factory.ts @@ -25,6 +25,7 @@ export class SessionFactory { updateId: newUuidV7(), updatedAt: newDate(), userId: newUuid(), + oauthSid: newUuid(), ...dto, }); }