reverseproxy: Optionally detach stream (websockets) from config lifecycle

This commit is contained in:
Francis Lavoie
2026-04-13 04:23:23 -04:00
parent 4430756d5c
commit 476fd0c077
12 changed files with 1396 additions and 166 deletions
@@ -99,6 +99,8 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// stream_buffer_size <size>
// stream_timeout <duration>
// stream_close_delay <duration>
// stream_retain_on_reload
// stream_log_skip_handshake
// verbose_logs
//
// # request manipulation
@@ -703,6 +705,18 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
h.StreamCloseDelay = caddy.Duration(dur)
}
case "stream_retain_on_reload":
if d.NextArg() {
return d.ArgErr()
}
h.StreamRetainOnReload = true
case "stream_log_skip_handshake":
if d.NextArg() {
return d.ArgErr()
}
h.StreamLogSkipHandshake = true
case "trusted_proxies":
for d.NextArg() {
if d.Val() == "private_ranges" {
@@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
hrc.isFinalized = true
// write the response
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger)
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr)
}
// CopyResponseHeadersHandler is a special HTTP handler which may
+79
View File
@@ -16,6 +16,10 @@ import (
var reverseProxyMetrics = struct {
once sync.Once
upstreamsHealthy *prometheus.GaugeVec
streamsActive *prometheus.GaugeVec
streamsTotal *prometheus.CounterVec
streamDuration *prometheus.HistogramVec
streamBytes *prometheus.CounterVec
logger *zap.Logger
}{}
@@ -23,6 +27,8 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
const ns, sub = "caddy", "reverse_proxy"
upstreamsLabels := []string{"upstream"}
streamResultLabels := []string{"upstream", "result"}
streamBytesLabels := []string{"upstream", "direction"}
reverseProxyMetrics.once.Do(func() {
reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: ns,
@@ -30,6 +36,31 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
Name: "upstreams_healthy",
Help: "Health status of reverse proxy upstreams.",
}, upstreamsLabels)
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: ns,
Subsystem: sub,
Name: "streams_active",
Help: "Number of currently active upgraded reverse proxy streams.",
}, upstreamsLabels)
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: ns,
Subsystem: sub,
Name: "streams_total",
Help: "Total number of upgraded reverse proxy streams by close result.",
}, streamResultLabels)
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: ns,
Subsystem: sub,
Name: "stream_duration_seconds",
Help: "Duration of upgraded reverse proxy streams by close result.",
Buckets: prometheus.DefBuckets,
}, streamResultLabels)
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: ns,
Subsystem: sub,
Name: "stream_bytes_total",
Help: "Total bytes proxied across upgraded reverse proxy streams.",
}, streamBytesLabels)
})
// duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because
@@ -42,10 +73,58 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamsActive,
NewCollector: reverseProxyMetrics.streamsActive,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamsTotal,
NewCollector: reverseProxyMetrics.streamsTotal,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamDuration,
NewCollector: reverseProxyMetrics.streamDuration,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamBytes,
NewCollector: reverseProxyMetrics.streamBytes,
}) {
panic(err)
}
reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics")
}
func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) {
labels := prometheus.Labels{"upstream": upstream}
reverseProxyMetrics.streamsActive.With(labels).Inc()
var once sync.Once
return func(result string, duration time.Duration, toBackend, fromBackend int64) {
once.Do(func() {
reverseProxyMetrics.streamsActive.With(labels).Dec()
reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc()
reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds())
if toBackend > 0 {
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend))
}
if fromBackend > 0 {
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend))
}
})
}
}
type metricsUpstreamsHealthyUpdater struct {
handler *Handler
}
@@ -0,0 +1,67 @@
package reverseproxy
import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) {
const upstream = "127.0.0.1:7443"
// Use fresh metric vectors for deterministic assertions in this unit test.
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
finish := trackActiveStream(upstream)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 {
t.Fatalf("active streams = %v, want 1", got)
}
finish("closed", 150*time.Millisecond, 1234, 4321)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 {
t.Fatalf("active streams = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 {
t.Fatalf("streams_total closed = %v, want 1", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 {
t.Fatalf("bytes to_upstream = %v, want 1234", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 {
t.Fatalf("bytes from_upstream = %v, want 4321", got)
}
// A second finish call should be ignored by the once guard.
finish("error", 1*time.Second, 111, 222)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 {
t.Fatalf("streams_total error = %v, want 0", got)
}
}
func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) {
const upstream = "127.0.0.1:9000"
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0)
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 {
t.Fatalf("bytes to_upstream = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 {
t.Fatalf("bytes from_upstream = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 {
t.Fatalf("streams_total timeout = %v, want 1", got)
}
}
+48 -17
View File
@@ -186,6 +186,18 @@ type Handler struct {
// by the previous config closing. Default: no delay.
StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"`
// If true, upgraded connections such as WebSockets are retained across
// config reloads when their upstream still exists in the new config.
// Connections using upstreams that are removed are closed during cleanup.
// By default this is false, preserving legacy behavior where upgraded
// connections are closed on reload (optionally delayed by stream_close_delay).
StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"`
// If true, suppresses the access log entry normally emitted when an
// upgraded stream handshake completes and the request unwinds. By default
// the handshake is still logged as a normal request with status 101.
StreamLogSkipHandshake bool `json:"stream_log_skip_handshake,omitempty"`
// If configured, rewrites the copy of the upstream request.
// Allows changing the request method and URI (path and query).
// Since the rewrite is applied to the copy, it does not persist
@@ -240,10 +252,9 @@ type Handler struct {
// Holds the handle_response Caddyfile tokens while adapting
handleResponseSegments []*caddyfile.Dispenser
// Stores upgraded requests (hijacked connections) for proper cleanup
connections map[io.ReadWriteCloser]openConnection
connectionsCloseTimer *time.Timer
connectionsMu *sync.Mutex
// Tracks hijacked/upgraded connections (WebSocket etc.) so they can be
// closed when their upstream is removed from the config.
tunnel *tunnelState
ctx caddy.Context
logger *zap.Logger
@@ -267,8 +278,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
h.events = eventAppIface.(*caddyevents.App)
h.ctx = ctx
h.logger = ctx.Logger()
h.connections = make(map[io.ReadWriteCloser]openConnection)
h.connectionsMu = new(sync.Mutex)
h.tunnel = newTunnelState(h.logger, time.Duration(h.StreamCloseDelay))
// warn about unsafe buffering config
if h.RequestBuffers == -1 || h.ResponseBuffers == -1 {
@@ -439,13 +449,29 @@ func (h *Handler) Provision(ctx caddy.Context) error {
// Cleanup cleans up the resources made by h.
func (h *Handler) Cleanup() error {
err := h.cleanupConnections()
// remove hosts from our config from the pool
for _, upstream := range h.Upstreams {
_, _ = hosts.Delete(upstream.String())
if !h.StreamRetainOnReload {
// Legacy behavior: close all upgraded connections on reload, either
// immediately or after StreamCloseDelay.
err := h.tunnel.cleanupConnections()
for _, upstream := range h.Upstreams {
_, _ = hosts.Delete(upstream.String())
}
return err
}
var err error
for _, upstream := range h.Upstreams {
// hosts.Delete returns deleted=true when the ref count reaches zero,
// meaning no other active config references this upstream. In that
// case close any tunnels proxying to it; otherwise let them survive
// to their natural end since the upstream is still in use.
deleted, _ := hosts.Delete(upstream.String())
if deleted {
if closeErr := h.tunnel.closeConnectionsForUpstream(upstream.String()); closeErr != nil && err == nil {
err = closeErr
}
}
}
return err
}
@@ -1092,10 +1118,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
// we use the original request here, so that any routes from 'next'
// see the original request rather than the proxy cloned request.
hrc := &handleResponseContext{
handler: h,
response: res,
start: start,
logger: logger,
handler: h,
response: res,
start: start,
logger: logger,
upstreamAddr: di.Upstream.String(),
}
ctx := origReq.Context()
ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc)
@@ -1125,7 +1152,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
}
// copy the response body and headers back to the upstream client
return h.finalizeResponse(rw, req, res, repl, start, logger)
return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String())
}
// finalizeResponse prepares and copies the response.
@@ -1136,11 +1163,12 @@ func (h *Handler) finalizeResponse(
repl *caddy.Replacer,
start time.Time,
logger *zap.Logger,
upstreamAddr string,
) error {
// deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
var wg sync.WaitGroup
h.handleUpgradeResponse(logger, &wg, rw, req, res)
h.handleUpgradeResponse(logger, &wg, rw, req, res, upstreamAddr)
wg.Wait()
return nil
}
@@ -1719,6 +1747,9 @@ type handleResponseContext struct {
// i.e. copied and closed, to make sure that it doesn't
// happen twice.
isFinalized bool
// upstreamAddr is the selected upstream address for this request.
upstreamAddr string
}
// proxyHandleResponseContextCtxKey is the context key for the active proxy handler
+288 -105
View File
@@ -26,6 +26,7 @@ import (
"io"
weakrand "math/rand/v2"
"mime"
"net"
"net/http"
"sync"
"time"
@@ -35,6 +36,7 @@ import (
"go.uber.org/zap/zapcore"
"golang.org/x/net/http/httpguts"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
@@ -57,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
return n, nil
}
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
@@ -90,13 +92,22 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
copyHeader(rw.Header(), res.Header)
normalizeWebsocketHeaders(rw.Header())
// Capture all h fields needed by the tunnel now, so that the Handler (h)
// is not referenced after this function returns (for HTTP/1.1 hijacked
// connections the tunnel runs in a detached goroutine).
tunnel := h.tunnel
bufferSize := h.StreamBufferSize
streamTimeout := time.Duration(h.StreamTimeout)
var (
conn io.ReadWriteCloser
brw *bufio.ReadWriter
isH2 bool
)
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can be removed for those backends
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
isH2 = true
req.Body = body
rw.Header().Del("Upgrade")
rw.Header().Del("Connection")
@@ -143,26 +154,24 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
}
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
backConnCloseCh := make(chan struct{})
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
start := time.Now()
defer func() {
conn.Close()
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
c.Write(zap.Duration("duration", time.Since(start)))
}
}()
// For H2 extended connect: close backConn when the request context is
// cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections
// we skip this because req.Context() may be cancelled when ServeHTTP
// returns early, which would prematurely close the backend connection.
if isH2 {
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
backConnCloseCh := make(chan struct{})
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
}
if err := brw.Flush(); err != nil {
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
@@ -184,13 +193,11 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
}
// Ensure the hijacked client connection, and the new connection established
// with the backend, are both closed in the event of a server shutdown. This
// is done by registering them. We also try to gracefully close connections
// we recognize as websockets.
// We need to make sure the client connection messages (i.e. to upstream)
// are masked, so we need to know whether the connection is considered the
// server or the client side of the proxy.
// Register both connections with the tunnel tracker. We also try to
// gracefully close connections we recognize as websockets. We need to make
// sure the client connection messages (i.e. to upstream) are masked, so we
// need to know whether the connection is considered the server or the
// client side of the proxy.
gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error {
if isWebsocket(req) {
return func() error {
@@ -199,43 +206,186 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
return nil
}
deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false))
deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true))
defer deleteFrontConn()
deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), upstreamAddr)
deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), upstreamAddr)
if h.StreamLogSkipHandshake {
caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true)
}
repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
repl.Set("http.reverse_proxy.upgraded", true)
finishMetrics := trackActiveStream(upstreamAddr)
start := time.Now()
if isH2 {
h.handleH2UpgradeTunnel(logger, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics)
} else {
h.handleDetachedUpgradeTunnel(logger, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics)
// Return immediately without touching wg. finalizeResponse's
// wg.Wait() returns at once since wg was never incremented.
}
}
func (h *Handler) handleH2UpgradeTunnel(
logger *zap.Logger,
wg *sync.WaitGroup,
conn io.ReadWriteCloser,
backConn io.ReadWriteCloser,
deleteFrontConn func(),
deleteBackConn func(),
bufferSize int,
streamTimeout time.Duration,
start time.Time,
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
) {
// H2 extended connect: ServeHTTP must block because rw and req.Body are
// only valid while the handler goroutine is running. Defers clean up
// when the select below fires and this function returns.
defer deleteBackConn()
spc := switchProtocolCopier{
user: conn,
backend: backConn,
wg: wg,
bufferSize: h.StreamBufferSize,
}
// setup the timeout if requested
var timeoutc <-chan time.Time
if h.StreamTimeout > 0 {
timer := time.NewTimer(time.Duration(h.StreamTimeout))
defer timer.Stop()
timeoutc = timer.C
}
defer deleteFrontConn()
var (
toBackend int64
fromBackend int64
result = "closed"
)
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
spc := switchProtocolCopier{
user: conn,
backend: backConn,
wg: wg,
bufferSize: bufferSize,
sent: &toBackend,
received: &fromBackend,
}
wg.Add(2)
var timeoutc <-chan time.Time
if streamTimeout > 0 {
timer := time.NewTimer(streamTimeout)
defer timer.Stop()
timeoutc = timer.C
}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
select {
case err := <-errc:
result = classifyStreamResult(err)
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
c.Write(zap.Error(err))
}
case time := <-timeoutc:
case t := <-timeoutc:
result = "timeout"
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", time))
c.Write(zap.Time("timeout", t))
}
}
// Close both ends to unblock the still-running copy goroutine,
// then wait for it so byte counts are final before metrics/logging.
conn.Close()
backConn.Close()
wg.Wait()
finishMetrics(result, time.Since(start), toBackend, fromBackend)
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
c.Write(
zap.Duration("duration", time.Since(start)),
zap.Int64("bytes_to_backend", toBackend),
zap.Int64("bytes_from_backend", fromBackend),
)
}
}
func (h *Handler) handleDetachedUpgradeTunnel(
logger *zap.Logger,
conn io.ReadWriteCloser,
backConn io.ReadWriteCloser,
deleteFrontConn func(),
deleteBackConn func(),
bufferSize int,
streamTimeout time.Duration,
start time.Time,
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
) {
// HTTP/1.1 hijacked connection: launch a detached goroutine so that
// ServeHTTP can return immediately, allowing the Handler to be GC'd
// after a config reload. The goroutine captures only tunnel (a small
// *tunnelState), logger, conn/backConn, and scalar config values.
go func() {
var (
toBackend int64
fromBackend int64
result = "closed"
)
defer deleteBackConn()
defer deleteFrontConn()
defer func() {
finishMetrics(result, time.Since(start), toBackend, fromBackend)
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
c.Write(
zap.Duration("duration", time.Since(start)),
zap.Int64("bytes_to_backend", toBackend),
zap.Int64("bytes_from_backend", fromBackend),
)
}
}()
var innerWg sync.WaitGroup
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
spc := switchProtocolCopier{
user: conn,
backend: backConn,
wg: &innerWg,
bufferSize: bufferSize,
sent: &toBackend,
received: &fromBackend,
}
innerWg.Add(2)
var timeoutc <-chan time.Time
if streamTimeout > 0 {
timer := time.NewTimer(streamTimeout)
defer timer.Stop()
timeoutc = timer.C
}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
select {
case err := <-errc:
result = classifyStreamResult(err)
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
c.Write(zap.Error(err))
}
case t := <-timeoutc:
result = "timeout"
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", t))
}
}
// Close both ends to unblock the still-running copy goroutine,
// then wait for it to finish so byte counts are accurate before
// the deferred log fires.
conn.Close()
backConn.Close()
innerWg.Wait()
}()
}
func classifyStreamResult(err error) string {
if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return "closed"
}
return "error"
}
// flushInterval returns the p.FlushInterval value, conditionally
@@ -375,75 +525,86 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
}
}
// registerConnection holds onto conn so it can be closed in the event
// of a server shutdown. This is useful because hijacked connections or
// connections dialed to backends don't close when server is shut down.
// The caller should call the returned delete() function when the
// connection is done to remove it from memory.
func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) {
h.connectionsMu.Lock()
h.connections[conn] = openConnection{conn, gracefulClose}
h.connectionsMu.Unlock()
return func() {
h.connectionsMu.Lock()
delete(h.connections, conn)
// if there is no connection left before the connections close timer fires
if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
// we release the timer that holds the reference to Handler
if (*h.connectionsCloseTimer).Stop() {
h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
}
h.connectionsCloseTimer = nil
}
h.connectionsMu.Unlock()
// openConnection maps an open connection to an optional function for graceful
// close and records which upstream address the connection is proxying to.
type openConnection struct {
conn io.ReadWriteCloser
gracefulClose func() error
upstream string
}
// tunnelState tracks hijacked/upgraded connections for selective cleanup.
type tunnelState struct {
connections map[io.ReadWriteCloser]openConnection
closeTimer *time.Timer
closeDelay time.Duration
mu sync.Mutex
logger *zap.Logger
}
func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState {
return &tunnelState{
connections: make(map[io.ReadWriteCloser]openConnection),
closeDelay: closeDelay,
logger: logger,
}
}
// closeConnections immediately closes all hijacked connections (both to client and backend).
func (h *Handler) closeConnections() error {
var err error
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
// registerConnection stores conn in the tracking map. The caller must invoke
// the returned del func when the connection is done.
func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, upstream string) (del func()) {
ts.mu.Lock()
ts.connections[conn] = openConnection{conn, gracefulClose, upstream}
ts.mu.Unlock()
return func() {
ts.mu.Lock()
delete(ts.connections, conn)
if len(ts.connections) == 0 && ts.closeTimer != nil {
if ts.closeTimer.Stop() {
ts.logger.Debug("stopped streaming connections close timer - all connections are already closed")
}
ts.closeTimer = nil
}
ts.mu.Unlock()
}
}
for _, oc := range h.connections {
// closeConnections closes all tracked connections.
func (ts *tunnelState) closeConnections() error {
var err error
ts.mu.Lock()
defer ts.mu.Unlock()
for _, oc := range ts.connections {
if oc.gracefulClose != nil {
// this is potentially blocking while we have the lock on the connections
// map, but that should be OK since the server has in theory shut down
// and we are no longer using the connections map
gracefulErr := oc.gracefulClose()
if gracefulErr != nil && err == nil {
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
err = gracefulErr
}
}
closeErr := oc.conn.Close()
if closeErr != nil && err == nil {
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
err = closeErr
}
}
return err
}
// cleanupConnections closes hijacked connections.
// Depending on the value of StreamCloseDelay it does that either immediately
// or sets up a timer that will do that later.
func (h *Handler) cleanupConnections() error {
if h.StreamCloseDelay == 0 {
return h.closeConnections()
// cleanupConnections closes upgraded connections. Depending on closeDelay it
// does that either immediately or after a timer.
func (ts *tunnelState) cleanupConnections() error {
if ts.closeDelay == 0 {
return ts.closeConnections()
}
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
// the handler is shut down, no new connection can appear,
// so we can skip setting up the timer when there are no connections
if len(h.connections) > 0 {
delay := time.Duration(h.StreamCloseDelay)
h.connectionsCloseTimer = time.AfterFunc(delay, func() {
if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
ts.mu.Lock()
defer ts.mu.Unlock()
if len(ts.connections) > 0 {
delay := ts.closeDelay
ts.closeTimer = time.AfterFunc(delay, func() {
if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
c.Write(zap.Duration("delay", delay))
}
err := h.closeConnections()
err := ts.closeConnections()
if err != nil {
if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil {
if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil {
c.Write(
zap.Error(err),
zap.Duration("delay", delay),
@@ -567,11 +728,26 @@ func isWebsocket(r *http.Request) bool {
httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket")
}
// openConnection maps an open connection to
// an optional function for graceful close.
type openConnection struct {
conn io.ReadWriteCloser
gracefulClose func() error
// closeConnectionsForUpstream closes all tracked connections that were
// established to the given upstream address.
func (ts *tunnelState) closeConnectionsForUpstream(addr string) error {
var err error
ts.mu.Lock()
defer ts.mu.Unlock()
for _, oc := range ts.connections {
if oc.upstream != addr {
continue
}
if oc.gracefulClose != nil {
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
err = gracefulErr
}
}
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
err = closeErr
}
}
return err
}
type maxLatencyWriter struct {
@@ -642,16 +818,23 @@ type switchProtocolCopier struct {
user, backend io.ReadWriteCloser
wg *sync.WaitGroup
bufferSize int
// sent and received accumulate byte counts for each direction.
// They are written before wg.Done() and read after wg.Wait(), so no
// additional synchronization is needed beyond the WaitGroup barrier.
sent *int64 // bytes copied to backend; must be non-nil
received *int64 // bytes copied from backend; must be non-nil
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
_, err := io.CopyBuffer(c.user, c.backend, c.buffer())
n, err := io.CopyBuffer(c.user, c.backend, c.buffer())
*c.received = n
errc <- err
c.wg.Done()
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
_, err := io.CopyBuffer(c.backend, c.user, c.buffer())
n, err := io.CopyBuffer(c.backend, c.user, c.buffer())
*c.sent = n
errc <- err
c.wg.Done()
}
@@ -7,8 +7,10 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
)
func TestHandlerCopyResponse(t *testing.T) {
@@ -41,12 +43,15 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) {
var wg sync.WaitGroup
var errc = make(chan error, 1)
var dst bytes.Buffer
var sent, received int64
copier := switchProtocolCopier{
user: nopReadWriteCloser{Reader: strings.NewReader("hello")},
backend: nopReadWriteCloser{Writer: &dst},
wg: &wg,
bufferSize: 7,
sent: &sent,
received: &received,
}
buf := copier.buffer()
@@ -80,3 +85,132 @@ type nopReadWriteCloser struct {
}
func (nopReadWriteCloser) Close() error { return nil }
type trackingReadWriteCloser struct {
closed chan struct{}
one sync.Once
}
func newTrackingReadWriteCloser() *trackingReadWriteCloser {
return &trackingReadWriteCloser{closed: make(chan struct{})}
}
func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF }
func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil }
func (c *trackingReadWriteCloser) Close() error {
c.one.Do(func() {
close(c.closed)
})
return nil
}
func (c *trackingReadWriteCloser) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) {
ts := newTunnelState(caddy.Log(), 0)
connA := newTrackingReadWriteCloser()
connB := newTrackingReadWriteCloser()
ts.registerConnection(connA, nil, "a")
ts.registerConnection(connB, nil, "b")
h := &Handler{
tunnel: ts,
StreamRetainOnReload: false,
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if !connA.isClosed() || !connB.isClosed() {
t.Fatalf("legacy cleanup should close all upgraded connections")
}
}
func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) {
ts := newTunnelState(caddy.Log(), 40*time.Millisecond)
conn := newTrackingReadWriteCloser()
ts.registerConnection(conn, nil, "a")
h := &Handler{
tunnel: ts,
StreamRetainOnReload: false,
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if conn.isClosed() {
t.Fatal("connection should not close immediately when stream_close_delay is set")
}
select {
case <-conn.closed:
case <-time.After(500 * time.Millisecond):
t.Fatal("connection did not close after stream_close_delay elapsed")
}
}
func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) {
const upstreamA = "upstream-a"
const upstreamB = "upstream-b"
// Simulate old+new configs both referencing upstreamA (refcount 2),
// while upstreamB is only referenced by the old config (refcount 1).
hosts.LoadOrStore(upstreamA, struct{}{})
hosts.LoadOrStore(upstreamA, struct{}{})
hosts.LoadOrStore(upstreamB, struct{}{})
t.Cleanup(func() {
_, _ = hosts.Delete(upstreamA)
_, _ = hosts.Delete(upstreamA)
_, _ = hosts.Delete(upstreamB)
})
ts := newTunnelState(caddy.Log(), 0)
connA := newTrackingReadWriteCloser()
connB := newTrackingReadWriteCloser()
ts.registerConnection(connA, nil, upstreamA)
ts.registerConnection(connB, nil, upstreamB)
h := &Handler{
tunnel: ts,
StreamRetainOnReload: true,
Upstreams: UpstreamPool{
&Upstream{Dial: upstreamA},
&Upstream{Dial: upstreamB},
},
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if connA.isClosed() {
t.Fatal("connection for retained upstream should remain open")
}
if !connB.isClosed() {
t.Fatal("connection for removed upstream should be closed")
}
}
func TestHandlerUnmarshalCaddyfileStreamLogSkipHandshake(t *testing.T) {
d := caddyfile.NewTestDispenser(`
reverse_proxy localhost:9000 {
stream_log_skip_handshake
}
`)
var h Handler
if err := h.UnmarshalCaddyfile(d); err != nil {
t.Fatalf("UnmarshalCaddyfile() error = %v", err)
}
if !h.StreamLogSkipHandshake {
t.Fatal("expected stream_log_skip_handshake to enable StreamLogSkipHandshake")
}
}