From aa3d20be3ee451af9465470a28937690104e9422 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Mon, 28 Apr 2025 23:14:09 +0800 Subject: [PATCH] reverseproxy: Use DialTLSContext if ServerName has placeholder (#6955) Co-authored-by: Matt Holt --- .../caddyhttp/reverseproxy/httptransport.go | 70 +++++++++---------- 1 file changed, 32 insertions(+), 38 deletions(-) diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index fec531434..23af0b3f9 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -382,6 +382,36 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e if err != nil { return nil, fmt.Errorf("making TLS client config: %v", err) } + + // servername has a placeholder, so we need to replace it + if strings.Contains(h.TLS.ServerName, "{") { + rt.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + // reuses the dialer from above to establish a plaintext connection + conn, err := dialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + // but add our own handshake logic + repl := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + tlsConfig := rt.TLSClientConfig.Clone() + tlsConfig.ServerName = repl.ReplaceAll(tlsConfig.ServerName, "") + tlsConn := tls.Client(conn, tlsConfig) + + // complete the handshake before returning the connection + if rt.TLSHandshakeTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, rt.TLSHandshakeTimeout) + defer cancel() + } + err = tlsConn.HandshakeContext(ctx) + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return tlsConn, nil + } + } } if h.KeepAlive != nil { @@ -453,45 +483,9 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e return rt, nil } -// replaceTLSServername checks TLS servername to see if it needs replacing -// if it does need replacing, it creates a new cloned HTTPTransport object to avoid any races -// and does the replacing of the TLS servername on that and returns the new object -// if no replacement is necessary it returns the original -func (h *HTTPTransport) replaceTLSServername(repl *caddy.Replacer) *HTTPTransport { - // check whether we have TLS and need to replace the servername in the TLSClientConfig - if h.TLSEnabled() && strings.Contains(h.TLS.ServerName, "{") { - // make a new h, "copy" the parts we don't need to touch, add a new *tls.Config and replace servername - newtransport := &HTTPTransport{ - Resolver: h.Resolver, - TLS: h.TLS, - KeepAlive: h.KeepAlive, - Compression: h.Compression, - MaxConnsPerHost: h.MaxConnsPerHost, - DialTimeout: h.DialTimeout, - FallbackDelay: h.FallbackDelay, - ResponseHeaderTimeout: h.ResponseHeaderTimeout, - ExpectContinueTimeout: h.ExpectContinueTimeout, - MaxResponseHeaderSize: h.MaxResponseHeaderSize, - WriteBufferSize: h.WriteBufferSize, - ReadBufferSize: h.ReadBufferSize, - Versions: h.Versions, - Transport: h.Transport.Clone(), - h2cTransport: h.h2cTransport, - } - newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "") - return newtransport - } - - return h -} - // RoundTrip implements http.RoundTripper. func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Try to replace TLS servername if needed - repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) - transport := h.replaceTLSServername(repl) - - transport.SetScheme(req) + h.SetScheme(req) // use HTTP/3 if enabled (TODO: This is EXPERIMENTAL) if h.h3Transport != nil { @@ -507,7 +501,7 @@ func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return h.h2cTransport.RoundTrip(req) } - return transport.Transport.RoundTrip(req) + return h.Transport.RoundTrip(req) } // SetScheme ensures that the outbound request req