mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	proxy: Fixed support for TLS verification of WebSocket connections
This commit is contained in:
		
							parent
							
								
									153d4a5ac6
								
							
						
					
					
						commit
						b857265f9c
					
				@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
 | 
			
		||||
		MaxIdleConnsPerHost: -1,
 | 
			
		||||
	}
 | 
			
		||||
	if b, _ := base.(*http.Transport); b != nil {
 | 
			
		||||
		tlsClientConfig := b.TLSClientConfig
 | 
			
		||||
		if tlsClientConfig.NextProtos != nil {
 | 
			
		||||
			tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
 | 
			
		||||
			tlsClientConfig.NextProtos = nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		t.Proxy = b.Proxy
 | 
			
		||||
		t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
 | 
			
		||||
		t.TLSClientConfig.NextProtos = nil
 | 
			
		||||
		t.TLSClientConfig = tlsClientConfig
 | 
			
		||||
		t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
 | 
			
		||||
		t.Dial = b.Dial
 | 
			
		||||
		t.DialTLS = b.DialTLS
 | 
			
		||||
@ -363,20 +368,16 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
 | 
			
		||||
 | 
			
		||||
	dial := getTransportDial(t)
 | 
			
		||||
	dialTLS := getTransportDialTLS(t)
 | 
			
		||||
 | 
			
		||||
	t.Dial = func(network, addr string) (net.Conn, error) {
 | 
			
		||||
		c, err := dial(network, addr)
 | 
			
		||||
		hj.Conn = c
 | 
			
		||||
		return &hijackedConn{c, hj}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if dialTLS != nil {
 | 
			
		||||
	t.DialTLS = func(network, addr string) (net.Conn, error) {
 | 
			
		||||
		c, err := dialTLS(network, addr)
 | 
			
		||||
		hj.Conn = c
 | 
			
		||||
		return &hijackedConn{c, hj}, err
 | 
			
		||||
	}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return hj
 | 
			
		||||
}
 | 
			
		||||
@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e
 | 
			
		||||
	return defaultDialer.Dial
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
 | 
			
		||||
// getTransportDial always returns a TLS Dialer
 | 
			
		||||
// and defaults to the existing t.DialTLS.
 | 
			
		||||
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
 | 
			
		||||
	if t.DialTLS != nil {
 | 
			
		||||
		return t.DialTLS
 | 
			
		||||
	}
 | 
			
		||||
	if t.TLSClientConfig == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// newConnHijackerTransport will modify t.Dial after calling this method
 | 
			
		||||
	// => Create a backup reference.
 | 
			
		||||
	plainDial := getTransportDial(t)
 | 
			
		||||
 | 
			
		||||
	// The following DialTLS implementation stems from the Go stdlib and
 | 
			
		||||
	// is identical to what happens if DialTLS is not provided.
 | 
			
		||||
	// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
 | 
			
		||||
	return func(network, addr string) (net.Conn, error) {
 | 
			
		||||
		plainConn, err := plainDial(network, addr)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tlsConn := tls.Client(plainConn, t.TLSClientConfig)
 | 
			
		||||
		tlsClientConfig := t.TLSClientConfig
 | 
			
		||||
		if tlsClientConfig == nil {
 | 
			
		||||
			tlsClientConfig = &tls.Config{}
 | 
			
		||||
		}
 | 
			
		||||
		if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
 | 
			
		||||
			tlsClientConfig.ServerName = stripPort(addr)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tlsConn := tls.Client(plainConn, tlsClientConfig)
 | 
			
		||||
		errc := make(chan error, 2)
 | 
			
		||||
		var timer *time.Timer
 | 
			
		||||
		if d := t.TLSHandshakeTimeout; d != 0 {
 | 
			
		||||
@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
 | 
			
		||||
			plainConn.Close()
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if !t.TLSClientConfig.InsecureSkipVerify {
 | 
			
		||||
			serverName := t.TLSClientConfig.ServerName
 | 
			
		||||
			if serverName == "" {
 | 
			
		||||
				serverName = addr
 | 
			
		||||
				idx := strings.LastIndex(serverName, ":")
 | 
			
		||||
				if idx != -1 {
 | 
			
		||||
					serverName = serverName[:idx]
 | 
			
		||||
		if !tlsClientConfig.InsecureSkipVerify {
 | 
			
		||||
			hostname := tlsClientConfig.ServerName
 | 
			
		||||
			if hostname == "" {
 | 
			
		||||
				hostname = stripPort(addr)
 | 
			
		||||
			}
 | 
			
		||||
			}
 | 
			
		||||
			if err := tlsConn.VerifyHostname(serverName); err != nil {
 | 
			
		||||
			if err := tlsConn.VerifyHostname(hostname); err != nil {
 | 
			
		||||
				plainConn.Close()
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// stripPort returns address without its port if it has one and
 | 
			
		||||
// works with IP addresses as well as hostnames formatted as host:port.
 | 
			
		||||
//
 | 
			
		||||
// IPv6 addresses (excluding the port) must be enclosed in
 | 
			
		||||
// square brackets similar to the requirements of Go's stdlib.
 | 
			
		||||
func stripPort(address string) string {
 | 
			
		||||
	// Keep in mind that the address might be a IPv6 address
 | 
			
		||||
	// and thus contain a colon, but not have a port.
 | 
			
		||||
	portIdx := strings.LastIndex(address, ":")
 | 
			
		||||
	ipv6Idx := strings.LastIndex(address, "]")
 | 
			
		||||
	if portIdx > ipv6Idx {
 | 
			
		||||
		address = address[:portIdx]
 | 
			
		||||
	}
 | 
			
		||||
	return address
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type tlsHandshakeTimeoutError struct{}
 | 
			
		||||
 | 
			
		||||
func (tlsHandshakeTimeoutError) Timeout() bool   { return true }
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user