mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	Proxy: When connecting to websocket backend, reuse the connection isntead of starting a new one.
This commit is contained in:
		
							parent
							
								
									c4e65df262
								
							
						
					
					
						commit
						d534a2139f
					
				@ -183,9 +183,80 @@ var hopHeaders = []string{
 | 
			
		||||
 | 
			
		||||
type respUpdateFn func(resp *http.Response)
 | 
			
		||||
 | 
			
		||||
type hijackedConn struct {
 | 
			
		||||
	net.Conn
 | 
			
		||||
	hj *connHijackerTransport
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *hijackedConn) Read(b []byte) (n int, err error) {
 | 
			
		||||
	n, err = c.Conn.Read(b)
 | 
			
		||||
	c.hj.Replay = append(c.hj.Replay, b[:n]...)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *hijackedConn) Close() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type connHijackerTransport struct {
 | 
			
		||||
	*http.Transport
 | 
			
		||||
	Conn   net.Conn
 | 
			
		||||
	Replay []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
 | 
			
		||||
	transport := &http.Transport{
 | 
			
		||||
		Proxy: http.ProxyFromEnvironment,
 | 
			
		||||
		Dial: (&net.Dialer{
 | 
			
		||||
			Timeout:   30 * time.Second,
 | 
			
		||||
			KeepAlive: 30 * time.Second,
 | 
			
		||||
		}).Dial,
 | 
			
		||||
		TLSHandshakeTimeout: 10 * time.Second,
 | 
			
		||||
		TLSClientConfig:     &tls.Config{InsecureSkipVerify: true},
 | 
			
		||||
	}
 | 
			
		||||
	if base != nil {
 | 
			
		||||
		if baseTransport, ok := base.(*http.Transport); ok {
 | 
			
		||||
			transport.Proxy = baseTransport.Proxy
 | 
			
		||||
			transport.TLSClientConfig = baseTransport.TLSClientConfig
 | 
			
		||||
			transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
 | 
			
		||||
			transport.Dial = baseTransport.Dial
 | 
			
		||||
			transport.DialTLS = baseTransport.DialTLS
 | 
			
		||||
			transport.DisableKeepAlives = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
 | 
			
		||||
	oldDial := transport.Dial
 | 
			
		||||
	oldDialTLS := transport.DialTLS
 | 
			
		||||
	if oldDial == nil {
 | 
			
		||||
		oldDial = (&net.Dialer{
 | 
			
		||||
			Timeout:   30 * time.Second,
 | 
			
		||||
			KeepAlive: 30 * time.Second,
 | 
			
		||||
		}).Dial
 | 
			
		||||
	}
 | 
			
		||||
	hjTransport.Dial = func(network, addr string) (net.Conn, error) {
 | 
			
		||||
		c, err := oldDial(network, addr)
 | 
			
		||||
		hjTransport.Conn = c
 | 
			
		||||
		return &hijackedConn{c, hjTransport}, err
 | 
			
		||||
	}
 | 
			
		||||
	if oldDialTLS != nil {
 | 
			
		||||
		hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
 | 
			
		||||
			c, err := oldDialTLS(network, addr)
 | 
			
		||||
			hjTransport.Conn = c
 | 
			
		||||
			return &hijackedConn{c, hjTransport}, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return hjTransport
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestIsWebsocket(req *http.Request) bool {
 | 
			
		||||
	return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
 | 
			
		||||
	transport := p.Transport
 | 
			
		||||
	if transport == nil {
 | 
			
		||||
	if requestIsWebsocket(outreq) {
 | 
			
		||||
		transport = newConnHijackerTransport(transport)
 | 
			
		||||
	} else if transport == nil {
 | 
			
		||||
		transport = http.DefaultTransport
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -216,13 +287,22 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r
 | 
			
		||||
		}
 | 
			
		||||
		defer conn.Close()
 | 
			
		||||
 | 
			
		||||
		backendConn, err := net.Dial("tcp", outreq.URL.Host)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer backendConn.Close()
 | 
			
		||||
		var backendConn net.Conn
 | 
			
		||||
		if hj, ok := transport.(*connHijackerTransport); ok {
 | 
			
		||||
			backendConn = hj.Conn
 | 
			
		||||
			if _, err := conn.Write(hj.Replay); err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			bufferPool.Put(hj.Replay)
 | 
			
		||||
		} else {
 | 
			
		||||
			backendConn, err = net.Dial("tcp", outreq.URL.Host)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			defer backendConn.Close()
 | 
			
		||||
 | 
			
		||||
		outreq.Write(backendConn)
 | 
			
		||||
			outreq.Write(backendConn)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		go func() {
 | 
			
		||||
			io.Copy(backendConn, conn) // write tcp stream to backend.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user