feat(server): added backchannel logout api endpoint (#26235)

* feat(server): added backchannel logout api endpoint

* test(server): fixed e2e tests

* fix(server): fixed suggested changes by reviewer

* feat(server): created function invalidateOAuth

* fix(server): fixed session.repository.sql

* test(server): added unit tests for backchannelLogout function

* test(server): added e2e tests for oidc backchnnel logout

* docs(server): added documentation on backchannel logout url

* docs(server): fixed typo

* feat(server): minor improvements of the oidc backchannel logout

* test(server): fixed tests after merge with main

* fix(server): fixed e2e test file

* refactor(server): tiny refactor of validateLogoutToken

* chore: cleanup

* fix: tests

* fix: make jwks extractable

---------

Co-authored-by: Daniel Dietzler <mail@ddietzler.dev>
This commit is contained in:
santanoce
2026-04-17 20:45:33 +02:00
committed by GitHub
parent 8afca348ff
commit dbf30b77bf
21 changed files with 558 additions and 47 deletions
+2 -1
View File
@@ -1,5 +1,5 @@
import { NestExpressApplication } from '@nestjs/platform-express';
import { json } from 'body-parser';
import { json, urlencoded } from 'body-parser';
import compression from 'compression';
import cookieParser from 'cookie-parser';
import helmetMiddleware from 'helmet';
@@ -56,6 +56,7 @@ export async function configureExpress(
app.use(cookieParser());
app.use(json({ limit: '10mb' }));
app.use(urlencoded({ limit: '10mb' }));
if (configRepository.isDev()) {
app.enableCors();
+15 -1
View File
@@ -1,11 +1,12 @@
import { Body, Controller, Get, HttpCode, HttpStatus, Post, Redirect, Req, Res } from '@nestjs/common';
import { ApiTags } from '@nestjs/swagger';
import { ApiConsumes, ApiTags } from '@nestjs/swagger';
import { Request, Response } from 'express';
import { Endpoint, HistoryBuilder } from 'src/decorators';
import {
AuthDto,
LoginResponseDto,
OAuthAuthorizeResponseDto,
OAuthBackchannelLogoutDto,
OAuthCallbackDto,
OAuthConfigDto,
} from 'src/dtos/auth.dto';
@@ -112,4 +113,17 @@ export class OAuthController {
unlinkOAuthAccount(@Auth() auth: AuthDto): Promise<UserAdminResponseDto> {
return this.service.unlink(auth);
}
@Post('backchannel-logout')
@HttpCode(HttpStatus.OK)
@ApiConsumes('application/x-www-form-urlencoded')
@Endpoint({
summary: 'Backchannel OAuth logout',
description:
'Logout the OAuth account and invalidate the session specified by the sid claim or all sessions if the sid claim is not present.',
history: new HistoryBuilder().added('v2'),
})
async logoutOAuth(@Body() dto: OAuthBackchannelLogoutDto): Promise<void> {
return this.service.backchannelLogout(dto);
}
}
+5
View File
@@ -124,6 +124,10 @@ const OAuthAuthorizeResponseSchema = z
})
.meta({ id: 'OAuthAuthorizeResponseDto' });
const OAuthBackchannelLogoutSchema = z
.object({ logout_token: z.string().describe('OAuth logout token') })
.meta({ id: 'OAuthBackchannelLogoutDto' });
const AuthStatusResponseSchema = z
.object({
pinCode: z.boolean().describe('Has PIN code set'),
@@ -147,4 +151,5 @@ export class ValidateAccessTokenResponseDto extends createZodDto(ValidateAccessT
export class OAuthCallbackDto extends createZodDto(OAuthCallbackSchema) {}
export class OAuthConfigDto extends createZodDto(OAuthConfigSchema) {}
export class OAuthAuthorizeResponseDto extends createZodDto(OAuthAuthorizeResponseSchema) {}
export class OAuthBackchannelLogoutDto extends createZodDto(OAuthBackchannelLogoutSchema) {}
export class AuthStatusResponseDto extends createZodDto(AuthStatusResponseSchema) {}
+1 -1
View File
@@ -74,7 +74,7 @@ delete from "session"
where
"id" = $1::uuid
-- SessionRepository.invalidate
-- SessionRepository.invalidateAll
delete from "session"
where
"userId" = $1
+65 -3
View File
@@ -1,4 +1,5 @@
import { Injectable, InternalServerErrorException } from '@nestjs/common';
import { createRemoteJWKSet, jwtVerify, JWTVerifyGetKey } from 'jose';
import {
allowInsecureRequests as allowInsecureRequestsExecute,
authorizationCodeGrant,
@@ -71,12 +72,12 @@ export class OAuthRepository {
return client.serverMetadata().end_session_endpoint;
}
async getProfile(
async getProfileAndOAuthSid(
config: OAuthConfig,
url: string,
expectedState: string,
codeVerifier: string,
): Promise<OAuthProfile> {
): Promise<{ profile: OAuthProfile; sid?: string }> {
const client = await this.getClient(config);
const pkceCodeVerifier = client.serverMetadata().supportsPKCE() ? codeVerifier : undefined;
@@ -96,7 +97,15 @@ export class OAuthRepository {
throw new Error('Unexpected profile response, no `sub`');
}
return profile;
let sid: string | undefined;
if (tokens.id_token) {
const claims = tokens.claims();
if (typeof claims?.sid === 'string') {
sid = claims.sid;
}
}
return { profile, sid };
} catch (error: Error | any) {
if (error.message.includes('unexpected JWT alg received')) {
this.logger.warn(
@@ -126,6 +135,59 @@ export class OAuthRepository {
};
}
private jwksClients: Map<string, JWTVerifyGetKey> = new Map(); // useful for caching and performnce
async validateLogoutToken(config: OAuthConfig, logoutToken: string): Promise<{ sub?: string; sid?: string } | null> {
const client = await this.getClient(config);
const algorithm = client.clientMetadata().id_token_signed_response_alg ?? 'RS256';
let keyOrGetter: Uint8Array | JWTVerifyGetKey;
try {
if (algorithm.startsWith('HS')) {
keyOrGetter = new TextEncoder().encode(config.clientSecret);
} else {
const jwksUri = client.serverMetadata().jwks_uri;
if (!jwksUri) {
throw new Error('Unable to get JWKS URI');
}
if (!this.jwksClients.has(jwksUri)) {
this.jwksClients.set(jwksUri, createRemoteJWKSet(new URL(jwksUri)));
}
keyOrGetter = this.jwksClients.get(jwksUri) as JWTVerifyGetKey;
}
const { payload } = await jwtVerify(logoutToken, keyOrGetter as any, {
issuer: client.serverMetadata().issuer,
audience: config.clientId,
algorithms: [algorithm],
maxTokenAge: '2m',
clockTolerance: '5s',
});
// Validate specific Logout Token claims (RFC 8963):
// "events" claim must exist and contain the backchannel-logout event
const events = payload.events as Record<string, any> | undefined;
if (!events || !events['http://schemas.openid.net/event/backchannel-logout']) {
throw new Error('Missing backchannel-logout event claim');
}
// "nonce" must not be present
if (payload.nonce) {
throw new Error('Logout token must not contain a nonce');
}
return {
sub: payload.sub,
sid: payload.sid as string | undefined,
};
} catch (error: Error | any) {
this.logger.error(`Error validating JWT logout token: ${error.message}`);
this.logger.error(error);
throw new Error('Error validating JWT logout token', { cause: error });
}
}
private async getClient({
issuerUrl,
clientId,
+23 -1
View File
@@ -102,7 +102,7 @@ export class SessionRepository {
}
@GenerateSql({ params: [{ userId: DummyValue.UUID, excludeId: DummyValue.UUID }] })
async invalidate({ userId, excludeId }: { userId: string; excludeId?: string }) {
async invalidateAll({ userId, excludeId }: { userId: string; excludeId?: string }) {
await this.db
.deleteFrom('session')
.where('userId', '=', userId)
@@ -110,6 +110,28 @@ export class SessionRepository {
.execute();
}
@GenerateSql({ params: [DummyValue.STRING, DummyValue.STRING] })
async invalidateOAuth({ oauthSid, oauthId }: { oauthSid?: string; oauthId?: string }): Promise<string[]> {
let query = this.db.deleteFrom('session').returning('session.id');
if (oauthSid && oauthId) {
query = query
.using('user')
.whereRef('user.id', '=', 'session.userId')
.where('session.oauthSid', '=', oauthSid)
.where('user.oauthId', '=', oauthId);
} else if (!oauthSid && oauthId) {
query = query.using('user').whereRef('user.id', '=', 'session.userId').where('user.oauthId', '=', oauthId);
} else if (oauthSid && !oauthId) {
query = query.where('session.oauthSid', '=', oauthSid);
} else {
throw new Error('Invalid arguments: at least one of oauthSid or oauthId must be present');
}
const deletedRows = await query.execute();
return deletedRows.map((row) => row.id);
}
@GenerateSql({ params: [DummyValue.UUID] })
async lockAll(userId: string) {
await this.db.updateTable('session').set({ pinExpiresAt: null }).where('userId', '=', userId).execute();
@@ -0,0 +1,11 @@
import { Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await sql`ALTER TABLE "session" ADD "oauthSid" character varying;`.execute(db);
await sql`CREATE INDEX "session_oauthSid_idx" ON "session" ("oauthSid");`.execute(db);
}
export async function down(db: Kysely<any>): Promise<void> {
await sql`DROP INDEX "session_oauthSid_idx";`.execute(db);
await sql`ALTER TABLE "session" DROP COLUMN "oauthSid";`.execute(db);
}
@@ -52,4 +52,7 @@ export class SessionTable {
@Column({ type: 'timestamp with time zone', nullable: true })
pinExpiresAt!: Timestamp | null;
@Column({ nullable: true, index: true })
oauthSid!: string | null;
}
+136 -26
View File
@@ -196,6 +196,64 @@ describe(AuthService.name, () => {
});
});
describe('backchannelLogout', () => {
const dto = { logout_token: 'fake-jwt-token' };
it('should throw a Bad Request Exception if OAuth is not enabled', async () => {
await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException);
await expect(sut.backchannelLogout(dto)).rejects.toThrow(
'Received backchannel logout request but OAuth is not enabled',
);
});
it('should throw a Bad Request Exception if the logout token validation fails', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.validateLogoutToken.mockRejectedValue(new Error('Token validation failed'));
await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException);
await expect(sut.backchannelLogout(dto)).rejects.toThrow('Error backchannel logout: token validation failed');
});
it('should throw a Bad Request Exception if there are no claims in the logout token', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.validateLogoutToken.mockResolvedValue(null);
await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException);
await expect(sut.backchannelLogout(dto)).rejects.toThrow('Invalid logout token: no claims found');
});
it('should throw a Bad Request Exception if there is neither the sub nor the sid claim', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.validateLogoutToken.mockResolvedValue({ sub: '', sid: '' });
await expect(sut.backchannelLogout(dto)).rejects.toBeInstanceOf(BadRequestException);
await expect(sut.backchannelLogout(dto)).rejects.toThrow(
'Invalid logout token: it must contain either a sub or a sid claim',
);
});
it('should invalidate the OAuth session(s) if the logout token is valid', async () => {
const claims = { sub: 'fake-sub', sid: 'fake-sid' };
const deletedSessionIds: string[] = ['fake-session-1', 'fake-session-2'];
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.validateLogoutToken.mockResolvedValue(claims);
mocks.session.invalidateOAuth.mockResolvedValue(deletedSessionIds);
mocks.event.emit.mockResolvedValue(void 0);
mocks.event.emit.mockResolvedValue(void 0);
await sut.backchannelLogout(dto);
expect(mocks.session.invalidateOAuth).toHaveBeenCalledWith({
oauthSid: claims.sid,
oauthId: claims.sub,
});
expect(mocks.event.emit).toHaveBeenCalledWith('SessionDelete', { sessionId: 'fake-session-1' });
expect(mocks.event.emit).toHaveBeenCalledWith('SessionDelete', { sessionId: 'fake-session-2' });
});
});
describe('adminSignUp', () => {
const dto: SignUpDto = { email: 'test@immich.com', password: 'password', name: 'immich admin' };
@@ -250,6 +308,7 @@ describe(AuthService.name, () => {
user: UserFactory.create(),
pinExpiresAt: null,
appVersion: null,
oauthSid: null,
};
mocks.session.getByToken.mockResolvedValue(sessionWithToken);
@@ -416,6 +475,7 @@ describe(AuthService.name, () => {
user: UserFactory.create(),
pinExpiresAt: null,
appVersion: null,
oauthSid: null,
};
mocks.session.getByToken.mockResolvedValue(sessionWithToken);
@@ -444,6 +504,7 @@ describe(AuthService.name, () => {
isPendingSyncReset: false,
pinExpiresAt: null,
appVersion: null,
oauthSid: null,
};
mocks.session.getByToken.mockResolvedValue(sessionWithToken);
@@ -466,6 +527,7 @@ describe(AuthService.name, () => {
isPendingSyncReset: false,
pinExpiresAt: null,
appVersion: null,
oauthSid: null,
};
mocks.session.getByToken.mockResolvedValue(sessionWithToken);
@@ -601,7 +663,7 @@ describe(AuthService.name, () => {
it('should not allow auto registering', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create());
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() });
await expect(
sut.callback(
@@ -619,7 +681,7 @@ describe(AuthService.name, () => {
const profile = OAuthProfileFactory.create();
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.getProfile.mockResolvedValue(profile);
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile });
mocks.user.getByEmail.mockResolvedValue(user);
mocks.user.update.mockResolvedValue(user);
mocks.session.create.mockResolvedValue(SessionFactory.create());
@@ -639,7 +701,7 @@ describe(AuthService.name, () => {
const profile = OAuthProfileFactory.create({ email: ' TEST@IMMICH.CLOUD ' });
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.getProfile.mockResolvedValue(profile);
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile });
mocks.user.getByEmail.mockResolvedValue(user);
mocks.user.update.mockResolvedValue(user);
mocks.session.create.mockResolvedValue(SessionFactory.create());
@@ -658,7 +720,7 @@ describe(AuthService.name, () => {
const user = UserFactory.create({ oauthId: 'existing-sub' });
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithAutoRegister);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create());
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() });
mocks.user.getByEmail.mockResolvedValueOnce(user);
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
@@ -679,7 +741,7 @@ describe(AuthService.name, () => {
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' }));
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create());
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() });
mocks.session.create.mockResolvedValue(SessionFactory.create());
await sut.callback(
@@ -698,7 +760,7 @@ describe(AuthService.name, () => {
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.create.mockResolvedValue(UserFactory.create());
mocks.session.create.mockResolvedValue(SessionFactory.create());
mocks.oauth.getProfile.mockResolvedValue({ sub: 'sub' });
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: { sub: 'sub' } });
await expect(
sut.callback(
@@ -720,12 +782,12 @@ describe(AuthService.name, () => {
it(`should use the mobile redirect override for a url of ${url}`, async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithMobileOverride);
mocks.user.getByOAuthId.mockResolvedValue(UserFactory.create());
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create());
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() });
mocks.session.create.mockResolvedValue(SessionFactory.create());
await sut.callback({ url, state: 'xyz789', codeVerifier: 'foo' }, {}, loginDetails);
expect(mocks.oauth.getProfile).toHaveBeenCalledWith(
expect(mocks.oauth.getProfileAndOAuthSid).toHaveBeenCalledWith(
expect.objectContaining({}),
'http://mobile-redirect?code=abc123',
'xyz789',
@@ -738,7 +800,7 @@ describe(AuthService.name, () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota);
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create());
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() });
mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' }));
mocks.session.create.mockResolvedValue(SessionFactory.create());
@@ -753,9 +815,9 @@ describe(AuthService.name, () => {
it('should infer name from given and family names', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled);
mocks.oauth.getProfile.mockResolvedValue(
OAuthProfileFactory.create({ name: undefined, given_name: 'Given', family_name: 'Family' }),
);
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: OAuthProfileFactory.create({ name: undefined, given_name: 'Given', family_name: 'Family' }),
});
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.create.mockResolvedValue(UserFactory.create());
@@ -774,7 +836,7 @@ describe(AuthService.name, () => {
const profile = OAuthProfileFactory.create({ name: undefined, given_name: undefined, family_name: undefined });
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled);
mocks.oauth.getProfile.mockResolvedValue(profile);
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile });
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.create.mockResolvedValue(UserFactory.create());
@@ -791,7 +853,9 @@ describe(AuthService.name, () => {
it('should ignore an invalid storage quota', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: 'abc' }));
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: OAuthProfileFactory.create({ immich_quota: 'abc' }),
});
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' }));
@@ -808,7 +872,9 @@ describe(AuthService.name, () => {
it('should ignore a negative quota', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: -5 }));
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: OAuthProfileFactory.create({ immich_quota: -5 }),
});
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' }));
@@ -825,7 +891,7 @@ describe(AuthService.name, () => {
it('should set quota for 0 quota', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: 0 }));
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create({ immich_quota: 0 }) });
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' }));
@@ -842,7 +908,7 @@ describe(AuthService.name, () => {
it('should use a valid storage quota', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithStorageQuota);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_quota: 5 }));
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create({ immich_quota: 5 }) });
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.getByOAuthId.mockResolvedValue(void 0);
@@ -864,7 +930,7 @@ describe(AuthService.name, () => {
const profile = OAuthProfileFactory.create({ picture: 'https://auth.immich.cloud/profiles/1.jpg' });
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.getProfile.mockResolvedValue(profile);
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile });
mocks.user.getByOAuthId.mockResolvedValue(user);
mocks.crypto.randomUUID.mockReturnValue(fileId);
mocks.oauth.getProfilePicture.mockResolvedValue({
@@ -891,13 +957,13 @@ describe(AuthService.name, () => {
const user = UserFactory.create({ oauthId: 'oauth-id', profileImagePath: 'not-empty' });
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled);
mocks.oauth.getProfile.mockResolvedValue(
OAuthProfileFactory.create({
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: OAuthProfileFactory.create({
sub: user.oauthId,
email: user.email,
picture: 'https://auth.immich.cloud/profiles/1.jpg',
}),
);
});
mocks.user.getByOAuthId.mockResolvedValue(user);
mocks.user.update.mockResolvedValue(user);
mocks.session.create.mockResolvedValue(SessionFactory.create());
@@ -914,7 +980,9 @@ describe(AuthService.name, () => {
it('should only allow "admin" and "user" for the role claim', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithAutoRegister);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_role: 'foo' }));
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: OAuthProfileFactory.create({ immich_role: 'foo' }),
});
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getAdmin.mockResolvedValue(UserFactory.create({ isAdmin: true }));
mocks.user.getByOAuthId.mockResolvedValue(void 0);
@@ -932,7 +1000,9 @@ describe(AuthService.name, () => {
it('should create an admin user if the role claim is set to admin', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthWithAutoRegister);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ immich_role: 'admin' }));
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: OAuthProfileFactory.create({ immich_role: 'admin' }),
});
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getByOAuthId.mockResolvedValue(void 0);
mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' }));
@@ -951,7 +1021,9 @@ describe(AuthService.name, () => {
mocks.systemMetadata.get.mockResolvedValue({
oauth: { ...systemConfigStub.oauthWithAutoRegister.oauth, roleClaim: 'my_role' },
});
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create({ my_role: 'admin' }));
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: OAuthProfileFactory.create({ my_role: 'admin' }),
});
mocks.user.getByEmail.mockResolvedValue(void 0);
mocks.user.getByOAuthId.mockResolvedValue(void 0);
mocks.user.create.mockResolvedValue(UserFactory.create({ oauthId: 'oauth-id' }));
@@ -974,7 +1046,7 @@ describe(AuthService.name, () => {
const profile = OAuthProfileFactory.create();
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled);
mocks.oauth.getProfile.mockResolvedValue(profile);
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile });
mocks.user.update.mockResolvedValue(user);
await sut.link(
@@ -986,13 +1058,36 @@ describe(AuthService.name, () => {
expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: profile.sub });
});
it('should link an account and update the session with the oauthSid', async () => {
const user = UserFactory.create();
const session = SessionFactory.create();
const auth = AuthFactory.from(user).session(session).build();
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled);
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({
profile: { sub: 'sub' },
sid: session.oauthSid ?? undefined,
});
mocks.user.update.mockResolvedValue(user);
mocks.session.update.mockResolvedValue(session);
await sut.link(
auth,
{ url: 'http://immich/user-settings?code=abc123', state: 'xyz789', codeVerifier: 'foo' },
{},
);
expect(mocks.session.update).toHaveBeenCalledWith(session.id, { oauthSid: session.oauthSid });
expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: 'sub' });
});
it('should not link an already linked oauth.sub', async () => {
const authUser = UserFactory.create();
const authApiKey = ApiKeyFactory.create({ permissions: [] });
const auth = { user: authUser, apiKey: authApiKey };
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled);
mocks.oauth.getProfile.mockResolvedValue(OAuthProfileFactory.create());
mocks.oauth.getProfileAndOAuthSid.mockResolvedValue({ profile: OAuthProfileFactory.create() });
mocks.user.getByOAuthId.mockResolvedValue({ id: 'other-user' } as UserAdmin);
await expect(
@@ -1015,6 +1110,21 @@ describe(AuthService.name, () => {
expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: '' });
});
it('should unlink an account and remove the oauthSid from the session', async () => {
const user = UserFactory.create();
const session = SessionFactory.create();
const auth = AuthFactory.from(user).session(session).build();
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled);
mocks.session.update.mockResolvedValue(session);
mocks.user.update.mockResolvedValue(user);
await sut.unlink(auth);
expect(mocks.session.update).toHaveBeenCalledWith(session.id, { oauthSid: null });
expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: '' });
});
});
describe('setupPinCode', () => {
+56 -4
View File
@@ -12,6 +12,7 @@ import {
ChangePasswordDto,
LoginCredentialDto,
LogoutResponseDto,
OAuthBackchannelLogoutDto,
OAuthCallbackDto,
OAuthConfigDto,
PinCodeChangeDto,
@@ -91,6 +92,40 @@ export class AuthService extends BaseService {
};
}
async backchannelLogout(dto: OAuthBackchannelLogoutDto): Promise<void> {
const { oauth } = await this.getConfig({ withCache: false });
if (!oauth.enabled) {
throw new BadRequestException('Received backchannel logout request but OAuth is not enabled');
}
let claims;
try {
claims = await this.oauthRepository.validateLogoutToken(oauth, dto.logout_token);
} catch (error: Error | any) {
this.logger.error(`Error backchannel logout: ${error.message}`);
this.logger.error(error);
throw new BadRequestException('Error backchannel logout: token validation failed');
}
if (!claims) {
throw new BadRequestException('Invalid logout token: no claims found');
}
if (!claims.sub && !claims.sid) {
throw new BadRequestException('Invalid logout token: it must contain either a sub or a sid claim');
}
const deletedSessionIds = await this.sessionRepository.invalidateOAuth({
oauthSid: claims.sid,
oauthId: claims.sub,
});
for (const sessionId of deletedSessionIds) {
await this.eventRepository.emit('SessionDelete', { sessionId });
}
}
async changePassword(auth: AuthDto, dto: ChangePasswordDto): Promise<UserAdminResponseDto> {
const { password, newPassword } = dto;
const user = await this.userRepository.getForChangePassword(auth.user.id);
@@ -276,7 +311,12 @@ export class AuthService extends BaseService {
}
const url = this.resolveRedirectUri(oauth, dto.url);
const profile = await this.oauthRepository.getProfile(oauth, url, expectedState, codeVerifier);
const { profile, sid: oauthSid } = await this.oauthRepository.getProfileAndOAuthSid(
oauth,
url,
expectedState,
codeVerifier,
);
const normalizedEmail = profile.email ? profile.email.trim().toLowerCase() : undefined;
const { autoRegister, defaultStorageQuota, storageLabelClaim, storageQuotaClaim, roleClaim } = oauth;
this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`);
@@ -342,7 +382,7 @@ export class AuthService extends BaseService {
await this.syncProfilePicture(user, profile.picture);
}
return this.createLoginResponse(user, loginDetails);
return this.createLoginResponse(user, loginDetails, oauthSid);
}
private async syncProfilePicture(user: UserAdmin, url: string) {
@@ -380,18 +420,29 @@ export class AuthService extends BaseService {
}
const { oauth } = await this.getConfig({ withCache: false });
const { sub: oauthId } = await this.oauthRepository.getProfile(oauth, dto.url, expectedState, codeVerifier);
const {
profile: { sub: oauthId },
sid,
} = await this.oauthRepository.getProfileAndOAuthSid(oauth, dto.url, expectedState, codeVerifier);
const duplicate = await this.userRepository.getByOAuthId(oauthId);
if (duplicate && duplicate.id !== auth.user.id) {
this.logger.warn(`OAuth link account failed: sub is already linked to another user (${duplicate.email}).`);
throw new BadRequestException('This OAuth account has already been linked to another user.');
}
if (auth.session) {
await this.sessionRepository.update(auth.session.id, { oauthSid: sid });
}
const user = await this.userRepository.update(auth.user.id, { oauthId });
return mapUserAdmin(user);
}
async unlink(auth: AuthDto): Promise<UserAdminResponseDto> {
if (auth.session) {
await this.sessionRepository.update(auth.session.id, { oauthSid: null });
}
const user = await this.userRepository.update(auth.user.id, { oauthId: '' });
return mapUserAdmin(user);
}
@@ -548,7 +599,7 @@ export class AuthService extends BaseService {
await this.sessionRepository.update(auth.session.id, { pinExpiresAt: null });
}
private async createLoginResponse(user: UserAdmin, loginDetails: LoginDetails) {
private async createLoginResponse(user: UserAdmin, loginDetails: LoginDetails, oauthSid?: string) {
const token = this.cryptoRepository.randomBytesAsText(32);
const hashed = this.cryptoRepository.hashSha256(token);
@@ -558,6 +609,7 @@ export class AuthService extends BaseService {
deviceType: loginDetails.deviceType,
appVersion: loginDetails.appVersion,
userId: user.id,
oauthSid: oauthSid ?? null,
});
return mapLoginResponse(user, token);
+5 -2
View File
@@ -46,11 +46,14 @@ describe('SessionService', () => {
const currentSession = SessionFactory.create();
const auth = AuthFactory.from().session(currentSession).build();
mocks.session.invalidate.mockResolvedValue();
mocks.session.invalidateAll.mockResolvedValue();
await sut.deleteAll(auth);
expect(mocks.session.invalidate).toHaveBeenCalledWith({ userId: auth.user.id, excludeId: currentSession.id });
expect(mocks.session.invalidateAll).toHaveBeenCalledWith({
userId: auth.user.id,
excludeId: currentSession.id,
});
});
});
+2 -2
View File
@@ -73,7 +73,7 @@ export class SessionService extends BaseService {
async deleteAll(auth: AuthDto): Promise<void> {
const userId = auth.user.id;
const currentSessionId = auth.session?.id;
await this.sessionRepository.invalidate({ userId, excludeId: currentSessionId });
await this.sessionRepository.invalidateAll({ userId, excludeId: currentSessionId });
}
async lock(auth: AuthDto, id: string): Promise<void> {
@@ -83,6 +83,6 @@ export class SessionService extends BaseService {
@OnEvent({ name: 'AuthChangePassword' })
async onAuthChangePassword({ userId, currentSessionId }: ArgOf<'AuthChangePassword'>): Promise<void> {
await this.sessionRepository.invalidate({ userId, excludeId: currentSessionId });
await this.sessionRepository.invalidateAll({ userId, excludeId: currentSessionId });
}
}