diff --git a/api/src/auth.ts b/api/src/auth.ts index 101a5dc0..c5c28728 100644 --- a/api/src/auth.ts +++ b/api/src/auth.ts @@ -33,7 +33,7 @@ const Jwt = t.Object({ type Jwt = typeof Jwt.static; const validator = TypeCompiler.Compile(Jwt); -export async function verifyJwt(bearer: string) { +async function verifyJwt(bearer: string) { // @ts-expect-error ts can't understand that there's two overload idk why const { payload } = await jwtVerify(bearer, jwtSecret ?? jwks, { issuer: process.env.JWT_ISSUER, diff --git a/api/src/websockets.ts b/api/src/websockets.ts index 10da619c..d244b7fb 100644 --- a/api/src/websockets.ts +++ b/api/src/websockets.ts @@ -1,6 +1,6 @@ import type { TObject, TString } from "@sinclair/typebox"; import Elysia, { type TSchema, t } from "elysia"; -import { verifyJwt } from "./auth"; +import { auth } from "./auth"; import { updateProgress } from "./controllers/profiles/history"; import { getOrCreateProfile } from "./controllers/profiles/profile"; import { SeedHistory } from "./models/history"; @@ -8,7 +8,7 @@ import { SeedHistory } from "./models/history"; const actionMap = { ping: handler({ message(ws) { - ws.send({ response: "pong" }); + ws.send({ action: "ping", response: "pong" }); }, }), watch: handler({ @@ -20,55 +20,12 @@ const actionMap = { const ret = await updateProgress(profilePk, [ { ...body, playedDate: null }, ]); - ws.send(ret); + ws.send({ action: "watch", ...ret }); }, }), }; -const baseWs = new Elysia() - .guard({ - headers: t.Object( - { - authorization: t.Optional(t.TemplateLiteral("Bearer ${string}")), - "Sec-WebSocket-Protocol": t.Optional( - t.Array( - t.Union([t.Literal("kyoo"), t.TemplateLiteral("Bearer ${string}")]), - ), - ), - }, - { additionalProperties: true }, - ), - }) - .resolve( - async ({ - headers: { authorization, "Sec-WebSocket-Protocol": wsProtocol }, - status, - }) => { - const auth = - authorization ?? - (wsProtocol?.length === 2 && - wsProtocol[0] === "kyoo" && - wsProtocol[1].startsWith("Bearer ") - ? wsProtocol[1] - : null); - const bearer = auth?.slice(7); - if (!bearer) { - return status(403, { - status: 403, - message: "No authorization header was found.", - }); - } - try { - return await verifyJwt(bearer); - } catch (err) { - return status(403, { - status: 403, - message: "Invalid jwt. Verification vailed", - details: err, - }); - } - }, - ); +const baseWs = new Elysia().use(auth); export const appWs = baseWs.ws("/ws", { body: t.Union( diff --git a/auth/jwt.go b/auth/jwt.go index 4337bca9..ab72be8c 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -44,9 +44,17 @@ func (h *Handler) CreateJwt(c echo.Context) error { var token string if auth == "" { - c, _ := c.Request().Cookie("X-Bearer") - if c != nil { - token = c.Value + cookie, _ := c.Request().Cookie("X-Bearer") + if cookie != nil { + token = cookie.Value + } else { + protocol, ok := c.Request().Header["Sec-Websocket-Protocol"] + if ok && + len(protocol) == 2 && + protocol[0] == "kyoo" && + strings.HasPrefix(protocol[1], "Bearer") { + token = protocol[1][len("Bearer "):] + } } } else if strings.HasPrefix(auth, "Bearer ") { token = auth[len("Bearer "):]