mirror of
https://github.com/zoriya/Kyoo.git
synced 2026-05-13 02:48:33 -04:00
Implement token refresh for websockets
This commit is contained in:
parent
2cd92c4665
commit
e692fda620
@ -33,7 +33,7 @@ const Jwt = t.Object({
|
||||
type Jwt = typeof Jwt.static;
|
||||
const validator = TypeCompiler.Compile(Jwt);
|
||||
|
||||
async function verifyJwt(bearer: string) {
|
||||
export async function verifyJwt(bearer: string) {
|
||||
// @ts-expect-error ts can't understand that there's two overload idk why
|
||||
const { payload } = await jwtVerify(bearer, jwtSecret ?? jwks, {
|
||||
issuer: process.env.JWT_ISSUER,
|
||||
|
||||
@ -2,7 +2,7 @@ import { getLogger } from "@logtape/logtape";
|
||||
import type { TObject, TString } from "@sinclair/typebox";
|
||||
import { eq } from "drizzle-orm";
|
||||
import Elysia, { type TSchema, t } from "elysia";
|
||||
import { auth } from "./auth";
|
||||
import { auth, verifyJwt } from "./auth";
|
||||
import { updateProgress } from "./controllers/profiles/history";
|
||||
import { getOrCreateProfile } from "./controllers/profiles/profile";
|
||||
import { prepareVideo } from "./controllers/video-metadata";
|
||||
@ -13,6 +13,7 @@ const logger = getLogger();
|
||||
|
||||
const actionMap = {
|
||||
ping: handler({
|
||||
skipRefresh: true,
|
||||
message(ws) {
|
||||
ws.send({ action: "ping", response: "pong" });
|
||||
},
|
||||
@ -90,6 +91,30 @@ export const appWs = baseWs.ws("/ws", {
|
||||
},
|
||||
async message(ws, { action, ...body }) {
|
||||
const handler = actionMap[action as keyof typeof actionMap];
|
||||
if (!handler.skipRefresh) {
|
||||
try {
|
||||
const resp = await fetch(
|
||||
new URL("/auth/jwt", process.env.AUTH_SERVER ?? "http://auth:4568"),
|
||||
{
|
||||
headers: {
|
||||
Authorization: ws.data.headers.authorization!,
|
||||
},
|
||||
},
|
||||
);
|
||||
if (resp.ok) {
|
||||
const data = (await resp.json()) as { token?: string };
|
||||
if (data.token) {
|
||||
const ret = await verifyJwt(data.token);
|
||||
ws.data.jwt = ret.jwt as typeof ws.data.jwt;
|
||||
ws.data.headers.authorization =
|
||||
`Bearer ${data.token}` as typeof ws.data.headers.authorization;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error("Failed to refresh jwt: {err}", { err: e });
|
||||
// If refresh fails, continue with the old JWT
|
||||
}
|
||||
}
|
||||
for (const perm of handler.permissions ?? []) {
|
||||
if (!ws.data.jwt.permissions.includes(perm)) {
|
||||
ws.send({
|
||||
@ -108,6 +133,7 @@ type Ws = Parameters<NonNullable<Parameters<typeof baseWs.ws>[1]["open"]>>[0];
|
||||
function handler<Schema extends TSchema = TObject<{}>>(ret: {
|
||||
body?: Schema;
|
||||
permissions?: string[];
|
||||
skipRefresh?: boolean;
|
||||
message: (ws: Ws, body: Schema["static"]) => void | Promise<void>;
|
||||
}) {
|
||||
return ret;
|
||||
|
||||
@ -206,6 +206,7 @@ func (h *Handler) createApiJwt(ctx context.Context, apikey string) (string, erro
|
||||
claims["username"] = key.Name
|
||||
claims["sub"] = key.Id
|
||||
claims["sid"] = key.Id
|
||||
claims["jti"] = uuid.New().String()
|
||||
claims["iss"] = h.config.PublicUrl
|
||||
claims["iat"] = &jwt.NumericDate{
|
||||
Time: time.Now().UTC(),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: apikeys.sql
|
||||
|
||||
package dbc
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
|
||||
package dbc
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
|
||||
package dbc
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: oidc.sql
|
||||
|
||||
package dbc
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: sessions.sql
|
||||
|
||||
package dbc
|
||||
@ -86,6 +86,46 @@ func (q *Queries) DeleteSession(ctx context.Context, arg DeleteSessionParams) (S
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserFromSessionId = `-- name: GetUserFromSessionId :one
|
||||
select
|
||||
s.pk,
|
||||
s.id,
|
||||
s.last_used,
|
||||
u.pk, u.id, u.username, u.email, u.password, u.claims, u.created_date, u.last_seen
|
||||
from
|
||||
keibi.users as u
|
||||
inner join keibi.sessions as s on u.pk = s.user_pk
|
||||
where
|
||||
s.id = $1
|
||||
limit 1
|
||||
`
|
||||
|
||||
type GetUserFromSessionIdRow struct {
|
||||
Pk int32 `json:"pk"`
|
||||
Id uuid.UUID `json:"id"`
|
||||
LastUsed time.Time `json:"lastUsed"`
|
||||
User User `json:"user"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserFromSessionId(ctx context.Context, id uuid.UUID) (GetUserFromSessionIdRow, error) {
|
||||
row := q.db.QueryRow(ctx, getUserFromSessionId, id)
|
||||
var i GetUserFromSessionIdRow
|
||||
err := row.Scan(
|
||||
&i.Pk,
|
||||
&i.Id,
|
||||
&i.LastUsed,
|
||||
&i.User.Pk,
|
||||
&i.User.Id,
|
||||
&i.User.Username,
|
||||
&i.User.Email,
|
||||
&i.User.Password,
|
||||
&i.User.Claims,
|
||||
&i.User.CreatedDate,
|
||||
&i.User.LastSeen,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserFromToken = `-- name: GetUserFromToken :one
|
||||
select
|
||||
s.pk,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: users.sql
|
||||
|
||||
package dbc
|
||||
|
||||
83
auth/jwt.go
83
auth/jwt.go
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
)
|
||||
@ -19,7 +21,7 @@ type Jwt struct {
|
||||
}
|
||||
|
||||
// @Summary Get JWT
|
||||
// @Description Convert a session token or an API key to a short lived JWT.
|
||||
// @Description Convert a session token or an API key to a short lived JWT. Passing an existing JWT will refresh it.
|
||||
// @Tags jwt
|
||||
// @Produce json
|
||||
// @Security Token
|
||||
@ -69,6 +71,12 @@ func (h *Handler) CreateJwt(c *echo.Context) error {
|
||||
if jwt == nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Guests not allowed.")
|
||||
}
|
||||
} else if _, err := base64.RawURLEncoding.DecodeString(token); err != nil {
|
||||
tkn, err := h.refreshJwt(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
jwt = &tkn
|
||||
} else {
|
||||
tkn, err := h.createJwt(ctx, token)
|
||||
if err != nil {
|
||||
@ -94,6 +102,7 @@ func (h *Handler) createGuestJwt() *string {
|
||||
claims["username"] = "guest"
|
||||
claims["sub"] = "00000000-0000-0000-0000-000000000000"
|
||||
claims["sid"] = "00000000-0000-0000-0000-000000000000"
|
||||
claims["jti"] = uuid.New().String()
|
||||
claims["iss"] = h.config.PublicUrl
|
||||
claims["iat"] = &jwt.NumericDate{
|
||||
Time: time.Now().UTC(),
|
||||
@ -128,6 +137,7 @@ func (h *Handler) createJwt(ctx context.Context, token string) (string, error) {
|
||||
claims["username"] = session.User.Username
|
||||
claims["sub"] = session.User.Id.String()
|
||||
claims["sid"] = session.Id.String()
|
||||
claims["jti"] = uuid.New().String()
|
||||
claims["iss"] = h.config.PublicUrl
|
||||
claims["iat"] = &jwt.NumericDate{
|
||||
Time: time.Now().UTC(),
|
||||
@ -144,13 +154,82 @@ func (h *Handler) createJwt(ctx context.Context, token string) (string, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (h *Handler) refreshJwt(ctx context.Context, jwtToken string) (string, error) {
|
||||
token, err := jwt.ParseWithClaims(jwtToken, jwt.MapClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != "RS256" {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return h.config.JwtPublicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid JWT")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid JWT claims")
|
||||
}
|
||||
|
||||
sidStr, ok := claims["sid"].(string)
|
||||
if !ok {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Missing session id in JWT")
|
||||
}
|
||||
sid, err := uuid.Parse(sidStr)
|
||||
if err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid session id in JWT")
|
||||
}
|
||||
|
||||
jtiStr, ok := claims["jti"].(string)
|
||||
if !ok {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Missing token id in JWT")
|
||||
}
|
||||
jti, err := uuid.Parse(jtiStr)
|
||||
if err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Invalid token id in JWT")
|
||||
}
|
||||
|
||||
session, err := h.db.GetUserFromSessionId(ctx, sid)
|
||||
if err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Session not found")
|
||||
}
|
||||
|
||||
if session.LastUsed.Add(h.config.ExpirationDelay).Compare(time.Now().UTC()) < 0 {
|
||||
return "", echo.NewHTTPError(http.StatusForbidden, "Session has expired")
|
||||
}
|
||||
|
||||
go func() {
|
||||
h.db.TouchSession(ctx, session.Pk)
|
||||
h.db.TouchUser(ctx, session.User.Pk)
|
||||
}()
|
||||
|
||||
newClaims := maps.Clone(session.User.Claims)
|
||||
newClaims["username"] = session.User.Username
|
||||
newClaims["sub"] = session.User.Id.String()
|
||||
newClaims["sid"] = session.Id.String()
|
||||
newClaims["jti"] = jti.String()
|
||||
newClaims["iss"] = h.config.PublicUrl
|
||||
newClaims["iat"] = &jwt.NumericDate{
|
||||
Time: time.Now().UTC(),
|
||||
}
|
||||
newClaims["exp"] = &jwt.NumericDate{
|
||||
Time: time.Now().UTC().Add(time.Hour),
|
||||
}
|
||||
newJwt := jwt.NewWithClaims(jwt.SigningMethodRS256, newClaims)
|
||||
newJwt.Header["kid"] = h.config.JwtKid
|
||||
t, err := newJwt.SignedString(h.config.JwtPrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// only used for the swagger doc
|
||||
type JwkSet struct {
|
||||
Keys []struct {
|
||||
E string `json:"e" example:"AQAB"`
|
||||
KeyOps []string `json:"key_ops" example:"[verify]"`
|
||||
Kty string `json:"kty" example:"RSA"`
|
||||
N string `json:"n" example:"oBcXcJUR-Sb8_b4qIj28LRAPxdF_6odRr52K5-ymiEkR2DOlEuXBtM-biWxPESW-U-zhfHzdVLf6ioy5xL0bJTh8BMIorkrDliN3vb81jCvyOMgZ7ATMJpMAQMmSDN7sL3U45r22FaoQufCJMQHmUsZPecdQSgj2aFBiRXxsLleYlSezdBVT_gKH-coqeYXSC_hk-ezSq4aDZ10BlDnZ-FA7-ES3T7nBmJEAU7KDAGeSvbYAfYimOW0r-Vc0xQNuwGCfzZtSexKXDbYbNwOVo3SjfCabq-gMfap_owcHbKicGBZu1LDlh7CpkmLQf_kv6GihM2LWFFh6Vwg2cltiwF22EIPlUDtYTkUR0qRkdNJaNkwV5Vv_6r3pzSmu5ovRriKtlrvJMjlTnLb4_ltsge3fw5Z34cJrsp094FbUc2O6Or4FGEXUldieJCnVRhs2_h6SDcmeMXs1zfvE5GlDnq8tZV6WMJ5Sb4jNO7rs_hTkr23_E6mVg-DdtozGfqzRzhIjPym6D_jVfR6dZv5W0sKwOHRmT7nYq-C7b2sAwmNNII296M4Rq-jn0b5pgSeMDYbIpbIA4thU8LYU0lBZp_ZVwWKG1RFZDxz3k9O5UVth2kTpTWlwn0hB1aAvgXHo6in1CScITGA72p73RbDieNnLFaCK4xUVstkWAKLqPxs"`
|
||||
N string `json:"n" example:"oBcXcJUR-Sb8_b4qIj28LRAPxdF_6odRr52K5-ymiEkR2DOlEuXBtM-biWxPESW-U-zhfHzdVLf6ioy5xL0bJTh8BMIorkrDliN3vb81jCvyOMgZ7ATMJpMAQMmSDN7sL3U45r22FaoQufCJMQHmUsZPecdQSgj2aFBiRXxsLleYlSezdBVT_gKH-coqeYXSC_hk-ezSq4aDZ10BlDnZ-FA7-ES3T7nBmJEAU7KDAGeSvbYAfYimOW0r-Vc0xQNuwGCfzZtSexKXDbYbNwOVo3SjfCabq-gMfap_owcHbKicGBZu1LDlh7CpkmLQf_kv6GihM2LWFFh6Vwg2cltiwF22EIPlUDtYTkUR0qRkdNJaNkwV5Vv_6r3pzSmu5ovRriKtlrvJMjlTnLb4_ltsge3fw5Z34cJrsp094FbUc2O6Or4FGEXUldieJCnVRhs2_h6SDcmeMXs1zfvE5GlDnq8tZV6WMJ5Sb4jNO7rs_hTkr23_E6mVg-DdtRS256ozGfqzRzhIjPym6D_jVfR6dZv5W0sKwOHRmT7nYq-C7b2sAwmNNII296M4Rq-jn0b5pgSeMDYbIpbIA4thU8LYU0lBZp_ZVwWKG1RFZDxz3k9O5UVth2kTpTWlwn0hB1aAvgXHo6in1CScITGA72p73RbDieNnLFaCK4xUVstkWAKLqPxs"`
|
||||
Use string `json:"use" example:"sig"`
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,6 +44,19 @@ where s.user_pk = u.pk
|
||||
returning
|
||||
s.*;
|
||||
|
||||
-- name: GetUserFromSessionId :one
|
||||
select
|
||||
s.pk,
|
||||
s.id,
|
||||
s.last_used,
|
||||
sqlc.embed(u)
|
||||
from
|
||||
keibi.users as u
|
||||
inner join keibi.sessions as s on u.pk = s.user_pk
|
||||
where
|
||||
s.id = $1
|
||||
limit 1;
|
||||
|
||||
-- name: ClearOtherSessions :exec
|
||||
delete from keibi.sessions as s using keibi.users as u
|
||||
where s.user_pk = u.pk
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user