From 41aee97386ae8a52231ab8ff7790e49cff802d77 Mon Sep 17 00:00:00 2001 From: Zen Dodd Date: Fri, 24 Apr 2026 05:33:41 +1000 Subject: [PATCH] core: propagate ECH keys to the QUIC listener (#7670) --- listeners.go | 15 +++++++++++- listeners_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/listeners.go b/listeners.go index 84ebaaaba..ace0215b0 100644 --- a/listeners.go +++ b/listeners.go @@ -462,7 +462,10 @@ func (na NetworkAddress) ListenQUIC(ctx context.Context, portOffset uint, config sqs := newSharedQUICState(tlsConf) // http3.ConfigureTLSConfig only uses this field and tls App sets this field as well //nolint:gosec - quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient} + quicTlsConfig := &tls.Config{ + GetConfigForClient: sqs.getConfigForClient, + GetEncryptedClientHelloKeys: sqs.getEncryptedClientHelloKeys, + } // Require clients to verify their source address when we're handling more than 1000 handshakes per second. // TODO: make tunable? limiter := rate.NewLimiter(1000, 1000) @@ -540,6 +543,16 @@ func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Co return sqs.activeTlsConf.GetConfigForClient(ch) } +// getEncryptedClientHelloKeys is used as tls.Config's GetEncryptedClientHelloKeys field. +func (sqs *sharedQUICState) getEncryptedClientHelloKeys(ch *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) { + sqs.rmu.RLock() + defer sqs.rmu.RUnlock() + if sqs.activeTlsConf.GetEncryptedClientHelloKeys == nil { + return nil, nil + } + return sqs.activeTlsConf.GetEncryptedClientHelloKeys(ch) +} + // addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc // so that when cancelled, the active tls.Config will change func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelCauseFunc) { diff --git a/listeners_test.go b/listeners_test.go index a4cadd3aa..7bbaca1f9 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -15,6 +15,7 @@ package caddy import ( + "crypto/tls" "reflect" "testing" @@ -175,6 +176,63 @@ func TestJoinNetworkAddress(t *testing.T) { } } +func TestSharedQUICStateGetEncryptedClientHelloKeys(t *testing.T) { + hello := &tls.ClientHelloInfo{ServerName: "example.com"} + initialKeys := []tls.EncryptedClientHelloKey{{Config: []byte("initial"), PrivateKey: []byte("initial-key")}} + updatedKeys := []tls.EncryptedClientHelloKey{{Config: []byte("updated"), PrivateKey: []byte("updated-key")}} + + initialConfig := &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + }, + GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) { + return initialKeys, nil + }, + } + + sqs := newSharedQUICState(initialConfig) + + keys, err := sqs.getEncryptedClientHelloKeys(hello) + if err != nil { + t.Fatalf("getting initial ECH keys: %v", err) + } + if !reflect.DeepEqual(keys, initialKeys) { + t.Fatalf("unexpected initial ECH keys: got %#v, want %#v", keys, initialKeys) + } + + updatedConfig := &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + }, + GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) { + return updatedKeys, nil + }, + } + + _, cancel := sqs.addState(updatedConfig) + sqs.rmu.Lock() + sqs.activeTlsConf = updatedConfig + sqs.rmu.Unlock() + + keys, err = sqs.getEncryptedClientHelloKeys(hello) + if err != nil { + t.Fatalf("getting updated ECH keys: %v", err) + } + if !reflect.DeepEqual(keys, updatedKeys) { + t.Fatalf("unexpected updated ECH keys: got %#v, want %#v", keys, updatedKeys) + } + + cancel(nil) + + keys, err = sqs.getEncryptedClientHelloKeys(hello) + if err != nil { + t.Fatalf("getting restored ECH keys: %v", err) + } + if !reflect.DeepEqual(keys, initialKeys) { + t.Fatalf("unexpected restored ECH keys: got %#v, want %#v", keys, initialKeys) + } +} + func TestParseNetworkAddress(t *testing.T) { for i, tc := range []struct { input string