From e692fda6208314f3bb14ffd9ffcca68beef9088c Mon Sep 17 00:00:00 2001 From: Zoe Roux Date: Mon, 4 May 2026 11:21:00 +0200 Subject: [PATCH] Implement token refresh for websockets --- api/src/auth.ts | 2 +- api/src/websockets.ts | 28 +++++++++++- auth/apikey.go | 1 + auth/dbc/apikeys.sql.go | 2 +- auth/dbc/db.go | 2 +- auth/dbc/models.go | 2 +- auth/dbc/oidc.sql.go | 2 +- auth/dbc/sessions.sql.go | 42 +++++++++++++++++- auth/dbc/users.sql.go | 2 +- auth/jwt.go | 83 ++++++++++++++++++++++++++++++++++- auth/sql/queries/sessions.sql | 13 ++++++ 11 files changed, 169 insertions(+), 10 deletions(-) diff --git a/api/src/auth.ts b/api/src/auth.ts index d2be7c8e..472d56da 100644 --- a/api/src/auth.ts +++ b/api/src/auth.ts @@ -33,7 +33,7 @@ const Jwt = t.Object({ type Jwt = typeof Jwt.static; const validator = TypeCompiler.Compile(Jwt); -async function verifyJwt(bearer: string) { +export async function verifyJwt(bearer: string) { // @ts-expect-error ts can't understand that there's two overload idk why const { payload } = await jwtVerify(bearer, jwtSecret ?? jwks, { issuer: process.env.JWT_ISSUER, diff --git a/api/src/websockets.ts b/api/src/websockets.ts index bb9327a8..780bf548 100644 --- a/api/src/websockets.ts +++ b/api/src/websockets.ts @@ -2,7 +2,7 @@ import { getLogger } from "@logtape/logtape"; import type { TObject, TString } from "@sinclair/typebox"; import { eq } from "drizzle-orm"; import Elysia, { type TSchema, t } from "elysia"; -import { auth } from "./auth"; +import { auth, verifyJwt } from "./auth"; import { updateProgress } from "./controllers/profiles/history"; import { getOrCreateProfile } from "./controllers/profiles/profile"; import { prepareVideo } from "./controllers/video-metadata"; @@ -13,6 +13,7 @@ const logger = getLogger(); const actionMap = { ping: handler({ + skipRefresh: true, message(ws) { ws.send({ action: "ping", response: "pong" }); }, @@ -90,6 +91,30 @@ export const appWs = baseWs.ws("/ws", { }, async message(ws, { action, ...body }) { const handler = actionMap[action as keyof typeof actionMap]; + if (!handler.skipRefresh) { + try { + const resp = await fetch( + new URL("/auth/jwt", process.env.AUTH_SERVER ?? "http://auth:4568"), + { + headers: { + Authorization: ws.data.headers.authorization!, + }, + }, + ); + if (resp.ok) { + const data = (await resp.json()) as { token?: string }; + if (data.token) { + const ret = await verifyJwt(data.token); + ws.data.jwt = ret.jwt as typeof ws.data.jwt; + ws.data.headers.authorization = + `Bearer ${data.token}` as typeof ws.data.headers.authorization; + } + } + } catch (e) { + logger.error("Failed to refresh jwt: {err}", { err: e }); + // If refresh fails, continue with the old JWT + } + } for (const perm of handler.permissions ?? []) { if (!ws.data.jwt.permissions.includes(perm)) { ws.send({ @@ -108,6 +133,7 @@ type Ws = Parameters[1]["open"]>>[0]; function handler>(ret: { body?: Schema; permissions?: string[]; + skipRefresh?: boolean; message: (ws: Ws, body: Schema["static"]) => void | Promise; }) { return ret; diff --git a/auth/apikey.go b/auth/apikey.go index a2b96fa4..1ccc4449 100644 --- a/auth/apikey.go +++ b/auth/apikey.go @@ -206,6 +206,7 @@ func (h *Handler) createApiJwt(ctx context.Context, apikey string) (string, erro claims["username"] = key.Name claims["sub"] = key.Id claims["sid"] = key.Id + claims["jti"] = uuid.New().String() claims["iss"] = h.config.PublicUrl claims["iat"] = &jwt.NumericDate{ Time: time.Now().UTC(), diff --git a/auth/dbc/apikeys.sql.go b/auth/dbc/apikeys.sql.go index bf570ee3..aeddf530 100644 --- a/auth/dbc/apikeys.sql.go +++ b/auth/dbc/apikeys.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: apikeys.sql package dbc diff --git a/auth/dbc/db.go b/auth/dbc/db.go index dcfc5dbe..e55d2dda 100644 --- a/auth/dbc/db.go +++ b/auth/dbc/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package dbc diff --git a/auth/dbc/models.go b/auth/dbc/models.go index 2cc65f4b..c699a71d 100644 --- a/auth/dbc/models.go +++ b/auth/dbc/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package dbc diff --git a/auth/dbc/oidc.sql.go b/auth/dbc/oidc.sql.go index 2764282a..91564382 100644 --- a/auth/dbc/oidc.sql.go +++ b/auth/dbc/oidc.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: oidc.sql package dbc diff --git a/auth/dbc/sessions.sql.go b/auth/dbc/sessions.sql.go index 3c6ef7ae..250664cf 100644 --- a/auth/dbc/sessions.sql.go +++ b/auth/dbc/sessions.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: sessions.sql package dbc @@ -86,6 +86,46 @@ func (q *Queries) DeleteSession(ctx context.Context, arg DeleteSessionParams) (S return i, err } +const getUserFromSessionId = `-- name: GetUserFromSessionId :one +select + s.pk, + s.id, + s.last_used, + u.pk, u.id, u.username, u.email, u.password, u.claims, u.created_date, u.last_seen +from + keibi.users as u + inner join keibi.sessions as s on u.pk = s.user_pk +where + s.id = $1 +limit 1 +` + +type GetUserFromSessionIdRow struct { + Pk int32 `json:"pk"` + Id uuid.UUID `json:"id"` + LastUsed time.Time `json:"lastUsed"` + User User `json:"user"` +} + +func (q *Queries) GetUserFromSessionId(ctx context.Context, id uuid.UUID) (GetUserFromSessionIdRow, error) { + row := q.db.QueryRow(ctx, getUserFromSessionId, id) + var i GetUserFromSessionIdRow + err := row.Scan( + &i.Pk, + &i.Id, + &i.LastUsed, + &i.User.Pk, + &i.User.Id, + &i.User.Username, + &i.User.Email, + &i.User.Password, + &i.User.Claims, + &i.User.CreatedDate, + &i.User.LastSeen, + ) + return i, err +} + const getUserFromToken = `-- name: GetUserFromToken :one select s.pk, diff --git a/auth/dbc/users.sql.go b/auth/dbc/users.sql.go index e7eee8aa..ec33da35 100644 --- a/auth/dbc/users.sql.go +++ b/auth/dbc/users.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: users.sql package dbc diff --git a/auth/jwt.go b/auth/jwt.go index 4798e768..ab503c96 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/base64" "fmt" "maps" "net/http" @@ -9,6 +10,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "github.com/labstack/echo/v5" "github.com/lestrrat-go/jwx/v3/jwk" ) @@ -19,7 +21,7 @@ type Jwt struct { } // @Summary Get JWT -// @Description Convert a session token or an API key to a short lived JWT. +// @Description Convert a session token or an API key to a short lived JWT. Passing an existing JWT will refresh it. // @Tags jwt // @Produce json // @Security Token @@ -69,6 +71,12 @@ func (h *Handler) CreateJwt(c *echo.Context) error { if jwt == nil { return echo.NewHTTPError(http.StatusUnauthorized, "Guests not allowed.") } + } else if _, err := base64.RawURLEncoding.DecodeString(token); err != nil { + tkn, err := h.refreshJwt(ctx, token) + if err != nil { + return err + } + jwt = &tkn } else { tkn, err := h.createJwt(ctx, token) if err != nil { @@ -94,6 +102,7 @@ func (h *Handler) createGuestJwt() *string { claims["username"] = "guest" claims["sub"] = "00000000-0000-0000-0000-000000000000" claims["sid"] = "00000000-0000-0000-0000-000000000000" + claims["jti"] = uuid.New().String() claims["iss"] = h.config.PublicUrl claims["iat"] = &jwt.NumericDate{ Time: time.Now().UTC(), @@ -128,6 +137,7 @@ func (h *Handler) createJwt(ctx context.Context, token string) (string, error) { claims["username"] = session.User.Username claims["sub"] = session.User.Id.String() claims["sid"] = session.Id.String() + claims["jti"] = uuid.New().String() claims["iss"] = h.config.PublicUrl claims["iat"] = &jwt.NumericDate{ Time: time.Now().UTC(), @@ -144,13 +154,82 @@ func (h *Handler) createJwt(ctx context.Context, token string) (string, error) { return t, nil } +func (h *Handler) refreshJwt(ctx context.Context, jwtToken string) (string, error) { + token, err := jwt.ParseWithClaims(jwtToken, jwt.MapClaims{}, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != "RS256" { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return h.config.JwtPublicKey, nil + }) + if err != nil { + return "", echo.NewHTTPError(http.StatusForbidden, "Invalid JWT") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return "", echo.NewHTTPError(http.StatusForbidden, "Invalid JWT claims") + } + + sidStr, ok := claims["sid"].(string) + if !ok { + return "", echo.NewHTTPError(http.StatusForbidden, "Missing session id in JWT") + } + sid, err := uuid.Parse(sidStr) + if err != nil { + return "", echo.NewHTTPError(http.StatusForbidden, "Invalid session id in JWT") + } + + jtiStr, ok := claims["jti"].(string) + if !ok { + return "", echo.NewHTTPError(http.StatusForbidden, "Missing token id in JWT") + } + jti, err := uuid.Parse(jtiStr) + if err != nil { + return "", echo.NewHTTPError(http.StatusForbidden, "Invalid token id in JWT") + } + + session, err := h.db.GetUserFromSessionId(ctx, sid) + if err != nil { + return "", echo.NewHTTPError(http.StatusForbidden, "Session not found") + } + + if session.LastUsed.Add(h.config.ExpirationDelay).Compare(time.Now().UTC()) < 0 { + return "", echo.NewHTTPError(http.StatusForbidden, "Session has expired") + } + + go func() { + h.db.TouchSession(ctx, session.Pk) + h.db.TouchUser(ctx, session.User.Pk) + }() + + newClaims := maps.Clone(session.User.Claims) + newClaims["username"] = session.User.Username + newClaims["sub"] = session.User.Id.String() + newClaims["sid"] = session.Id.String() + newClaims["jti"] = jti.String() + newClaims["iss"] = h.config.PublicUrl + newClaims["iat"] = &jwt.NumericDate{ + Time: time.Now().UTC(), + } + newClaims["exp"] = &jwt.NumericDate{ + Time: time.Now().UTC().Add(time.Hour), + } + newJwt := jwt.NewWithClaims(jwt.SigningMethodRS256, newClaims) + newJwt.Header["kid"] = h.config.JwtKid + t, err := newJwt.SignedString(h.config.JwtPrivateKey) + if err != nil { + return "", err + } + return t, nil +} + // only used for the swagger doc type JwkSet struct { Keys []struct { E string `json:"e" example:"AQAB"` KeyOps []string `json:"key_ops" example:"[verify]"` Kty string `json:"kty" example:"RSA"` - N string `json:"n" example:"oBcXcJUR-Sb8_b4qIj28LRAPxdF_6odRr52K5-ymiEkR2DOlEuXBtM-biWxPESW-U-zhfHzdVLf6ioy5xL0bJTh8BMIorkrDliN3vb81jCvyOMgZ7ATMJpMAQMmSDN7sL3U45r22FaoQufCJMQHmUsZPecdQSgj2aFBiRXxsLleYlSezdBVT_gKH-coqeYXSC_hk-ezSq4aDZ10BlDnZ-FA7-ES3T7nBmJEAU7KDAGeSvbYAfYimOW0r-Vc0xQNuwGCfzZtSexKXDbYbNwOVo3SjfCabq-gMfap_owcHbKicGBZu1LDlh7CpkmLQf_kv6GihM2LWFFh6Vwg2cltiwF22EIPlUDtYTkUR0qRkdNJaNkwV5Vv_6r3pzSmu5ovRriKtlrvJMjlTnLb4_ltsge3fw5Z34cJrsp094FbUc2O6Or4FGEXUldieJCnVRhs2_h6SDcmeMXs1zfvE5GlDnq8tZV6WMJ5Sb4jNO7rs_hTkr23_E6mVg-DdtozGfqzRzhIjPym6D_jVfR6dZv5W0sKwOHRmT7nYq-C7b2sAwmNNII296M4Rq-jn0b5pgSeMDYbIpbIA4thU8LYU0lBZp_ZVwWKG1RFZDxz3k9O5UVth2kTpTWlwn0hB1aAvgXHo6in1CScITGA72p73RbDieNnLFaCK4xUVstkWAKLqPxs"` + N string `json:"n" example:"oBcXcJUR-Sb8_b4qIj28LRAPxdF_6odRr52K5-ymiEkR2DOlEuXBtM-biWxPESW-U-zhfHzdVLf6ioy5xL0bJTh8BMIorkrDliN3vb81jCvyOMgZ7ATMJpMAQMmSDN7sL3U45r22FaoQufCJMQHmUsZPecdQSgj2aFBiRXxsLleYlSezdBVT_gKH-coqeYXSC_hk-ezSq4aDZ10BlDnZ-FA7-ES3T7nBmJEAU7KDAGeSvbYAfYimOW0r-Vc0xQNuwGCfzZtSexKXDbYbNwOVo3SjfCabq-gMfap_owcHbKicGBZu1LDlh7CpkmLQf_kv6GihM2LWFFh6Vwg2cltiwF22EIPlUDtYTkUR0qRkdNJaNkwV5Vv_6r3pzSmu5ovRriKtlrvJMjlTnLb4_ltsge3fw5Z34cJrsp094FbUc2O6Or4FGEXUldieJCnVRhs2_h6SDcmeMXs1zfvE5GlDnq8tZV6WMJ5Sb4jNO7rs_hTkr23_E6mVg-DdtRS256ozGfqzRzhIjPym6D_jVfR6dZv5W0sKwOHRmT7nYq-C7b2sAwmNNII296M4Rq-jn0b5pgSeMDYbIpbIA4thU8LYU0lBZp_ZVwWKG1RFZDxz3k9O5UVth2kTpTWlwn0hB1aAvgXHo6in1CScITGA72p73RbDieNnLFaCK4xUVstkWAKLqPxs"` Use string `json:"use" example:"sig"` } } diff --git a/auth/sql/queries/sessions.sql b/auth/sql/queries/sessions.sql index 187f1627..f591c2e6 100644 --- a/auth/sql/queries/sessions.sql +++ b/auth/sql/queries/sessions.sql @@ -44,6 +44,19 @@ where s.user_pk = u.pk returning s.*; +-- name: GetUserFromSessionId :one +select + s.pk, + s.id, + s.last_used, + sqlc.embed(u) +from + keibi.users as u + inner join keibi.sessions as s on u.pk = s.user_pk +where + s.id = $1 +limit 1; + -- name: ClearOtherSessions :exec delete from keibi.sessions as s using keibi.users as u where s.user_pk = u.pk