package main import ( "cmp" "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "maps" "os" "slices" "strconv" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/zoriya/kyoo/keibi/dbc" ) type Configuration struct { JwtPrivateKey *rsa.PrivateKey JwtPublicKey *rsa.PublicKey JwtKid string PublicUrl string OidcProviders map[string]OidcProviderConfig DefaultClaims jwt.MapClaims FirstUserClaims jwt.MapClaims GuestClaims jwt.MapClaims ProtectedClaims []string ExpirationDelay time.Duration EnvApiKeys []ApiKeyWToken ProfilePicturePath string DisableRegistration bool } type OidcAuthMethod string const ( OidcClientSecretBasic OidcAuthMethod = "ClientSecretBasic" OidcClientSecretPost OidcAuthMethod = "ClientSecretPost" ) type OidcProviderConfig struct { Id string Name string Logo string ClientId string Secret string Authorization string Token string Profile string Scope string AuthMethod OidcAuthMethod } var DefaultConfig = Configuration{ DefaultClaims: make(jwt.MapClaims), FirstUserClaims: make(jwt.MapClaims), OidcProviders: make(map[string]OidcProviderConfig), ProtectedClaims: []string{"permissions"}, ExpirationDelay: 30 * 24 * time.Hour, EnvApiKeys: make([]ApiKeyWToken, 0), } func LoadConfiguration(ctx context.Context, db *dbc.Queries) (*Configuration, error) { ret := DefaultConfig ret.PublicUrl = os.Getenv("PUBLIC_URL") ret.ProfilePicturePath = cmp.Or( os.Getenv("PROFILE_PICTURE_PATH"), "/profile_pictures", ) disableRegistration, err := strconv.ParseBool(cmp.Or(os.Getenv("DISABLE_REGISTRATION"), "false")) if err != nil { return nil, fmt.Errorf("invalid DISABLE_REGISTRATION value: %w", err) } ret.DisableRegistration = disableRegistration claims := os.Getenv("EXTRA_CLAIMS") if claims != "" { err := json.Unmarshal([]byte(claims), &ret.DefaultClaims) if err != nil { return nil, err } } maps.Insert(ret.FirstUserClaims, maps.All(ret.DefaultClaims)) claims = os.Getenv("FIRST_USER_CLAIMS") if claims != "" { err := json.Unmarshal([]byte(claims), &ret.FirstUserClaims) if err != nil { return nil, err } } else { ret.FirstUserClaims = ret.DefaultClaims } claims = os.Getenv("GUEST_CLAIMS") if claims != "" { err := json.Unmarshal([]byte(claims), &ret.GuestClaims) if err != nil { return nil, err } } protected := strings.Split(os.Getenv("PROTECTED_CLAIMS"), ",") ret.ProtectedClaims = append(ret.ProtectedClaims, protected...) rsa_pk_path := os.Getenv("RSA_PRIVATE_KEY_PATH") if rsa_pk_path != "" { privateKeyData, err := os.ReadFile(rsa_pk_path) if err != nil { return nil, err } block, _ := pem.Decode(privateKeyData) if block == nil || block.Type != "RSA PRIVATE KEY" { return nil, err } ret.JwtPrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { pkcs8Key, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return nil, err } ret.JwtPrivateKey = pkcs8Key.(*rsa.PrivateKey) } } else { var err error ret.JwtPrivateKey, err = rsa.GenerateKey(rand.Reader, 4096) if err != nil { return nil, err } } ret.JwtPublicKey = &ret.JwtPrivateKey.PublicKey key, err := jwk.Import(ret.JwtPublicKey) if err != nil { return nil, err } thumbprint, err := key.Thumbprint(crypto.SHA256) if err != nil { return nil, err } ret.JwtKid = base64.RawStdEncoding.EncodeToString(thumbprint) for _, env := range os.Environ() { if !strings.HasPrefix(env, "KEIBI_APIKEY_") { continue } v := strings.Split(env, "=") if strings.HasSuffix(v[0], "_CLAIMS") { continue } name := strings.TrimPrefix(v[0], "KEIBI_APIKEY_") cstr := os.Getenv(fmt.Sprintf("KEIBI_APIKEY_%s_CLAIMS", name)) var claims jwt.MapClaims if cstr != "" { err := json.Unmarshal([]byte(cstr), &claims) if err != nil { return nil, err } } else { return nil, fmt.Errorf("missing claims env var KEIBI_APIKEY_%s_CLAIMS", name) } name = strings.ToLower(name) ret.EnvApiKeys = append(ret.EnvApiKeys, ApiKeyWToken{ ApiKey: ApiKey{ Id: uuid.New(), Name: name, Claims: claims, }, Token: v[1], }) } apikeys, err := db.ListApiKeys(ctx) if err != nil { return nil, err } for _, key := range apikeys { dup := slices.ContainsFunc(ret.EnvApiKeys, func(k ApiKeyWToken) bool { return k.Name == key.Name }) if dup { return nil, fmt.Errorf( "an api key with the name %s is already defined in database. Can't specify a new one via env var", key.Name, ) } } oidcProviders := make([]string, 0) for _, env := range os.Environ() { parts := strings.SplitN(env, "=", 2) if len(parts) != 2 { continue } k := parts[0] if !strings.HasPrefix(k, "OIDC_") || !strings.HasSuffix(k, "_CLIENTID") { continue } name := strings.TrimSuffix(strings.TrimPrefix(k, "OIDC_"), "_CLIENTID") if name == "" { continue } oidcProviders = append(oidcProviders, name) } for _, name := range oidcProviders { providerId := strings.ToLower(name) provider := OidcProviderConfig{ Id: providerId, Name: os.Getenv(fmt.Sprintf("OIDC_%s_NAME", name)), Logo: os.Getenv(fmt.Sprintf("OIDC_%s_LOGO", name)), ClientId: os.Getenv(fmt.Sprintf("OIDC_%s_CLIENTID", name)), Secret: os.Getenv(fmt.Sprintf("OIDC_%s_SECRET", name)), Authorization: os.Getenv(fmt.Sprintf("OIDC_%s_AUTHORIZATION", name)), Token: os.Getenv(fmt.Sprintf("OIDC_%s_TOKEN", name)), Profile: os.Getenv(fmt.Sprintf("OIDC_%s_PROFILE", name)), Scope: os.Getenv(fmt.Sprintf("OIDC_%s_SCOPE", name)), AuthMethod: OidcClientSecretBasic, } authMethod := os.Getenv(fmt.Sprintf("OIDC_%s_AUTHMETHOD", name)) if authMethod != "" { switch OidcAuthMethod(authMethod) { case OidcClientSecretBasic, OidcClientSecretPost: provider.AuthMethod = OidcAuthMethod(authMethod) default: return nil, fmt.Errorf("invalid OIDC_%s_AUTHMETHOD: %s", name, authMethod) } } if provider.Name == "" { provider.Name = name } if provider.ClientId == "" || provider.Secret == "" || provider.Authorization == "" || provider.Token == "" || provider.Profile == "" { return nil, fmt.Errorf("invalid oidc configuration for provider %s, missing required values", providerId) } if provider.Scope == "" { provider.Scope = "openid profile email" } ret.OidcProviders[providerId] = provider } return &ret, nil }