diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 57962ddad..6b5394cca 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -322,6 +322,10 @@ func (h *Handler) Provision(ctx caddy.Context) error { } } + if h.StreamRetainOnReload { + registerDetachedTunnelStates(h.tunnel) + } + // warn about unsafe buffering config if h.RequestBuffers == -1 || h.ResponseBuffers == -1 { h.logger.Warn("UNLIMITED BUFFERING: buffering is enabled without any cap on buffer size, which can result in OOM crashes") @@ -522,19 +526,40 @@ func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger { return caddy.Log().Named(name) } -// Cleanup cleans up the resources made by h. -func (h *Handler) Cleanup() error { - if !h.StreamRetainOnReload { - // Legacy behavior: close all upgraded connections on reload, either - // immediately or after StreamCloseDelay. - err := h.tunnel.cleanupConnections() - for _, upstream := range h.Upstreams { - _, _ = hosts.Delete(upstream.String()) - } - return err - } +var ( + detachedTunnelStates = make(map[*tunnelState]struct{}) + detachedTunnelStatesMu sync.Mutex +) + +func registerDetachedTunnelStates(ts *tunnelState) { + detachedTunnelStatesMu.Lock() + defer detachedTunnelStatesMu.Unlock() + detachedTunnelStates[ts] = struct{}{} +} + +func notifyDetachedTunnelStatesOfUpstreamRemoval(upstream string, self *tunnelState) error { + detachedTunnelStatesMu.Lock() + defer detachedTunnelStatesMu.Unlock() var err error + for tunnel := range detachedTunnelStates { + if closeErr := tunnel.closeConnectionsForUpstream(upstream); closeErr != nil && tunnel == self && err == nil { + err = closeErr + } + } + return err +} + +func unregisterDetachedTunnelStates(ts *tunnelState) { + detachedTunnelStatesMu.Lock() + defer detachedTunnelStatesMu.Unlock() + delete(detachedTunnelStates, ts) +} + +// Cleanup cleans up the resources made by h. +func (h *Handler) Cleanup() error { + // even if StreamRetainOnReload is true, extended connect websockets may still be running + err := h.tunnel.cleanupAttachedConnections() for _, upstream := range h.Upstreams { // hosts.Delete returns deleted=true when the ref count reaches zero, // meaning no other active config references this upstream. In that @@ -542,7 +567,7 @@ func (h *Handler) Cleanup() error { // to their natural end since the upstream is still in use. deleted, _ := hosts.Delete(upstream.String()) if deleted { - if closeErr := h.tunnel.closeConnectionsForUpstream(upstream.String()); closeErr != nil && err == nil { + if closeErr := notifyDetachedTunnelStatesOfUpstreamRemoval(upstream.String(), h.tunnel); closeErr != nil && err == nil { err = closeErr } } diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 3835653df..ccb056d58 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -197,8 +197,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit } return nil } - deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), upstreamAddr) - deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), upstreamAddr) + deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), detached, upstreamAddr) + deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), detached, upstreamAddr) if h.streamLogsSkipHandshake() { caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true) } @@ -442,6 +442,7 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za type openConnection struct { conn io.ReadWriteCloser gracefulClose func() error + detached bool upstream string } @@ -464,29 +465,36 @@ func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState { // registerConnection stores conn in the tracking map. The caller must invoke // the returned del func when the connection is done. -func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, upstream string) (del func()) { +func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) { ts.mu.Lock() - ts.connections[conn] = openConnection{conn, gracefulClose, upstream} + ts.connections[conn] = openConnection{conn, gracefulClose, detached, upstream} ts.mu.Unlock() return func() { ts.mu.Lock() delete(ts.connections, conn) - if len(ts.connections) == 0 && ts.closeTimer != nil { - if ts.closeTimer.Stop() { - ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") + if len(ts.connections) == 0 { + unregisterDetachedTunnelStates(ts) + if ts.closeTimer != nil { + if ts.closeTimer.Stop() { + ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") + } + ts.closeTimer = nil } - ts.closeTimer = nil } ts.mu.Unlock() } } -// closeConnections closes all tracked connections. -func (ts *tunnelState) closeConnections() error { +// closeAttachedConnections closes all tracked attached connections. +func (ts *tunnelState) closeAttachedConnections() error { var err error ts.mu.Lock() defer ts.mu.Unlock() for _, oc := range ts.connections { + // detached connections are only closed when the upstream is gone from the config + if oc.detached { + continue + } if oc.gracefulClose != nil { if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { err = gracefulErr @@ -499,11 +507,11 @@ func (ts *tunnelState) closeConnections() error { return err } -// cleanupConnections closes upgraded connections. Depending on closeDelay it +// cleanupAttachedConnections closes upgraded attached connections. Depending on closeDelay it // does that either immediately or after a timer. -func (ts *tunnelState) cleanupConnections() error { +func (ts *tunnelState) cleanupAttachedConnections() error { if ts.closeDelay == 0 { - return ts.closeConnections() + return ts.closeAttachedConnections() } ts.mu.Lock() @@ -514,7 +522,7 @@ func (ts *tunnelState) cleanupConnections() error { if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { c.Write(zap.Duration("delay", delay)) } - err := ts.closeConnections() + err := ts.closeAttachedConnections() if err != nil { if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil { c.Write(