diff --git a/modules/caddytls/connpolicy.go b/modules/caddytls/connpolicy.go index d9fc6bcfe..cb7de18b0 100644 --- a/modules/caddytls/connpolicy.go +++ b/modules/caddytls/connpolicy.go @@ -749,10 +749,14 @@ func (clientauth *ClientAuthentication) provision(ctx caddy.Context) error { // if we have TrustedCACerts explicitly set, create an 'inline' CA and return if len(clientauth.TrustedCACerts) > 0 { - clientauth.ca = InlineCAPool{ + caPool := InlineCAPool{ TrustedCACerts: clientauth.TrustedCACerts, } - return nil + err := caPool.Provision(ctx) + if err != nil { + return nil + } + clientauth.ca = caPool } // if we don't have any CARaw set, there's not much work to do diff --git a/modules/caddytls/connpolicy_test.go b/modules/caddytls/connpolicy_test.go index 0caed2899..82ecbc40d 100644 --- a/modules/caddytls/connpolicy_test.go +++ b/modules/caddytls/connpolicy_test.go @@ -20,6 +20,7 @@ import ( "reflect" "testing" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" ) @@ -278,3 +279,49 @@ func TestClientAuthenticationUnmarshalCaddyfileWithDirectiveName(t *testing.T) { }) } } + +func TestClientAuthenticationProvision(t *testing.T) { + tests := []struct { + name string + ca ClientAuthentication + wantErr bool + }{ + { + name: "specifying both 'CARaw' and 'TrustedCACerts' produces an error", + ca: ClientAuthentication{ + CARaw: json.RawMessage(`{"provider":"inline","trusted_ca_certs":["foo"]}`), + TrustedCACerts: []string{"foo"}, + }, + wantErr: true, + }, + { + name: "specifying both 'CARaw' and 'TrustedCACertPEMFiles' produces an error", + ca: ClientAuthentication{ + CARaw: json.RawMessage(`{"provider":"inline","trusted_ca_certs":["foo"]}`), + TrustedCACertPEMFiles: []string{"foo"}, + }, + wantErr: true, + }, + { + name: "setting 'TrustedCACerts' provisions the cert pool", + ca: ClientAuthentication{ + TrustedCACerts: []string{test_der_1}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.ca.provision(caddy.Context{}) + if (err != nil) != tt.wantErr { + t.Errorf("ClientAuthentication.provision() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if tt.ca.ca.CertPool() == nil { + t.Error("CertPool is nil, expected non-nil value") + } + } + }) + } +}