From b5ace8d6ed3d1b222d886ebd5a2021aee48a7865 Mon Sep 17 00:00:00 2001 From: Zoe Roux Date: Wed, 25 Mar 2026 18:29:44 +0100 Subject: [PATCH] Add oidc in all users returns --- auth/apikey.go | 2 +- auth/dbc/users.sql.go | 214 +++++++++++------- auth/models/users.go | 58 +++++ auth/oidc.go | 37 +-- auth/page.go | 6 +- auth/sessions.go | 9 +- .../000005_user_oidc_domain.down.sql | 5 + .../migrations/000005_user_oidc_domain.up.sql | 5 + auth/sql/queries/users.sql | 72 ++++-- auth/sqlc.yaml | 4 + auth/users.go | 158 ++++--------- front/src/query/query.tsx | 2 +- front/src/ui/login/logic.tsx | 20 +- front/src/ui/login/oidc-callback.tsx | 13 +- 14 files changed, 365 insertions(+), 240 deletions(-) create mode 100644 auth/models/users.go create mode 100644 auth/sql/migrations/000005_user_oidc_domain.down.sql create mode 100644 auth/sql/migrations/000005_user_oidc_domain.up.sql diff --git a/auth/apikey.go b/auth/apikey.go index 7a3b5ce5..a2b96fa4 100644 --- a/auth/apikey.go +++ b/auth/apikey.go @@ -96,7 +96,7 @@ func (h *Handler) CreateApiKey(c *echo.Context) error { UseId: true, Id: uid, }) - user = &u[0].User.Pk + user = &u.User.Pk } dbkey, err := h.db.CreateApiKey(ctx, dbc.CreateApiKeyParams{ diff --git a/auth/dbc/users.sql.go b/auth/dbc/users.sql.go index 07bb6dc3..e7eee8aa 100644 --- a/auth/dbc/users.sql.go +++ b/auth/dbc/users.sql.go @@ -11,6 +11,7 @@ import ( jwt "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + "github.com/zoriya/kyoo/keibi/models" ) const createUser = `-- name: CreateUser :one @@ -100,32 +101,55 @@ func (q *Queries) DeleteUser(ctx context.Context, id uuid.UUID) (User, error) { const getAllUsers = `-- name: GetAllUsers :many select - pk, id, username, email, password, claims, created_date, last_seen + u.pk, u.id, u.username, u.email, u.password, u.claims, u.created_date, u.last_seen, + coalesce( + jsonb_object_agg( + h.provider, + jsonb_build_object( + 'id', h.id, + 'username', h.username, + 'profileUrl', h.profile_url + ) + ) filter ( + where + h.provider is not null + ), + '{}'::jsonb + )::keibi.user_oidc as oidc from - keibi.users + keibi.users as u + left join keibi.oidc_handle as h on u.pk = h.user_pk +group by + u.pk order by - id + u.pk limit $1 ` -func (q *Queries) GetAllUsers(ctx context.Context, limit int32) ([]User, error) { +type GetAllUsersRow struct { + User User `json:"user"` + Oidc models.OidcMap `json:"oidc"` +} + +func (q *Queries) GetAllUsers(ctx context.Context, limit int32) ([]GetAllUsersRow, error) { rows, err := q.db.Query(ctx, getAllUsers, limit) if err != nil { return nil, err } defer rows.Close() - var items []User + var items []GetAllUsersRow for rows.Next() { - var i User + var i GetAllUsersRow if err := rows.Scan( - &i.Pk, - &i.Id, - &i.Username, - &i.Email, - &i.Password, - &i.Claims, - &i.CreatedDate, - &i.LastSeen, + &i.User.Pk, + &i.User.Id, + &i.User.Username, + &i.User.Email, + &i.User.Password, + &i.User.Claims, + &i.User.CreatedDate, + &i.User.LastSeen, + &i.Oidc, ); err != nil { return nil, err } @@ -139,89 +163,52 @@ func (q *Queries) GetAllUsers(ctx context.Context, limit int32) ([]User, error) const getAllUsersAfter = `-- name: GetAllUsersAfter :many select - pk, id, username, email, password, claims, created_date, last_seen + u.pk, u.id, u.username, u.email, u.password, u.claims, u.created_date, u.last_seen, + coalesce( + jsonb_object_agg( + h.provider, + jsonb_build_object( + 'id', h.id, + 'username', h.username, + 'profileUrl', h.profile_url + ) + ) filter ( + where + h.provider is not null + ), + '{}'::jsonb + )::keibi.user_oidc as oidc from - keibi.users + keibi.users as u + left join keibi.oidc_handle as h on u.pk = h.user_pk where - id >= $2 + u.pk >= $2 +group by + u.pk order by - id + u.pk limit $1 ` type GetAllUsersAfterParams struct { - Limit int32 `json:"limit"` - AfterId uuid.UUID `json:"afterId"` + Limit int32 `json:"limit"` + AfterPk int32 `json:"afterPk"` } -func (q *Queries) GetAllUsersAfter(ctx context.Context, arg GetAllUsersAfterParams) ([]User, error) { - rows, err := q.db.Query(ctx, getAllUsersAfter, arg.Limit, arg.AfterId) +type GetAllUsersAfterRow struct { + User User `json:"user"` + Oidc models.OidcMap `json:"oidc"` +} + +func (q *Queries) GetAllUsersAfter(ctx context.Context, arg GetAllUsersAfterParams) ([]GetAllUsersAfterRow, error) { + rows, err := q.db.Query(ctx, getAllUsersAfter, arg.Limit, arg.AfterPk) if err != nil { return nil, err } defer rows.Close() - var items []User + var items []GetAllUsersAfterRow for rows.Next() { - var i User - if err := rows.Scan( - &i.Pk, - &i.Id, - &i.Username, - &i.Email, - &i.Password, - &i.Claims, - &i.CreatedDate, - &i.LastSeen, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getUser = `-- name: GetUser :many -select - u.pk, u.id, u.username, u.email, u.password, u.claims, u.created_date, u.last_seen, - h.provider, - h.id, - h.username, - h.profile_url -from - keibi.users as u - left join keibi.oidc_handle as h on u.pk = h.user_pk -where ($1::boolean - and u.id = $2) - or (not $1 - and u.username = $3) -` - -type GetUserParams struct { - UseId bool `json:"useId"` - Id uuid.UUID `json:"id"` - Username string `json:"username"` -} - -type GetUserRow struct { - User User `json:"user"` - Provider *string `json:"provider"` - Id *string `json:"id"` - Username *string `json:"username"` - ProfileUrl *string `json:"profileUrl"` -} - -func (q *Queries) GetUser(ctx context.Context, arg GetUserParams) ([]GetUserRow, error) { - rows, err := q.db.Query(ctx, getUser, arg.UseId, arg.Id, arg.Username) - if err != nil { - return nil, err - } - defer rows.Close() - var items []GetUserRow - for rows.Next() { - var i GetUserRow + var i GetAllUsersAfterRow if err := rows.Scan( &i.User.Pk, &i.User.Id, @@ -231,10 +218,7 @@ func (q *Queries) GetUser(ctx context.Context, arg GetUserParams) ([]GetUserRow, &i.User.Claims, &i.User.CreatedDate, &i.User.LastSeen, - &i.Provider, - &i.Id, - &i.Username, - &i.ProfileUrl, + &i.Oidc, ); err != nil { return nil, err } @@ -246,6 +230,62 @@ func (q *Queries) GetUser(ctx context.Context, arg GetUserParams) ([]GetUserRow, return items, nil } +const getUser = `-- name: GetUser :one +select + u.pk, u.id, u.username, u.email, u.password, u.claims, u.created_date, u.last_seen, + coalesce( + jsonb_object_agg( + h.provider, + jsonb_build_object( + 'id', h.id, + 'username', h.username, + 'profileUrl', h.profile_url + ) + ) filter ( + where + h.provider is not null + ), + '{}'::jsonb + )::keibi.user_oidc as oidc +from + keibi.users as u + left join keibi.oidc_handle as h on u.pk = h.user_pk +where ($1::boolean + and u.id = $2) + or (not $1 + and u.username = $3) +group by + u.pk +` + +type GetUserParams struct { + UseId bool `json:"useId"` + Id uuid.UUID `json:"id"` + Username string `json:"username"` +} + +type GetUserRow struct { + User User `json:"user"` + Oidc models.OidcMap `json:"oidc"` +} + +func (q *Queries) GetUser(ctx context.Context, arg GetUserParams) (GetUserRow, error) { + row := q.db.QueryRow(ctx, getUser, arg.UseId, arg.Id, arg.Username) + var i GetUserRow + err := row.Scan( + &i.User.Pk, + &i.User.Id, + &i.User.Username, + &i.User.Email, + &i.User.Password, + &i.User.Claims, + &i.User.CreatedDate, + &i.User.LastSeen, + &i.Oidc, + ) + return i, err +} + const getUserByEmail = `-- name: GetUserByEmail :one select pk, id, username, email, password, claims, created_date, last_seen diff --git a/auth/models/users.go b/auth/models/users.go new file mode 100644 index 00000000..7e6c04ca --- /dev/null +++ b/auth/models/users.go @@ -0,0 +1,58 @@ +package models + +import ( + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +type User struct { + // Primary key in database + Pk int32 `json:"-"` + // Id of the user. + Id uuid.UUID `json:"id" example:"e05089d6-9179-4b5b-a63e-94dd5fc2a397"` + // Username of the user. Can be used as a login. + Username string `json:"username" example:"zoriya"` + // Email of the user. Can be used as a login. + Email string `json:"email" format:"email" example:"kyoo@zoriya.dev"` + // When was this account created? + CreatedDate time.Time `json:"createdDate" example:"2025-03-29T18:20:05.267Z"` + // When was the last time this account made any authorized request? + LastSeen time.Time `json:"lastSeen" example:"2025-03-29T18:20:05.267Z"` + // List of custom claims JWT created via get /jwt will have + Claims jwt.MapClaims `json:"claims" example:"isAdmin: true"` + // List of other login method available for this user. Access tokens wont be returned here. + Oidc map[string]OidcHandle `json:"oidc,omitempty"` +} + +type OidcHandle struct { + // Id of this oidc handle. + Id string `json:"id" example:"e05089d6-9179-4b5b-a63e-94dd5fc2a397"` + // Username of the user on the external service. + Username string `json:"username" example:"zoriya"` + // Link to the profile of the user on the external service. Null if unknown or irrelevant. + ProfileUrl *string `json:"profileUrl" format:"url" example:"https://myanimelist.net/profile/zoriya"` +} +type OidcMap = map[string]OidcHandle + +type RegisterDto struct { + // Username of the new account, can't contain @ signs. Can be used for login. + Username string `json:"username" validate:"required,excludes=@" example:"zoriya"` + // Valid email that could be used for forgotten password requests. Can be used for login. + Email string `json:"email" validate:"required,email" format:"email" example:"kyoo@zoriya.dev"` + // Password to use. + Password string `json:"password" validate:"required" example:"password1234"` +} + +type EditUserDto struct { + Username *string `json:"username,omitempty" validate:"omitnil,excludes=@" example:"zoriya"` + Email *string `json:"email,omitempty" validate:"omitnil,email" example:"kyoo@zoriya.dev"` + Claims jwt.MapClaims `json:"claims,omitempty" example:"preferOriginal: true"` +} + +type EditPasswordDto struct { + OldPassword string `json:"oldPassword" validate:"required" example:"password1234"` + NewPassword string `json:"newPassword" validate:"required" example:"password1234"` +} + diff --git a/auth/oidc.go b/auth/oidc.go index 58543f84..bd8420ae 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -17,6 +17,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/labstack/echo/v5" "github.com/zoriya/kyoo/keibi/dbc" + "github.com/zoriya/kyoo/keibi/models" ) type OidcProvider struct { @@ -341,15 +342,6 @@ func (h *Handler) LinkOidcTo( uid uuid.UUID, ) error { ctx := c.Request().Context() - rows, err := h.db.GetUser(ctx, dbc.GetUserParams{ - UseId: true, - Id: uid, - }) - if err != nil { - return err - } - user := rows[0].User - existing, err := h.db.GetUserByOidc(ctx, dbc.GetUserByOidcParams{ Provider: provider.Id, Id: profile.Sub, @@ -366,8 +358,16 @@ func (h *Handler) LinkOidcTo( expireAt = new(time.Now().UTC().Add(time.Duration(token.ExpiresIn * float64(time.Second)))) } + dbuser, err := h.db.GetUser(ctx, dbc.GetUserParams{ + UseId: true, + Id: uid, + }) + if err != nil { + return err + } + err = h.db.UpsertOidcHandle(ctx, dbc.UpsertOidcHandleParams{ - UserPk: user.Pk, + UserPk: dbuser.User.Pk, Provider: provider.Id, Id: profile.Sub, Username: profile.Username, @@ -379,7 +379,14 @@ func (h *Handler) LinkOidcTo( if err != nil { return err } - return c.JSON(http.StatusOK, MapDbUser(&user)) + ret := MapDbUser(&dbuser.User) + ret.Oidc = dbuser.Oidc + ret.Oidc[provider.Id] = models.OidcHandle{ + Id: profile.Sub, + Username: profile.Username, + ProfileUrl: nil, + } + return c.JSON(http.StatusOK, ret) } func (h *Handler) CreateUserByOidc( @@ -460,16 +467,18 @@ func (h *Handler) OidcUnlink(c *echo.Context) error { } ctx := c.Request().Context() - rows, err := h.db.GetUser(ctx, dbc.GetUserParams{UseId: true, Id: uid}) + user, err := h.db.GetUser(ctx, dbc.GetUserParams{UseId: true, Id: uid}) if err != nil { return err } - if len(rows) == 0 { + if err == pgx.ErrNoRows { return echo.NewHTTPError(http.StatusNotFound, "No user found") + } else if err != nil { + return nil } err = h.db.DeleteOidcHandle(ctx, dbc.DeleteOidcHandleParams{ - UserPk: rows[0].User.Pk, + UserPk: user.User.Pk, Provider: providerName, }) if err != nil { diff --git a/auth/page.go b/auth/page.go index c82dffff..6d492754 100644 --- a/auth/page.go +++ b/auth/page.go @@ -1,6 +1,10 @@ package main -import "net/url" +import ( + "net/url" + + . "github.com/zoriya/kyoo/keibi/models" +) type Page[T any] struct { Items []T `json:"items"` diff --git a/auth/sessions.go b/auth/sessions.go index 9476be77..97cbc257 100644 --- a/auth/sessions.go +++ b/auth/sessions.go @@ -15,6 +15,7 @@ import ( "github.com/labstack/echo/v5" "github.com/mileusna/useragent" "github.com/zoriya/kyoo/keibi/dbc" + . "github.com/zoriya/kyoo/keibi/models" ) type Session struct { @@ -161,7 +162,7 @@ func (h *Handler) ListMySessions(c *echo.Context) error { return err } - dbSessions, err := h.db.GetUserSessions(ctx, users[0].User.Pk) + dbSessions, err := h.db.GetUserSessions(ctx, users.User.Pk) if err != nil { return err } @@ -201,11 +202,13 @@ func (h *Handler) ListUserSessions(c *echo.Context) error { if err != nil { return err } - if len(users) == 0 { + if err == pgx.ErrNoRows { return echo.NewHTTPError(http.StatusNotFound, "No user found with id or username") + } else if err != nil { + return err } - dbSessions, err := h.db.GetUserSessions(ctx, users[0].User.Pk) + dbSessions, err := h.db.GetUserSessions(ctx, users.User.Pk) if err != nil { return err } diff --git a/auth/sql/migrations/000005_user_oidc_domain.down.sql b/auth/sql/migrations/000005_user_oidc_domain.down.sql new file mode 100644 index 00000000..f845d488 --- /dev/null +++ b/auth/sql/migrations/000005_user_oidc_domain.down.sql @@ -0,0 +1,5 @@ +begin; + +drop domain keibi.user_oidc; + +commit; diff --git a/auth/sql/migrations/000005_user_oidc_domain.up.sql b/auth/sql/migrations/000005_user_oidc_domain.up.sql new file mode 100644 index 00000000..81e016dd --- /dev/null +++ b/auth/sql/migrations/000005_user_oidc_domain.up.sql @@ -0,0 +1,5 @@ +begin; + +create domain keibi.user_oidc as jsonb; + +commit; diff --git a/auth/sql/queries/users.sql b/auth/sql/queries/users.sql index a3ce4ec6..4a5790d0 100644 --- a/auth/sql/queries/users.sql +++ b/auth/sql/queries/users.sql @@ -1,37 +1,83 @@ -- name: GetAllUsers :many select - * + sqlc.embed(u), + coalesce( + jsonb_object_agg( + h.provider, + jsonb_build_object( + 'id', h.id, + 'username', h.username, + 'profileUrl', h.profile_url + ) + ) filter ( + where + h.provider is not null + ), + '{}'::jsonb + )::keibi.user_oidc as oidc from - keibi.users + keibi.users as u + left join keibi.oidc_handle as h on u.pk = h.user_pk +group by + u.pk order by - id + u.pk limit $1; -- name: GetAllUsersAfter :many select - * + sqlc.embed(u), + coalesce( + jsonb_object_agg( + h.provider, + jsonb_build_object( + 'id', h.id, + 'username', h.username, + 'profileUrl', h.profile_url + ) + ) filter ( + where + h.provider is not null + ), + '{}'::jsonb + )::keibi.user_oidc as oidc from - keibi.users + keibi.users as u + left join keibi.oidc_handle as h on u.pk = h.user_pk where - id >= sqlc.arg(after_id) + u.pk >= sqlc.arg(after_pk) +group by + u.pk order by - id + u.pk limit $1; --- name: GetUser :many +-- name: GetUser :one select sqlc.embed(u), - h.provider, - h.id, - h.username, - h.profile_url + coalesce( + jsonb_object_agg( + h.provider, + jsonb_build_object( + 'id', h.id, + 'username', h.username, + 'profileUrl', h.profile_url + ) + ) filter ( + where + h.provider is not null + ), + '{}'::jsonb + )::keibi.user_oidc as oidc from keibi.users as u left join keibi.oidc_handle as h on u.pk = h.user_pk where (@use_id::boolean and u.id = @id) or (not @use_id - and u.username = @username); + and u.username = @username) +group by + u.pk; -- name: GetUserByLogin :one select diff --git a/auth/sqlc.yaml b/auth/sqlc.yaml index 7e157c0a..a0345ed9 100644 --- a/auth/sqlc.yaml +++ b/auth/sqlc.yaml @@ -30,6 +30,10 @@ sql: - db_type: "jsonb" go_type: type: "interface{}" + - db_type: "keibi.user_oidc" + go_type: + import: "github.com/zoriya/kyoo/keibi/models" + type: "OidcMap" - column: "keibi.users.claims" go_type: import: "github.com/golang-jwt/jwt/v5" diff --git a/auth/users.go b/auth/users.go index ca20151e..461600fb 100644 --- a/auth/users.go +++ b/auth/users.go @@ -5,66 +5,18 @@ import ( "encoding/hex" "fmt" "net/http" + "strconv" "strings" - "time" "github.com/alexedwards/argon2id" - "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" "github.com/labstack/echo/v5" "github.com/zoriya/kyoo/keibi/dbc" + . "github.com/zoriya/kyoo/keibi/models" ) -type User struct { - // Primary key in database - Pk int32 `json:"-"` - // Id of the user. - Id uuid.UUID `json:"id" example:"e05089d6-9179-4b5b-a63e-94dd5fc2a397"` - // Username of the user. Can be used as a login. - Username string `json:"username" example:"zoriya"` - // Email of the user. Can be used as a login. - Email string `json:"email" format:"email" example:"kyoo@zoriya.dev"` - // When was this account created? - CreatedDate time.Time `json:"createdDate" example:"2025-03-29T18:20:05.267Z"` - // When was the last time this account made any authorized request? - LastSeen time.Time `json:"lastSeen" example:"2025-03-29T18:20:05.267Z"` - // List of custom claims JWT created via get /jwt will have - Claims jwt.MapClaims `json:"claims" example:"isAdmin: true"` - // List of other login method available for this user. Access tokens wont be returned here. - Oidc map[string]OidcHandle `json:"oidc,omitempty"` -} - -type OidcHandle struct { - // Id of this oidc handle. - Id string `json:"id" example:"e05089d6-9179-4b5b-a63e-94dd5fc2a397"` - // Username of the user on the external service. - Username string `json:"username" example:"zoriya"` - // Link to the profile of the user on the external service. Null if unknown or irrelevant. - ProfileUrl *string `json:"profileUrl" format:"url" example:"https://myanimelist.net/profile/zoriya"` -} - -type RegisterDto struct { - // Username of the new account, can't contain @ signs. Can be used for login. - Username string `json:"username" validate:"required,excludes=@" example:"zoriya"` - // Valid email that could be used for forgotten password requests. Can be used for login. - Email string `json:"email" validate:"required,email" format:"email" example:"kyoo@zoriya.dev"` - // Password to use. - Password string `json:"password" validate:"required" example:"password1234"` -} - -type EditUserDto struct { - Username *string `json:"username,omitempty" validate:"omitnil,excludes=@" example:"zoriya"` - Email *string `json:"email,omitempty" validate:"omitnil,email" example:"kyoo@zoriya.dev"` - Claims jwt.MapClaims `json:"claims,omitempty" example:"preferOriginal: true"` -} - -type EditPasswordDto struct { - OldPassword string `json:"oldPassword" validate:"required" example:"password1234"` - NewPassword string `json:"newPassword" validate:"required" example:"password1234"` -} - func MapDbUser(user *dbc.User) User { return User{ Pk: user.Pk, @@ -74,15 +26,7 @@ func MapDbUser(user *dbc.User) User { CreatedDate: user.CreatedDate, LastSeen: user.LastSeen, Claims: user.Claims, - Oidc: make(map[string]OidcHandle), - } -} - -func MapOidc(oidc *dbc.GetUserRow) OidcHandle { - return OidcHandle{ - Id: *oidc.Id, - Username: *oidc.Username, - ProfileUrl: oidc.ProfileUrl, + Oidc: nil, } } @@ -107,29 +51,40 @@ func (h *Handler) ListUsers(c *echo.Context) error { limit := int32(20) id := c.Param("after") - var users []dbc.User if id == "" { - users, err = h.db.GetAllUsers(ctx, limit) - } else { - uid, uerr := uuid.Parse(id) - if uerr != nil { - return echo.NewHTTPError(http.StatusUnprocessableEntity, "Invalid `after` parameter, uuid was expected") + users, err := h.db.GetAllUsers(ctx, limit) + if err != nil { + return err } - users, err = h.db.GetAllUsersAfter(ctx, dbc.GetAllUsersAfterParams{ + + ret := make([]User, 0, len(users)) + for _, user := range users { + u := MapDbUser(&user.User) + u.Oidc = user.Oidc + ret = append(ret, u) + } + return c.JSON(200, NewPage(ret, c.Request().URL, limit)) + } else { + pk, err := strconv.Atoi(id) + if err != nil { + return echo.NewHTTPError(http.StatusUnprocessableEntity, "Invalid `after` parameter") + } + users, err := h.db.GetAllUsersAfter(ctx, dbc.GetAllUsersAfterParams{ Limit: limit, - AfterId: uid, + AfterPk: int32(pk), }) - } + if err != nil { + return err + } - if err != nil { - return err + ret := make([]User, 0, len(users)) + for _, user := range users { + u := MapDbUser(&user.User) + u.Oidc = user.Oidc + ret = append(ret, u) + } + return c.JSON(200, NewPage(ret, c.Request().URL, limit)) } - - var ret []User - for _, user := range users { - ret = append(ret, MapDbUser(&user)) - } - return c.JSON(200, NewPage(ret, c.Request().URL, limit)) } // @Summary Get user @@ -150,27 +105,21 @@ func (h *Handler) GetUser(c *echo.Context) error { } id := c.Param("id") - uid, err := uuid.Parse(c.Param("id")) + uid, err := uuid.Parse(id) dbuser, err := h.db.GetUser(ctx, dbc.GetUserParams{ UseId: err == nil, Id: uid, Username: id, }) - if err != nil { + if err == pgx.ErrNoRows { + return echo.NewHTTPError(404, fmt.Sprintf("No user found with id or username: '%s'.", id)) + } else if err != nil { return err } - if len(dbuser) == 0 { - return echo.NewHTTPError(404, fmt.Sprintf("No user found with id or username: '%s'.", id)) - } - user := MapDbUser(&dbuser[0].User) - for _, oidc := range dbuser { - if oidc.Provider != nil { - user.Oidc[*oidc.Provider] = MapOidc(&oidc) - } - } - - return c.JSON(200, user) + ret := MapDbUser(&dbuser.User) + ret.Oidc = dbuser.Oidc + return c.JSON(200, ret) } // @Summary Get me @@ -192,21 +141,15 @@ func (h *Handler) GetMe(c *echo.Context) error { UseId: true, Id: id, }) - if err != nil { + if err == pgx.ErrNoRows { + return c.JSON(403, "Invalid jwt token, couldn't find user.") + } else if err != nil { return err } - if len(dbuser) == 0 { - return c.JSON(403, "Invalid jwt token, couldn't find user.") - } - user := MapDbUser(&dbuser[0].User) - for _, oidc := range dbuser { - if oidc.Provider != nil { - user.Oidc[*oidc.Provider] = MapOidc(&oidc) - } - } - - return c.JSON(200, user) + ret := MapDbUser(&dbuser.User) + ret.Oidc = dbuser.Oidc + return c.JSON(200, ret) } func (h *Handler) streamGravatar(c *echo.Context, email string) error { @@ -264,7 +207,7 @@ func (h *Handler) GetMyLogo(c *echo.Context) error { return err } - return h.streamGravatar(c, users[0].User.Email) + return h.streamGravatar(c, users.User.Email) } // @Summary Get user logo @@ -291,14 +234,13 @@ func (h *Handler) GetUserLogo(c *echo.Context) error { Id: uid, Username: id, }) - if err != nil { + if err == pgx.ErrNoRows { + return echo.NewHTTPError(404, fmt.Sprintf("No user found with id or username: '%s'.", id)) + } else if err != nil { return err } - if len(users) == 0 { - return echo.NewHTTPError(404, fmt.Sprintf("No user found with id or username: '%s'.", id)) - } - return h.streamGravatar(c, users[0].User.Email) + return h.streamGravatar(c, users.User.Email) } // @Summary Register @@ -534,7 +476,7 @@ func (h *Handler) ChangePassword(c *echo.Context) error { match, err := argon2id.ComparePasswordAndHash( req.OldPassword, - *user[0].User.Password, + *user.User.Password, ) if err != nil { return err diff --git a/front/src/query/query.tsx b/front/src/query/query.tsx index 96781d46..14ec37eb 100644 --- a/front/src/query/query.tsx +++ b/front/src/query/query.tsx @@ -148,7 +148,7 @@ export type QueryIdentifier = { }; }; -const toQueryKey = (query: { +export const toQueryKey = (query: { apiUrl: string; path: (string | undefined)[]; params?: { diff --git a/front/src/ui/login/logic.tsx b/front/src/ui/login/logic.tsx index 50f6759b..2eb316fe 100644 --- a/front/src/ui/login/logic.tsx +++ b/front/src/ui/login/logic.tsx @@ -60,21 +60,19 @@ export const oidcLogin = async ( ) => { apiUrl ??= defaultApiUrl; try { - const ret = await queryFn({ + const { token } = await queryFn({ method: "GET", url: `${apiUrl}/auth/oidc/callback/${provider}?token=${code}`, authToken: linkToToken, - parser: linkToToken ? z.object({ token: z.string() }) : User, + parser: linkToToken ? null : z.object({ token: z.string() }), + }); + if (linkToToken) return { ok: true, value: null }; + const user = await queryFn({ + method: "GET", + url: `${apiUrl}/auth/users/me`, + authToken: token, + parser: User, }); - const token = linkToToken ?? (ret as { token: string }).token; - const user = linkToToken - ? (ret as User) - : await queryFn({ - method: "GET", - url: `${apiUrl}/auth/users/me`, - authToken: token, - parser: User, - }); const account: Account = { ...user, apiUrl, token, selected: true }; addAccount(account); return { ok: true, value: account }; diff --git a/front/src/ui/login/oidc-callback.tsx b/front/src/ui/login/oidc-callback.tsx index f3248981..48a04bd4 100644 --- a/front/src/ui/login/oidc-callback.tsx +++ b/front/src/ui/login/oidc-callback.tsx @@ -1,7 +1,9 @@ +import { useQueryClient } from "@tanstack/react-query"; import { useRouter } from "expo-router"; import { useEffect } from "react"; import { P } from "~/primitives"; import { useToken } from "~/providers/account-context"; +import { toQueryKey } from "~/query"; import { useQueryState } from "~/utils"; import { oidcLogin } from "./logic"; @@ -14,6 +16,7 @@ export const OidcCallbackPage = () => { const [link] = useQueryState("link", undefined!); const router = useRouter(); + const queryClient = useQueryClient(); // biome-ignore lint/correctness/useExhaustiveDependencies: useMountEffect useEffect(() => { @@ -28,7 +31,15 @@ export const OidcCallbackPage = () => { apiUrl, ); if (loginError) onError(loginError); - else router.replace(link ? "/settings" : "/"); + else if (link) { + queryClient.invalidateQueries({ + queryKey: toQueryKey({ + apiUrl, + path: ["auth", "users", "me"], + }), + }); + router.replace("/settings"); + } else router.replace("/"); } if (error) onError(error);