Implement token refresh for websockets

This commit is contained in:
Zoe Roux 2026-05-04 11:21:00 +02:00
parent 2cd92c4665
commit e692fda620
No known key found for this signature in database
11 changed files with 169 additions and 10 deletions

View File

@ -33,7 +33,7 @@ const Jwt = t.Object({
type Jwt = typeof Jwt.static;
const validator = TypeCompiler.Compile(Jwt);
async function verifyJwt(bearer: string) {
export async function verifyJwt(bearer: string) {
// @ts-expect-error ts can't understand that there's two overload idk why
const { payload } = await jwtVerify(bearer, jwtSecret ?? jwks, {
issuer: process.env.JWT_ISSUER,

View File

@ -2,7 +2,7 @@ import { getLogger } from "@logtape/logtape";
import type { TObject, TString } from "@sinclair/typebox";
import { eq } from "drizzle-orm";
import Elysia, { type TSchema, t } from "elysia";
import { auth } from "./auth";
import { auth, verifyJwt } from "./auth";
import { updateProgress } from "./controllers/profiles/history";
import { getOrCreateProfile } from "./controllers/profiles/profile";
import { prepareVideo } from "./controllers/video-metadata";
@ -13,6 +13,7 @@ const logger = getLogger();
const actionMap = {
ping: handler({
skipRefresh: true,
message(ws) {
ws.send({ action: "ping", response: "pong" });
},
@ -90,6 +91,30 @@ export const appWs = baseWs.ws("/ws", {
},
async message(ws, { action, ...body }) {
const handler = actionMap[action as keyof typeof actionMap];
if (!handler.skipRefresh) {
try {
const resp = await fetch(
new URL("/auth/jwt", process.env.AUTH_SERVER ?? "http://auth:4568"),
{
headers: {
Authorization: ws.data.headers.authorization!,
},
},
);
if (resp.ok) {
const data = (await resp.json()) as { token?: string };
if (data.token) {
const ret = await verifyJwt(data.token);
ws.data.jwt = ret.jwt as typeof ws.data.jwt;
ws.data.headers.authorization =
`Bearer ${data.token}` as typeof ws.data.headers.authorization;
}
}
} catch (e) {
logger.error("Failed to refresh jwt: {err}", { err: e });
// If refresh fails, continue with the old JWT
}
}
for (const perm of handler.permissions ?? []) {
if (!ws.data.jwt.permissions.includes(perm)) {
ws.send({
@ -108,6 +133,7 @@ type Ws = Parameters<NonNullable<Parameters<typeof baseWs.ws>[1]["open"]>>[0];
function handler<Schema extends TSchema = TObject<{}>>(ret: {
body?: Schema;
permissions?: string[];
skipRefresh?: boolean;
message: (ws: Ws, body: Schema["static"]) => void | Promise<void>;
}) {
return ret;

View File

@ -206,6 +206,7 @@ func (h *Handler) createApiJwt(ctx context.Context, apikey string) (string, erro
claims["username"] = key.Name
claims["sub"] = key.Id
claims["sid"] = key.Id
claims["jti"] = uuid.New().String()
claims["iss"] = h.config.PublicUrl
claims["iat"] = &jwt.NumericDate{
Time: time.Now().UTC(),

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
// source: apikeys.sql
package dbc

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
package dbc

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
package dbc

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
// source: oidc.sql
package dbc

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
// source: sessions.sql
package dbc
@ -86,6 +86,46 @@ func (q *Queries) DeleteSession(ctx context.Context, arg DeleteSessionParams) (S
return i, err
}
const getUserFromSessionId = `-- name: GetUserFromSessionId :one
select
s.pk,
s.id,
s.last_used,
u.pk, u.id, u.username, u.email, u.password, u.claims, u.created_date, u.last_seen
from
keibi.users as u
inner join keibi.sessions as s on u.pk = s.user_pk
where
s.id = $1
limit 1
`
type GetUserFromSessionIdRow struct {
Pk int32 `json:"pk"`
Id uuid.UUID `json:"id"`
LastUsed time.Time `json:"lastUsed"`
User User `json:"user"`
}
func (q *Queries) GetUserFromSessionId(ctx context.Context, id uuid.UUID) (GetUserFromSessionIdRow, error) {
row := q.db.QueryRow(ctx, getUserFromSessionId, id)
var i GetUserFromSessionIdRow
err := row.Scan(
&i.Pk,
&i.Id,
&i.LastUsed,
&i.User.Pk,
&i.User.Id,
&i.User.Username,
&i.User.Email,
&i.User.Password,
&i.User.Claims,
&i.User.CreatedDate,
&i.User.LastSeen,
)
return i, err
}
const getUserFromToken = `-- name: GetUserFromToken :one
select
s.pk,

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
// source: users.sql
package dbc

View File

@ -2,6 +2,7 @@ package main
import (
"context"
"encoding/base64"
"fmt"
"maps"
"net/http"
@ -9,6 +10,7 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/labstack/echo/v5"
"github.com/lestrrat-go/jwx/v3/jwk"
)
@ -19,7 +21,7 @@ type Jwt struct {
}
// @Summary Get JWT
// @Description Convert a session token or an API key to a short lived JWT.
// @Description Convert a session token or an API key to a short lived JWT. Passing an existing JWT will refresh it.
// @Tags jwt
// @Produce json
// @Security Token
@ -69,6 +71,12 @@ func (h *Handler) CreateJwt(c *echo.Context) error {
if jwt == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Guests not allowed.")
}
} else if _, err := base64.RawURLEncoding.DecodeString(token); err != nil {
tkn, err := h.refreshJwt(ctx, token)
if err != nil {
return err
}
jwt = &tkn
} else {
tkn, err := h.createJwt(ctx, token)
if err != nil {
@ -94,6 +102,7 @@ func (h *Handler) createGuestJwt() *string {
claims["username"] = "guest"
claims["sub"] = "00000000-0000-0000-0000-000000000000"
claims["sid"] = "00000000-0000-0000-0000-000000000000"
claims["jti"] = uuid.New().String()
claims["iss"] = h.config.PublicUrl
claims["iat"] = &jwt.NumericDate{
Time: time.Now().UTC(),
@ -128,6 +137,7 @@ func (h *Handler) createJwt(ctx context.Context, token string) (string, error) {
claims["username"] = session.User.Username
claims["sub"] = session.User.Id.String()
claims["sid"] = session.Id.String()
claims["jti"] = uuid.New().String()
claims["iss"] = h.config.PublicUrl
claims["iat"] = &jwt.NumericDate{
Time: time.Now().UTC(),
@ -144,13 +154,82 @@ func (h *Handler) createJwt(ctx context.Context, token string) (string, error) {
return t, nil
}
func (h *Handler) refreshJwt(ctx context.Context, jwtToken string) (string, error) {
token, err := jwt.ParseWithClaims(jwtToken, jwt.MapClaims{}, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != "RS256" {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return h.config.JwtPublicKey, nil
})
if err != nil {
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid JWT")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid JWT claims")
}
sidStr, ok := claims["sid"].(string)
if !ok {
return "", echo.NewHTTPError(http.StatusForbidden, "Missing session id in JWT")
}
sid, err := uuid.Parse(sidStr)
if err != nil {
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid session id in JWT")
}
jtiStr, ok := claims["jti"].(string)
if !ok {
return "", echo.NewHTTPError(http.StatusForbidden, "Missing token id in JWT")
}
jti, err := uuid.Parse(jtiStr)
if err != nil {
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid token id in JWT")
}
session, err := h.db.GetUserFromSessionId(ctx, sid)
if err != nil {
return "", echo.NewHTTPError(http.StatusForbidden, "Session not found")
}
if session.LastUsed.Add(h.config.ExpirationDelay).Compare(time.Now().UTC()) < 0 {
return "", echo.NewHTTPError(http.StatusForbidden, "Session has expired")
}
go func() {
h.db.TouchSession(ctx, session.Pk)
h.db.TouchUser(ctx, session.User.Pk)
}()
newClaims := maps.Clone(session.User.Claims)
newClaims["username"] = session.User.Username
newClaims["sub"] = session.User.Id.String()
newClaims["sid"] = session.Id.String()
newClaims["jti"] = jti.String()
newClaims["iss"] = h.config.PublicUrl
newClaims["iat"] = &jwt.NumericDate{
Time: time.Now().UTC(),
}
newClaims["exp"] = &jwt.NumericDate{
Time: time.Now().UTC().Add(time.Hour),
}
newJwt := jwt.NewWithClaims(jwt.SigningMethodRS256, newClaims)
newJwt.Header["kid"] = h.config.JwtKid
t, err := newJwt.SignedString(h.config.JwtPrivateKey)
if err != nil {
return "", err
}
return t, nil
}
// only used for the swagger doc
type JwkSet struct {
Keys []struct {
E string `json:"e" example:"AQAB"`
KeyOps []string `json:"key_ops" example:"[verify]"`
Kty string `json:"kty" example:"RSA"`
N string `json:"n" example:"oBcXcJUR-Sb8_b4qIj28LRAPxdF_6odRr52K5-ymiEkR2DOlEuXBtM-biWxPESW-U-zhfHzdVLf6ioy5xL0bJTh8BMIorkrDliN3vb81jCvyOMgZ7ATMJpMAQMmSDN7sL3U45r22FaoQufCJMQHmUsZPecdQSgj2aFBiRXxsLleYlSezdBVT_gKH-coqeYXSC_hk-ezSq4aDZ10BlDnZ-FA7-ES3T7nBmJEAU7KDAGeSvbYAfYimOW0r-Vc0xQNuwGCfzZtSexKXDbYbNwOVo3SjfCabq-gMfap_owcHbKicGBZu1LDlh7CpkmLQf_kv6GihM2LWFFh6Vwg2cltiwF22EIPlUDtYTkUR0qRkdNJaNkwV5Vv_6r3pzSmu5ovRriKtlrvJMjlTnLb4_ltsge3fw5Z34cJrsp094FbUc2O6Or4FGEXUldieJCnVRhs2_h6SDcmeMXs1zfvE5GlDnq8tZV6WMJ5Sb4jNO7rs_hTkr23_E6mVg-DdtozGfqzRzhIjPym6D_jVfR6dZv5W0sKwOHRmT7nYq-C7b2sAwmNNII296M4Rq-jn0b5pgSeMDYbIpbIA4thU8LYU0lBZp_ZVwWKG1RFZDxz3k9O5UVth2kTpTWlwn0hB1aAvgXHo6in1CScITGA72p73RbDieNnLFaCK4xUVstkWAKLqPxs"`
N string `json:"n" example:"oBcXcJUR-Sb8_b4qIj28LRAPxdF_6odRr52K5-ymiEkR2DOlEuXBtM-biWxPESW-U-zhfHzdVLf6ioy5xL0bJTh8BMIorkrDliN3vb81jCvyOMgZ7ATMJpMAQMmSDN7sL3U45r22FaoQufCJMQHmUsZPecdQSgj2aFBiRXxsLleYlSezdBVT_gKH-coqeYXSC_hk-ezSq4aDZ10BlDnZ-FA7-ES3T7nBmJEAU7KDAGeSvbYAfYimOW0r-Vc0xQNuwGCfzZtSexKXDbYbNwOVo3SjfCabq-gMfap_owcHbKicGBZu1LDlh7CpkmLQf_kv6GihM2LWFFh6Vwg2cltiwF22EIPlUDtYTkUR0qRkdNJaNkwV5Vv_6r3pzSmu5ovRriKtlrvJMjlTnLb4_ltsge3fw5Z34cJrsp094FbUc2O6Or4FGEXUldieJCnVRhs2_h6SDcmeMXs1zfvE5GlDnq8tZV6WMJ5Sb4jNO7rs_hTkr23_E6mVg-DdtRS256ozGfqzRzhIjPym6D_jVfR6dZv5W0sKwOHRmT7nYq-C7b2sAwmNNII296M4Rq-jn0b5pgSeMDYbIpbIA4thU8LYU0lBZp_ZVwWKG1RFZDxz3k9O5UVth2kTpTWlwn0hB1aAvgXHo6in1CScITGA72p73RbDieNnLFaCK4xUVstkWAKLqPxs"`
Use string `json:"use" example:"sig"`
}
}

View File

@ -44,6 +44,19 @@ where s.user_pk = u.pk
returning
s.*;
-- name: GetUserFromSessionId :one
select
s.pk,
s.id,
s.last_used,
sqlc.embed(u)
from
keibi.users as u
inner join keibi.sessions as s on u.pk = s.user_pk
where
s.id = $1
limit 1;
-- name: ClearOtherSessions :exec
delete from keibi.sessions as s using keibi.users as u
where s.user_pk = u.pk