From 476fd0c077fe38c359d51f3433b604ac2a144559 Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Mon, 13 Apr 2026 04:23:23 -0400 Subject: [PATCH] reverseproxy: Optionally detach stream (websockets) from config lifecycle --- .../reverseproxy_upgrade_handlers_test.go | 130 +++++ .../integration/stream_reload_stress_test.go | 487 ++++++++++++++++++ modules/caddyhttp/encode/encode.go | 18 +- modules/caddyhttp/responsewriter.go | 80 +-- modules/caddyhttp/responsewriter_test.go | 93 ++++ modules/caddyhttp/reverseproxy/caddyfile.go | 14 + .../caddyhttp/reverseproxy/copyresponse.go | 2 +- modules/caddyhttp/reverseproxy/metrics.go | 79 +++ .../caddyhttp/reverseproxy/metrics_test.go | 67 +++ .../caddyhttp/reverseproxy/reverseproxy.go | 65 ++- modules/caddyhttp/reverseproxy/streaming.go | 393 ++++++++++---- .../caddyhttp/reverseproxy/streaming_test.go | 134 +++++ 12 files changed, 1396 insertions(+), 166 deletions(-) create mode 100644 caddytest/integration/reverseproxy_upgrade_handlers_test.go create mode 100644 caddytest/integration/stream_reload_stress_test.go create mode 100644 modules/caddyhttp/reverseproxy/metrics_test.go diff --git a/caddytest/integration/reverseproxy_upgrade_handlers_test.go b/caddytest/integration/reverseproxy_upgrade_handlers_test.go new file mode 100644 index 000000000..dda93db0e --- /dev/null +++ b/caddytest/integration/reverseproxy_upgrade_handlers_test.go @@ -0,0 +1,130 @@ +package integration + +import ( + "bufio" + "fmt" + "io" + "net" + "net/textproto" + "strings" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +func TestReverseProxyUpgradeWithEncode(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + route { + encode gzip + reverse_proxy %s + } +} +`, backend.addr), "caddyfile") + + client := newUpgradedStreamClientWithHeaders(t, map[string]string{ + "Accept-Encoding": "gzip", + }) + defer client.Close() + + if err := client.echo("encode-upgrade\n"); err != nil { + t.Fatalf("upgraded stream echo through encode failed: %v", err) + } +} + +func TestReverseProxyUpgradeWithInterceptHandleResponse(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + route { + intercept { + @upgrade status 101 + handle_response @upgrade { + respond "should-not-run" + } + } + reverse_proxy %s + } +} +`, backend.addr), "caddyfile") + + client := newUpgradedStreamClientWithHeaders(t, nil) + defer client.Close() + + if err := client.echo("intercept-upgrade\n"); err != nil { + t.Fatalf("upgraded stream echo through intercept failed: %v", err) + } +} + +func newUpgradedStreamClientWithHeaders(t *testing.T, extraHeaders map[string]string) *upgradedStreamClient { + t.Helper() + + conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second) + if err != nil { + t.Fatalf("dialing caddy: %v", err) + } + + requestLines := []string{ + "GET /upgrade HTTP/1.1", + "Host: localhost:9080", + "Connection: Upgrade", + "Upgrade: stress-stream", + } + for k, v := range extraHeaders { + requestLines = append(requestLines, k+": "+v) + } + requestLines = append(requestLines, "", "") + + if _, err := io.WriteString(conn, strings.Join(requestLines, "\r\n")); err != nil { + _ = conn.Close() + t.Fatalf("writing upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + tproto := textproto.NewReader(reader) + statusLine, err := tproto.ReadLine() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + _ = conn.Close() + t.Fatalf("unexpected upgrade status: %s", statusLine) + } + + headers, err := tproto.ReadMIMEHeader() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade headers: %v", err) + } + if !strings.EqualFold(headers.Get("Connection"), "Upgrade") { + _ = conn.Close() + t.Fatalf("unexpected upgrade response headers: %v", headers) + } + + return &upgradedStreamClient{conn: conn, reader: reader} +} diff --git a/caddytest/integration/stream_reload_stress_test.go b/caddytest/integration/stream_reload_stress_test.go new file mode 100644 index 000000000..cd0b354ca --- /dev/null +++ b/caddytest/integration/stream_reload_stress_test.go @@ -0,0 +1,487 @@ +package integration + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "net/textproto" + "os" + "runtime" + "runtime/debug" + "runtime/pprof" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +// stressCloseDelay is the stream_close_delay used for the close_delay scenario. +// Long enough to outlast all test reloads; short enough to keep total test time reasonable. +const stressCloseDelay = 3 * time.Second + +func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { + tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{ + LoadRequestTimeout: 30 * time.Second, + TestRequestTimeout: 30 * time.Second, + }) + + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + // Three scenarios, each sequential so they don't share Caddy state: + // + // legacy – no delay, close on reload immediately (old default) + // close_delay – stream_close_delay, the old "keep-alive workaround" + // retain – stream_retain_on_reload, the new explicit retain flag + // + // Reloads are spread across time and interleaved with echo-checks so + // stream health is exercised at each reload boundary, not only at the end. + legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0) + closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay) + retain := runReloadStress(t, tester, backend.addr, "retain", true, 0) + + if legacy.aliveAfterReloads != 0 { + t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads) + } + if closeDelay.aliveBeforeDelayExpiry == 0 { + t.Fatalf("close_delay mode: all streams closed before delay expired (expected them alive)") + } + if closeDelay.aliveAfterReloads != 0 { + t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads) + } + if retain.aliveAfterReloads != retain.streamCount { + t.Fatalf("retain mode kept %d/%d upgraded streams alive after reloads", retain.aliveAfterReloads, retain.streamCount) + } + + t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(legacy.beforeReload.HeapInuse), + formatBytes(legacy.midReload.HeapInuse), + formatBytes(legacy.afterReload.HeapInuse), + formatBytesDiff(legacy.beforeReload.HeapInuse, legacy.afterReload.HeapInuse), + legacy.beforeReload.HeapObjects, legacy.afterReload.HeapObjects, + legacy.beforeReload.handlerFrames, legacy.afterReload.handlerFrames, + ) + t.Logf("close_delay heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(closeDelay.beforeReload.HeapInuse), + formatBytes(closeDelay.midReload.HeapInuse), + formatBytes(closeDelay.afterReload.HeapInuse), + formatBytesDiff(closeDelay.beforeReload.HeapInuse, closeDelay.afterReload.HeapInuse), + closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects, + closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames, + ) + t.Logf("retain heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(retain.beforeReload.HeapInuse), + formatBytes(retain.midReload.HeapInuse), + formatBytes(retain.afterReload.HeapInuse), + formatBytesDiff(retain.beforeReload.HeapInuse, retain.afterReload.HeapInuse), + retain.beforeReload.HeapObjects, retain.afterReload.HeapObjects, + retain.beforeReload.handlerFrames, retain.afterReload.handlerFrames, + ) +} + +type stressRunResult struct { + streamCount int + aliveAfterReloads int + aliveBeforeDelayExpiry int // only meaningful for close_delay mode + beforeReload heapSnapshot + midReload heapSnapshot // after all reloads, before delay expiry clean-up + afterReload heapSnapshot // after all streams have been fully cleaned up +} + +type heapSnapshot struct { + HeapInuse uint64 + HeapObjects uint64 + handlerFrames int + profileBytes int +} + +// runReloadStress opens streamCount upgraded streams, then performs reloadCount +// config reloads spread over time. An echo check is performed every 6 reloads so +// stream health is exercised at each reload boundary rather than only at the end. +// closeDelay mirrors the stream_close_delay config option; pass 0 to disable. +func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, retain bool, closeDelay time.Duration) stressRunResult { + t.Helper() + + const echoEvery = 6 // perform an echo check every N reloads + + streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", 12) + reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", 24) + + tester.InitServer(reloadStressConfig(backendAddr, retain, closeDelay, 0), "caddyfile") + + clients := make([]*upgradedStreamClient, 0, streamCount) + for i := 0; i < streamCount; i++ { + client := newUpgradedStreamClient(t) + clients = append(clients, client) + if err := client.echo(fmt.Sprintf("%s-warmup-%02d\n", mode, i)); err != nil { + closeClients(clients) + t.Fatalf("warmup echo failed in %s mode: %v", mode, err) + } + } + defer closeClients(clients) + + before := captureHeapSnapshot(t) + + // Reloads are spread across time; between batches of echoEvery reloads we + // pause briefly and measure stream health so the snapshot reflects real-world + // reload cadence rather than a tight loop. + for i := 1; i <= reloadCount; i++ { + loadCaddyfileConfig(t, reloadStressConfig(backendAddr, retain, closeDelay, i)) + + // Small pause after each reload to let connection teardown propagate. + time.Sleep(50 * time.Millisecond) + + if i%echoEvery == 0 { + alive := countAliveStreams(clients) + t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i) + + // In retain mode every stream must survive every reload (upstream unchanged). + if retain { + for j, client := range clients { + if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil { + t.Fatalf("retain mode stream %d died at reload %d: %v", j, i, err) + } + } + } + } + } + + // mid snapshot: after all reloads but before any close_delay timer has fired + // (the delay is long enough to still be running at this point). + mid := captureHeapSnapshot(t) + + // For legacy mode: the reloads close streams immediately; wait for that to complete. + // For close_delay mode: streams are still alive here; wait for the delay to fire. + // For retain mode: streams survive indefinitely; no wait needed. + var aliveBeforeDelayExpiry int + aliveAfterReloads := countAliveStreams(clients) + switch { + case retain: + // nothing to wait for + case closeDelay > 0: + // streams should still be alive at this point (delay hasn't expired) + aliveBeforeDelayExpiry = aliveAfterReloads + t.Logf("%s mode: %d/%d streams alive before close_delay expires; waiting %v for cleanup", + mode, aliveBeforeDelayExpiry, streamCount, closeDelay) + time.Sleep(closeDelay + 200*time.Millisecond) + aliveAfterReloads = countAliveStreams(clients) + default: + deadline := time.Now().Add(2 * time.Second) + for aliveAfterReloads > 0 && time.Now().Before(deadline) { + time.Sleep(50 * time.Millisecond) + aliveAfterReloads = countAliveStreams(clients) + } + } + + after := captureHeapSnapshot(t) + t.Logf("%s mode heap profile size: before=%dB mid=%dB after=%dB objects(before=%d mid=%d after=%d)", + mode, + before.profileBytes, mid.profileBytes, after.profileBytes, + before.HeapObjects, mid.HeapObjects, after.HeapObjects, + ) + + return stressRunResult{ + streamCount: streamCount, + aliveAfterReloads: aliveAfterReloads, + aliveBeforeDelayExpiry: aliveBeforeDelayExpiry, + beforeReload: before, + midReload: mid, + afterReload: after, + } +} + +func envIntOrDefault(t *testing.T, key string, def int) int { + t.Helper() + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return def + } + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + t.Fatalf("invalid %s=%q: must be a positive integer", key, raw) + } + return v +} + +func loadCaddyfileConfig(t *testing.T, rawConfig string) { + t.Helper() + + client := &http.Client{Timeout: 30 * time.Second} + req, err := http.NewRequest(http.MethodPost, "http://localhost:2999/load", strings.NewReader(rawConfig)) + if err != nil { + t.Fatalf("creating load request: %v", err) + } + req.Header.Set("Content-Type", "text/caddyfile") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("loading config: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading load response: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("loading config failed: status=%d body=%s", resp.StatusCode, body) + } +} + +func reloadStressConfig(backendAddr string, retain bool, closeDelay time.Duration, revision int) string { + var directives string + if retain { + directives += "\n\t\tstream_retain_on_reload" + } + if closeDelay > 0 { + directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay) + } + + return fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + reverse_proxy %s { + header_up X-Reload-Revision %d%s + } +} +`, backendAddr, revision, directives) +} + +func captureHeapSnapshot(t *testing.T) heapSnapshot { + t.Helper() + + runtime.GC() + debug.FreeOSMemory() + + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + + var buf bytes.Buffer + if err := pprof.Lookup("heap").WriteTo(&buf, 1); err != nil { + t.Fatalf("capturing heap profile: %v", err) + } + profile := buf.String() + + return heapSnapshot{ + HeapInuse: mem.HeapInuse, + HeapObjects: mem.HeapObjects, + handlerFrames: strings.Count(profile, "modules/caddyhttp/reverseproxy.(*Handler)"), + profileBytes: buf.Len(), + } +} + +func countAliveStreams(clients []*upgradedStreamClient) int { + alive := 0 + for index, client := range clients { + if err := client.echo(fmt.Sprintf("alive-check-%02d\n", index)); err == nil { + alive++ + } + } + return alive +} + +func closeClients(clients []*upgradedStreamClient) { + for _, client := range clients { + if client != nil { + _ = client.Close() + } + } +} + +func formatBytes(value uint64) string { + const unit = 1024 + if value < unit { + return fmt.Sprintf("%d B", value) + } + div, exp := uint64(unit), 0 + for n := value / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(value)/float64(div), "KMGTPE"[exp]) +} + +func formatBytesDiff(before, after uint64) string { + if after >= before { + return "+" + formatBytes(after-before) + } + return "-" + formatBytes(before-after) +} + +type upgradedStreamClient struct { + conn net.Conn + reader *bufio.Reader + mu sync.Mutex +} + +func newUpgradedStreamClient(t *testing.T) *upgradedStreamClient { + t.Helper() + + conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second) + if err != nil { + t.Fatalf("dialing caddy: %v", err) + } + + request := strings.Join([]string{ + "GET /upgrade HTTP/1.1", + "Host: localhost:9080", + "Connection: Upgrade", + "Upgrade: stress-stream", + "", + "", + }, "\r\n") + if _, err := io.WriteString(conn, request); err != nil { + _ = conn.Close() + t.Fatalf("writing upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + tproto := textproto.NewReader(reader) + statusLine, err := tproto.ReadLine() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + _ = conn.Close() + t.Fatalf("unexpected upgrade status: %s", statusLine) + } + + headers, err := tproto.ReadMIMEHeader() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade headers: %v", err) + } + if !strings.EqualFold(headers.Get("Connection"), "Upgrade") { + _ = conn.Close() + t.Fatalf("unexpected upgrade response headers: %v", headers) + } + + return &upgradedStreamClient{conn: conn, reader: reader} +} + +func (c *upgradedStreamClient) echo(payload string) error { + c.mu.Lock() + defer c.mu.Unlock() + + deadline := time.Now().Add(1 * time.Second) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return err + } + if _, err := io.WriteString(c.conn, payload); err != nil { + return err + } + if err := c.conn.SetReadDeadline(deadline); err != nil { + return err + } + + buf := make([]byte, len(payload)) + if _, err := io.ReadFull(c.reader, buf); err != nil { + return err + } + if string(buf) != payload { + return fmt.Errorf("unexpected echoed payload: got %q want %q", string(buf), payload) + } + return nil +} + +func (c *upgradedStreamClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn.Close() +} + +type upgradeEchoBackend struct { + addr string + ln net.Listener + mu sync.Mutex + conns map[net.Conn]struct{} + server *http.Server +} + +func newUpgradeEchoBackend(t *testing.T) *upgradeEchoBackend { + t.Helper() + + backend := &upgradeEchoBackend{conns: make(map[net.Conn]struct{})} + backend.server = &http.Server{ + Handler: http.HandlerFunc(backend.serveHTTP), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for backend: %v", err) + } + backend.ln = ln + backend.addr = ln.Addr().String() + + go func() { + _ = backend.server.Serve(ln) + }() + + return backend +} + +func (b *upgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "stress-stream") { + http.Error(w, "upgrade required", http.StatusUpgradeRequired) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + conn, rw, err := hijacker.Hijack() + if err != nil { + return + } + + b.trackConn(conn) + _, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: stress-stream\r\n\r\n") + _ = rw.Flush() + + go func() { + defer b.untrackConn(conn) + defer conn.Close() + _, _ = io.Copy(conn, conn) + }() +} + +func (b *upgradeEchoBackend) trackConn(conn net.Conn) { + b.mu.Lock() + b.conns[conn] = struct{}{} + b.mu.Unlock() +} + +func (b *upgradeEchoBackend) untrackConn(conn net.Conn) { + b.mu.Lock() + delete(b.conns, conn) + b.mu.Unlock() +} + +func (b *upgradeEchoBackend) Close() { + _ = b.server.Close() + _ = b.ln.Close() + + b.mu.Lock() + defer b.mu.Unlock() + for conn := range b.conns { + _ = conn.Close() + } + clear(b.conns) +} diff --git a/modules/caddyhttp/encode/encode.go b/modules/caddyhttp/encode/encode.go index ac995c37b..ecf85495a 100644 --- a/modules/caddyhttp/encode/encode.go +++ b/modules/caddyhttp/encode/encode.go @@ -405,6 +405,11 @@ func (rw *responseWriter) ReadFrom(r io.Reader) (int64, error) { // Close writes any remaining buffered response and // deallocates any active resources. func (rw *responseWriter) Close() error { + if caddyhttp.ResponseWriterHijacked(rw.ResponseWriter) { + rw.releaseEncoder() + return nil + } + // didn't write, probably head request if !rw.wroteHeader { cl, err := strconv.Atoi(rw.Header().Get("Content-Length")) @@ -422,13 +427,20 @@ func (rw *responseWriter) Close() error { var err error if rw.w != nil { err = rw.w.Close() - rw.w.Reset(nil) - rw.config.writerPools[rw.encodingName].Put(rw.w) - rw.w = nil + rw.releaseEncoder() } return err } +func (rw *responseWriter) releaseEncoder() { + if rw.w == nil { + return + } + rw.w.Reset(nil) + rw.config.writerPools[rw.encodingName].Put(rw.w) + rw.w = nil +} + // Unwrap returns the underlying ResponseWriter. func (rw *responseWriter) Unwrap() http.ResponseWriter { return rw.ResponseWriter diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 904c30c03..d5b43bf42 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -70,6 +70,7 @@ type responseRecorder struct { size int wroteHeader bool stream bool + hijacked bool readSize *int } @@ -144,7 +145,8 @@ func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer // WriteHeader writes the headers with statusCode to the wrapped // ResponseWriter unless the response is to be buffered instead. -// 1xx responses are never buffered. +// 1xx responses are never buffered, except 101 which is treated +// as a final upgrade response. func (rr *responseRecorder) WriteHeader(statusCode int) { if rr.wroteHeader { return @@ -153,6 +155,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { // save statusCode always, in case HTTP middleware upgrades websocket // connections by manually setting headers and writing status 101 rr.statusCode = statusCode + if statusCode == http.StatusSwitchingProtocols { + rr.stream = true + rr.wroteHeader = true + rr.ResponseWriterWrapper.WriteHeader(statusCode) + return + } // decide whether we should buffer the response if rr.shouldBuffer == nil { @@ -222,7 +230,14 @@ func (rr *responseRecorder) Buffered() bool { return !rr.stream } +func (rr *responseRecorder) Hijacked() bool { + return rr.hijacked +} + func (rr *responseRecorder) WriteResponse() error { + if rr.hijacked { + return nil + } if rr.statusCode == 0 { // could happen if no handlers actually wrote anything, // and this prevents a panic; status must be > 0 @@ -258,13 +273,16 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { if err != nil { return nil, nil, err } - // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not - conn = &hijackedConn{conn, rr} + rr.hijacked = true + rr.stream = true + rr.wroteHeader = true + // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not. + // Return the raw hijacked connection so upgraded stream traffic does not keep + // traversing the response recorder hot path. brw.Writer.Reset(conn) buffered := brw.Reader.Buffered() if buffered != 0 { - conn.(*hijackedConn).updateReadSize(buffered) data, _ := brw.Peek(buffered) brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn)) // peek to make buffered data appear, as Reset will make it 0 @@ -275,40 +293,24 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return conn, brw, nil } -// used to track the size of hijacked response writers -type hijackedConn struct { - net.Conn - rr *responseRecorder -} - -func (hc *hijackedConn) updateReadSize(n int) { - if hc.rr.readSize != nil { - *hc.rr.readSize += n +// ResponseWriterHijacked reports whether w or one of its wrapped response +// writers has been hijacked. +func ResponseWriterHijacked(w http.ResponseWriter) bool { + for w != nil { + if hijacked, ok := w.(interface{ Hijacked() bool }); ok && hijacked.Hijacked() { + return true + } + unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter }) + if !ok { + return false + } + next := unwrapper.Unwrap() + if next == w { + return false + } + w = next } -} - -func (hc *hijackedConn) Read(p []byte) (int, error) { - n, err := hc.Conn.Read(p) - hc.updateReadSize(n) - return n, err -} - -func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) { - n, err := io.Copy(w, hc.Conn) - hc.updateReadSize(int(n)) - return n, err -} - -func (hc *hijackedConn) Write(p []byte) (int, error) { - n, err := hc.Conn.Write(p) - hc.rr.size += n - return n, err -} - -func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) { - n, err := io.Copy(hc.Conn, r) - hc.rr.size += int(n) - return n, err + return false } // ResponseRecorder is a http.ResponseWriter that records @@ -319,6 +321,7 @@ type ResponseRecorder interface { Status() int Buffer() *bytes.Buffer Buffered() bool + Hijacked() bool Size() int WriteResponse() error } @@ -338,7 +341,4 @@ var ( // see PR #5022 (25%-50% speedup) _ io.ReaderFrom = (*ResponseWriterWrapper)(nil) _ io.ReaderFrom = (*responseRecorder)(nil) - _ io.ReaderFrom = (*hijackedConn)(nil) - - _ io.WriterTo = (*hijackedConn)(nil) ) diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index c08ad26a4..72e416db1 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -1,11 +1,14 @@ package caddyhttp import ( + "bufio" "bytes" "io" + "net" "net/http" "strings" "testing" + "time" ) type responseWriterSpy interface { @@ -44,6 +47,50 @@ func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) { func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called } +type hijackRespWriter struct { + baseRespWriter + header http.Header + status int + conn net.Conn +} + +func newHijackRespWriter() *hijackRespWriter { + return &hijackRespWriter{ + header: make(http.Header), + conn: stubConn{}, + } +} + +func (hrw *hijackRespWriter) Header() http.Header { + return hrw.header +} + +func (hrw *hijackRespWriter) WriteHeader(statusCode int) { + hrw.status = statusCode +} + +func (hrw *hijackRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + br := bufio.NewReader(hrw.conn) + bw := bufio.NewWriter(hrw.conn) + return hrw.conn, bufio.NewReadWriter(br, bw), nil +} + +type stubConn struct{} + +func (stubConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (stubConn) Write(p []byte) (int, error) { return len(p), nil } +func (stubConn) Close() error { return nil } +func (stubConn) LocalAddr() net.Addr { return stubAddr("local") } +func (stubConn) RemoteAddr() net.Addr { return stubAddr("remote") } +func (stubConn) SetDeadline(time.Time) error { return nil } +func (stubConn) SetReadDeadline(time.Time) error { return nil } +func (stubConn) SetWriteDeadline(time.Time) error { return nil } + +type stubAddr string + +func (a stubAddr) Network() string { return "tcp" } +func (a stubAddr) String() string { return string(a) } + func TestResponseWriterWrapperReadFrom(t *testing.T) { tests := map[string]struct { responseWriter responseWriterSpy @@ -169,3 +216,49 @@ func TestResponseRecorderReadFrom(t *testing.T) { }) } } + +func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) { + w := newHijackRespWriter() + var buf bytes.Buffer + + rr := NewResponseRecorder(w, &buf, func(status int, header http.Header) bool { + return true + }) + rr.WriteHeader(http.StatusSwitchingProtocols) + + if rr.Buffered() { + t.Fatal("101 switching protocols response should not remain buffered") + } + if rr.Status() != http.StatusSwitchingProtocols { + t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols) + } + if w.status != http.StatusSwitchingProtocols { + t.Fatalf("underlying status = %d, want %d", w.status, http.StatusSwitchingProtocols) + } + + hj, ok := rr.(http.Hijacker) + if !ok { + t.Fatal("response recorder does not implement http.Hijacker") + } + conn, _, err := hj.Hijack() + if err != nil { + t.Fatalf("Hijack() error = %v", err) + } + defer conn.Close() + + if !rr.Hijacked() { + t.Fatal("response recorder should report hijacked state") + } + if !ResponseWriterHijacked(rr) { + t.Fatal("ResponseWriterHijacked() should report true after hijack") + } + if err := rr.WriteResponse(); err != nil { + t.Fatalf("WriteResponse() after hijack returned error: %v", err) + } + if rr.Size() != 0 { + t.Fatalf("size = %d, want 0 after hijack handshake", rr.Size()) + } + if got := w.Written(); got != "" { + t.Fatalf("unexpected buffered body write after hijack: %q", got) + } +} diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index a370a2873..c75ecb55f 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -99,6 +99,8 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) // stream_buffer_size // stream_timeout // stream_close_delay +// stream_retain_on_reload +// stream_log_skip_handshake // verbose_logs // // # request manipulation @@ -703,6 +705,18 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { h.StreamCloseDelay = caddy.Duration(dur) } + case "stream_retain_on_reload": + if d.NextArg() { + return d.ArgErr() + } + h.StreamRetainOnReload = true + + case "stream_log_skip_handshake": + if d.NextArg() { + return d.ArgErr() + } + h.StreamLogSkipHandshake = true + case "trusted_proxies": for d.NextArg() { if d.Val() == "private_ranges" { diff --git a/modules/caddyhttp/reverseproxy/copyresponse.go b/modules/caddyhttp/reverseproxy/copyresponse.go index c1c9de92b..ec1720d31 100644 --- a/modules/caddyhttp/reverseproxy/copyresponse.go +++ b/modules/caddyhttp/reverseproxy/copyresponse.go @@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request hrc.isFinalized = true // write the response - return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger) + return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr) } // CopyResponseHeadersHandler is a special HTTP handler which may diff --git a/modules/caddyhttp/reverseproxy/metrics.go b/modules/caddyhttp/reverseproxy/metrics.go index 248842730..4b26d8641 100644 --- a/modules/caddyhttp/reverseproxy/metrics.go +++ b/modules/caddyhttp/reverseproxy/metrics.go @@ -16,6 +16,10 @@ import ( var reverseProxyMetrics = struct { once sync.Once upstreamsHealthy *prometheus.GaugeVec + streamsActive *prometheus.GaugeVec + streamsTotal *prometheus.CounterVec + streamDuration *prometheus.HistogramVec + streamBytes *prometheus.CounterVec logger *zap.Logger }{} @@ -23,6 +27,8 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { const ns, sub = "caddy", "reverse_proxy" upstreamsLabels := []string{"upstream"} + streamResultLabels := []string{"upstream", "result"} + streamBytesLabels := []string{"upstream", "direction"} reverseProxyMetrics.once.Do(func() { reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: ns, @@ -30,6 +36,31 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { Name: "upstreams_healthy", Help: "Health status of reverse proxy upstreams.", }, upstreamsLabels) + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: ns, + Subsystem: sub, + Name: "streams_active", + Help: "Number of currently active upgraded reverse proxy streams.", + }, upstreamsLabels) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: ns, + Subsystem: sub, + Name: "streams_total", + Help: "Total number of upgraded reverse proxy streams by close result.", + }, streamResultLabels) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: ns, + Subsystem: sub, + Name: "stream_duration_seconds", + Help: "Duration of upgraded reverse proxy streams by close result.", + Buckets: prometheus.DefBuckets, + }, streamResultLabels) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: ns, + Subsystem: sub, + Name: "stream_bytes_total", + Help: "Total bytes proxied across upgraded reverse proxy streams.", + }, streamBytesLabels) }) // duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because @@ -42,10 +73,58 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { }) { panic(err) } + if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamsActive, + NewCollector: reverseProxyMetrics.streamsActive, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamsTotal, + NewCollector: reverseProxyMetrics.streamsTotal, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamDuration, + NewCollector: reverseProxyMetrics.streamDuration, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamBytes, + NewCollector: reverseProxyMetrics.streamBytes, + }) { + panic(err) + } reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics") } +func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) { + labels := prometheus.Labels{"upstream": upstream} + reverseProxyMetrics.streamsActive.With(labels).Inc() + + var once sync.Once + return func(result string, duration time.Duration, toBackend, fromBackend int64) { + once.Do(func() { + reverseProxyMetrics.streamsActive.With(labels).Dec() + reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc() + reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds()) + if toBackend > 0 { + reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend)) + } + if fromBackend > 0 { + reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend)) + } + }) + } +} + type metricsUpstreamsHealthyUpdater struct { handler *Handler } diff --git a/modules/caddyhttp/reverseproxy/metrics_test.go b/modules/caddyhttp/reverseproxy/metrics_test.go new file mode 100644 index 000000000..edbe9ca8d --- /dev/null +++ b/modules/caddyhttp/reverseproxy/metrics_test.go @@ -0,0 +1,67 @@ +package reverseproxy + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) { + const upstream = "127.0.0.1:7443" + + // Use fresh metric vectors for deterministic assertions in this unit test. + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"}) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"}) + + finish := trackActiveStream(upstream) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 { + t.Fatalf("active streams = %v, want 1", got) + } + + finish("closed", 150*time.Millisecond, 1234, 4321) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 { + t.Fatalf("active streams = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 { + t.Fatalf("streams_total closed = %v, want 1", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 { + t.Fatalf("bytes to_upstream = %v, want 1234", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 { + t.Fatalf("bytes from_upstream = %v, want 4321", got) + } + + // A second finish call should be ignored by the once guard. + finish("error", 1*time.Second, 111, 222) + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 { + t.Fatalf("streams_total error = %v, want 0", got) + } +} + +func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) { + const upstream = "127.0.0.1:9000" + + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"}) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"}) + + trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 { + t.Fatalf("bytes to_upstream = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 { + t.Fatalf("bytes from_upstream = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 { + t.Fatalf("streams_total timeout = %v, want 1", got) + } +} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 3b9b56a05..df731ce3d 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -186,6 +186,18 @@ type Handler struct { // by the previous config closing. Default: no delay. StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"` + // If true, upgraded connections such as WebSockets are retained across + // config reloads when their upstream still exists in the new config. + // Connections using upstreams that are removed are closed during cleanup. + // By default this is false, preserving legacy behavior where upgraded + // connections are closed on reload (optionally delayed by stream_close_delay). + StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"` + + // If true, suppresses the access log entry normally emitted when an + // upgraded stream handshake completes and the request unwinds. By default + // the handshake is still logged as a normal request with status 101. + StreamLogSkipHandshake bool `json:"stream_log_skip_handshake,omitempty"` + // If configured, rewrites the copy of the upstream request. // Allows changing the request method and URI (path and query). // Since the rewrite is applied to the copy, it does not persist @@ -240,10 +252,9 @@ type Handler struct { // Holds the handle_response Caddyfile tokens while adapting handleResponseSegments []*caddyfile.Dispenser - // Stores upgraded requests (hijacked connections) for proper cleanup - connections map[io.ReadWriteCloser]openConnection - connectionsCloseTimer *time.Timer - connectionsMu *sync.Mutex + // Tracks hijacked/upgraded connections (WebSocket etc.) so they can be + // closed when their upstream is removed from the config. + tunnel *tunnelState ctx caddy.Context logger *zap.Logger @@ -267,8 +278,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.events = eventAppIface.(*caddyevents.App) h.ctx = ctx h.logger = ctx.Logger() - h.connections = make(map[io.ReadWriteCloser]openConnection) - h.connectionsMu = new(sync.Mutex) + h.tunnel = newTunnelState(h.logger, time.Duration(h.StreamCloseDelay)) // warn about unsafe buffering config if h.RequestBuffers == -1 || h.ResponseBuffers == -1 { @@ -439,13 +449,29 @@ func (h *Handler) Provision(ctx caddy.Context) error { // Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { - err := h.cleanupConnections() - - // remove hosts from our config from the pool - for _, upstream := range h.Upstreams { - _, _ = hosts.Delete(upstream.String()) + if !h.StreamRetainOnReload { + // Legacy behavior: close all upgraded connections on reload, either + // immediately or after StreamCloseDelay. + err := h.tunnel.cleanupConnections() + for _, upstream := range h.Upstreams { + _, _ = hosts.Delete(upstream.String()) + } + return err } + var err error + for _, upstream := range h.Upstreams { + // hosts.Delete returns deleted=true when the ref count reaches zero, + // meaning no other active config references this upstream. In that + // case close any tunnels proxying to it; otherwise let them survive + // to their natural end since the upstream is still in use. + deleted, _ := hosts.Delete(upstream.String()) + if deleted { + if closeErr := h.tunnel.closeConnectionsForUpstream(upstream.String()); closeErr != nil && err == nil { + err = closeErr + } + } + } return err } @@ -1092,10 +1118,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe // we use the original request here, so that any routes from 'next' // see the original request rather than the proxy cloned request. hrc := &handleResponseContext{ - handler: h, - response: res, - start: start, - logger: logger, + handler: h, + response: res, + start: start, + logger: logger, + upstreamAddr: di.Upstream.String(), } ctx := origReq.Context() ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc) @@ -1125,7 +1152,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe } // copy the response body and headers back to the upstream client - return h.finalizeResponse(rw, req, res, repl, start, logger) + return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String()) } // finalizeResponse prepares and copies the response. @@ -1136,11 +1163,12 @@ func (h *Handler) finalizeResponse( repl *caddy.Replacer, start time.Time, logger *zap.Logger, + upstreamAddr string, ) error { // deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { var wg sync.WaitGroup - h.handleUpgradeResponse(logger, &wg, rw, req, res) + h.handleUpgradeResponse(logger, &wg, rw, req, res, upstreamAddr) wg.Wait() return nil } @@ -1719,6 +1747,9 @@ type handleResponseContext struct { // i.e. copied and closed, to make sure that it doesn't // happen twice. isFinalized bool + + // upstreamAddr is the selected upstream address for this request. + upstreamAddr string } // proxyHandleResponseContextCtxKey is the context key for the active proxy handler diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index e454ee655..37fca6014 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -26,6 +26,7 @@ import ( "io" weakrand "math/rand/v2" "mime" + "net" "net/http" "sync" "time" @@ -35,6 +36,7 @@ import ( "go.uber.org/zap/zapcore" "golang.org/x/net/http/httpguts" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) @@ -57,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { return n, nil } -func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) { +func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) { reqUpType := upgradeType(req.Header) resUpType := upgradeType(res.Header) @@ -90,13 +92,22 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, copyHeader(rw.Header(), res.Header) normalizeWebsocketHeaders(rw.Header()) + // Capture all h fields needed by the tunnel now, so that the Handler (h) + // is not referenced after this function returns (for HTTP/1.1 hijacked + // connections the tunnel runs in a detached goroutine). + tunnel := h.tunnel + bufferSize := h.StreamBufferSize + streamTimeout := time.Duration(h.StreamTimeout) + var ( conn io.ReadWriteCloser brw *bufio.ReadWriter + isH2 bool ) // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade // TODO: once we can reliably detect backend support this, it can be removed for those backends if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { + isH2 = true req.Body = body rw.Header().Del("Upgrade") rw.Header().Del("Connection") @@ -143,26 +154,24 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 - backConnCloseCh := make(chan struct{}) - go func() { - // Ensure that the cancellation of a request closes the backend. - // See issue https://golang.org/issue/35559. - select { - case <-req.Context().Done(): - case <-backConnCloseCh: - } - backConn.Close() - }() - defer close(backConnCloseCh) - - start := time.Now() - defer func() { - conn.Close() - if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { - c.Write(zap.Duration("duration", time.Since(start))) - } - }() + // For H2 extended connect: close backConn when the request context is + // cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections + // we skip this because req.Context() may be cancelled when ServeHTTP + // returns early, which would prematurely close the backend connection. + if isH2 { + // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 + backConnCloseCh := make(chan struct{}) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + defer close(backConnCloseCh) + } if err := brw.Flush(); err != nil { if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil { @@ -184,13 +193,11 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // Ensure the hijacked client connection, and the new connection established - // with the backend, are both closed in the event of a server shutdown. This - // is done by registering them. We also try to gracefully close connections - // we recognize as websockets. - // We need to make sure the client connection messages (i.e. to upstream) - // are masked, so we need to know whether the connection is considered the - // server or the client side of the proxy. + // Register both connections with the tunnel tracker. We also try to + // gracefully close connections we recognize as websockets. We need to make + // sure the client connection messages (i.e. to upstream) are masked, so we + // need to know whether the connection is considered the server or the + // client side of the proxy. gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error { if isWebsocket(req) { return func() error { @@ -199,43 +206,186 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } return nil } - deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false)) - deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true)) - defer deleteFrontConn() + deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), upstreamAddr) + deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), upstreamAddr) + if h.StreamLogSkipHandshake { + caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true) + } + repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + repl.Set("http.reverse_proxy.upgraded", true) + finishMetrics := trackActiveStream(upstreamAddr) + + start := time.Now() + + if isH2 { + h.handleH2UpgradeTunnel(logger, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics) + } else { + h.handleDetachedUpgradeTunnel(logger, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics) + // Return immediately without touching wg. finalizeResponse's + // wg.Wait() returns at once since wg was never incremented. + } +} + +func (h *Handler) handleH2UpgradeTunnel( + logger *zap.Logger, + wg *sync.WaitGroup, + conn io.ReadWriteCloser, + backConn io.ReadWriteCloser, + deleteFrontConn func(), + deleteBackConn func(), + bufferSize int, + streamTimeout time.Duration, + start time.Time, + finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), +) { + // H2 extended connect: ServeHTTP must block because rw and req.Body are + // only valid while the handler goroutine is running. Defers clean up + // when the select below fires and this function returns. defer deleteBackConn() - - spc := switchProtocolCopier{ - user: conn, - backend: backConn, - wg: wg, - bufferSize: h.StreamBufferSize, - } - - // setup the timeout if requested - var timeoutc <-chan time.Time - if h.StreamTimeout > 0 { - timer := time.NewTimer(time.Duration(h.StreamTimeout)) - defer timer.Stop() - timeoutc = timer.C - } + defer deleteFrontConn() + var ( + toBackend int64 + fromBackend int64 + result = "closed" + ) // when a stream timeout is encountered, no error will be read from errc // a buffer size of 2 will allow both the read and write goroutines to send the error and exit // see: https://github.com/caddyserver/caddy/issues/7418 errc := make(chan error, 2) + spc := switchProtocolCopier{ + user: conn, + backend: backConn, + wg: wg, + bufferSize: bufferSize, + sent: &toBackend, + received: &fromBackend, + } wg.Add(2) + + var timeoutc <-chan time.Time + if streamTimeout > 0 { + timer := time.NewTimer(streamTimeout) + defer timer.Stop() + timeoutc = timer.C + } + go spc.copyToBackend(errc) go spc.copyFromBackend(errc) select { case err := <-errc: + result = classifyStreamResult(err) if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil { c.Write(zap.Error(err)) } - case time := <-timeoutc: + case t := <-timeoutc: + result = "timeout" if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil { - c.Write(zap.Time("timeout", time)) + c.Write(zap.Time("timeout", t)) } } + + // Close both ends to unblock the still-running copy goroutine, + // then wait for it so byte counts are final before metrics/logging. + conn.Close() + backConn.Close() + wg.Wait() + + finishMetrics(result, time.Since(start), toBackend, fromBackend) + if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { + c.Write( + zap.Duration("duration", time.Since(start)), + zap.Int64("bytes_to_backend", toBackend), + zap.Int64("bytes_from_backend", fromBackend), + ) + } +} + +func (h *Handler) handleDetachedUpgradeTunnel( + logger *zap.Logger, + conn io.ReadWriteCloser, + backConn io.ReadWriteCloser, + deleteFrontConn func(), + deleteBackConn func(), + bufferSize int, + streamTimeout time.Duration, + start time.Time, + finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), +) { + // HTTP/1.1 hijacked connection: launch a detached goroutine so that + // ServeHTTP can return immediately, allowing the Handler to be GC'd + // after a config reload. The goroutine captures only tunnel (a small + // *tunnelState), logger, conn/backConn, and scalar config values. + go func() { + var ( + toBackend int64 + fromBackend int64 + result = "closed" + ) + defer deleteBackConn() + defer deleteFrontConn() + defer func() { + finishMetrics(result, time.Since(start), toBackend, fromBackend) + if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { + c.Write( + zap.Duration("duration", time.Since(start)), + zap.Int64("bytes_to_backend", toBackend), + zap.Int64("bytes_from_backend", fromBackend), + ) + } + }() + + var innerWg sync.WaitGroup + // when a stream timeout is encountered, no error will be read from errc + // a buffer size of 2 will allow both the read and write goroutines to send the error and exit + // see: https://github.com/caddyserver/caddy/issues/7418 + errc := make(chan error, 2) + spc := switchProtocolCopier{ + user: conn, + backend: backConn, + wg: &innerWg, + bufferSize: bufferSize, + sent: &toBackend, + received: &fromBackend, + } + innerWg.Add(2) + + var timeoutc <-chan time.Time + if streamTimeout > 0 { + timer := time.NewTimer(streamTimeout) + defer timer.Stop() + timeoutc = timer.C + } + + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + select { + case err := <-errc: + result = classifyStreamResult(err) + if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil { + c.Write(zap.Error(err)) + } + case t := <-timeoutc: + result = "timeout" + if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil { + c.Write(zap.Time("timeout", t)) + } + } + + // Close both ends to unblock the still-running copy goroutine, + // then wait for it to finish so byte counts are accurate before + // the deferred log fires. + conn.Close() + backConn.Close() + innerWg.Wait() + }() +} + +func classifyStreamResult(err error) string { + if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return "closed" + } + return "error" } // flushInterval returns the p.FlushInterval value, conditionally @@ -375,75 +525,86 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za } } -// registerConnection holds onto conn so it can be closed in the event -// of a server shutdown. This is useful because hijacked connections or -// connections dialed to backends don't close when server is shut down. -// The caller should call the returned delete() function when the -// connection is done to remove it from memory. -func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) { - h.connectionsMu.Lock() - h.connections[conn] = openConnection{conn, gracefulClose} - h.connectionsMu.Unlock() - return func() { - h.connectionsMu.Lock() - delete(h.connections, conn) - // if there is no connection left before the connections close timer fires - if len(h.connections) == 0 && h.connectionsCloseTimer != nil { - // we release the timer that holds the reference to Handler - if (*h.connectionsCloseTimer).Stop() { - h.logger.Debug("stopped streaming connections close timer - all connections are already closed") - } - h.connectionsCloseTimer = nil - } - h.connectionsMu.Unlock() +// openConnection maps an open connection to an optional function for graceful +// close and records which upstream address the connection is proxying to. +type openConnection struct { + conn io.ReadWriteCloser + gracefulClose func() error + upstream string +} + +// tunnelState tracks hijacked/upgraded connections for selective cleanup. +type tunnelState struct { + connections map[io.ReadWriteCloser]openConnection + closeTimer *time.Timer + closeDelay time.Duration + mu sync.Mutex + logger *zap.Logger +} + +func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState { + return &tunnelState{ + connections: make(map[io.ReadWriteCloser]openConnection), + closeDelay: closeDelay, + logger: logger, } } -// closeConnections immediately closes all hijacked connections (both to client and backend). -func (h *Handler) closeConnections() error { - var err error - h.connectionsMu.Lock() - defer h.connectionsMu.Unlock() +// registerConnection stores conn in the tracking map. The caller must invoke +// the returned del func when the connection is done. +func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, upstream string) (del func()) { + ts.mu.Lock() + ts.connections[conn] = openConnection{conn, gracefulClose, upstream} + ts.mu.Unlock() + return func() { + ts.mu.Lock() + delete(ts.connections, conn) + if len(ts.connections) == 0 && ts.closeTimer != nil { + if ts.closeTimer.Stop() { + ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") + } + ts.closeTimer = nil + } + ts.mu.Unlock() + } +} - for _, oc := range h.connections { +// closeConnections closes all tracked connections. +func (ts *tunnelState) closeConnections() error { + var err error + ts.mu.Lock() + defer ts.mu.Unlock() + for _, oc := range ts.connections { if oc.gracefulClose != nil { - // this is potentially blocking while we have the lock on the connections - // map, but that should be OK since the server has in theory shut down - // and we are no longer using the connections map - gracefulErr := oc.gracefulClose() - if gracefulErr != nil && err == nil { + if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { err = gracefulErr } } - closeErr := oc.conn.Close() - if closeErr != nil && err == nil { + if closeErr := oc.conn.Close(); closeErr != nil && err == nil { err = closeErr } } return err } -// cleanupConnections closes hijacked connections. -// Depending on the value of StreamCloseDelay it does that either immediately -// or sets up a timer that will do that later. -func (h *Handler) cleanupConnections() error { - if h.StreamCloseDelay == 0 { - return h.closeConnections() +// cleanupConnections closes upgraded connections. Depending on closeDelay it +// does that either immediately or after a timer. +func (ts *tunnelState) cleanupConnections() error { + if ts.closeDelay == 0 { + return ts.closeConnections() } - h.connectionsMu.Lock() - defer h.connectionsMu.Unlock() - // the handler is shut down, no new connection can appear, - // so we can skip setting up the timer when there are no connections - if len(h.connections) > 0 { - delay := time.Duration(h.StreamCloseDelay) - h.connectionsCloseTimer = time.AfterFunc(delay, func() { - if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { + ts.mu.Lock() + defer ts.mu.Unlock() + if len(ts.connections) > 0 { + delay := ts.closeDelay + ts.closeTimer = time.AfterFunc(delay, func() { + if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { c.Write(zap.Duration("delay", delay)) } - err := h.closeConnections() + err := ts.closeConnections() if err != nil { - if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil { + if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil { c.Write( zap.Error(err), zap.Duration("delay", delay), @@ -567,11 +728,26 @@ func isWebsocket(r *http.Request) bool { httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") } -// openConnection maps an open connection to -// an optional function for graceful close. -type openConnection struct { - conn io.ReadWriteCloser - gracefulClose func() error +// closeConnectionsForUpstream closes all tracked connections that were +// established to the given upstream address. +func (ts *tunnelState) closeConnectionsForUpstream(addr string) error { + var err error + ts.mu.Lock() + defer ts.mu.Unlock() + for _, oc := range ts.connections { + if oc.upstream != addr { + continue + } + if oc.gracefulClose != nil { + if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { + err = gracefulErr + } + } + if closeErr := oc.conn.Close(); closeErr != nil && err == nil { + err = closeErr + } + } + return err } type maxLatencyWriter struct { @@ -642,16 +818,23 @@ type switchProtocolCopier struct { user, backend io.ReadWriteCloser wg *sync.WaitGroup bufferSize int + // sent and received accumulate byte counts for each direction. + // They are written before wg.Done() and read after wg.Wait(), so no + // additional synchronization is needed beyond the WaitGroup barrier. + sent *int64 // bytes copied to backend; must be non-nil + received *int64 // bytes copied from backend; must be non-nil } func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.CopyBuffer(c.user, c.backend, c.buffer()) + n, err := io.CopyBuffer(c.user, c.backend, c.buffer()) + *c.received = n errc <- err c.wg.Done() } func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.CopyBuffer(c.backend, c.user, c.buffer()) + n, err := io.CopyBuffer(c.backend, c.user, c.buffer()) + *c.sent = n errc <- err c.wg.Done() } diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index ce0db65a0..d2441739a 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -7,8 +7,10 @@ import ( "strings" "sync" "testing" + "time" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" ) func TestHandlerCopyResponse(t *testing.T) { @@ -41,12 +43,15 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) { var wg sync.WaitGroup var errc = make(chan error, 1) var dst bytes.Buffer + var sent, received int64 copier := switchProtocolCopier{ user: nopReadWriteCloser{Reader: strings.NewReader("hello")}, backend: nopReadWriteCloser{Writer: &dst}, wg: &wg, bufferSize: 7, + sent: &sent, + received: &received, } buf := copier.buffer() @@ -80,3 +85,132 @@ type nopReadWriteCloser struct { } func (nopReadWriteCloser) Close() error { return nil } + +type trackingReadWriteCloser struct { + closed chan struct{} + one sync.Once +} + +func newTrackingReadWriteCloser() *trackingReadWriteCloser { + return &trackingReadWriteCloser{closed: make(chan struct{})} +} + +func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil } +func (c *trackingReadWriteCloser) Close() error { + c.one.Do(func() { + close(c.closed) + }) + return nil +} + +func (c *trackingReadWriteCloser) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { + ts := newTunnelState(caddy.Log(), 0) + connA := newTrackingReadWriteCloser() + connB := newTrackingReadWriteCloser() + ts.registerConnection(connA, nil, "a") + ts.registerConnection(connB, nil, "b") + + h := &Handler{ + tunnel: ts, + StreamRetainOnReload: false, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + if !connA.isClosed() || !connB.isClosed() { + t.Fatalf("legacy cleanup should close all upgraded connections") + } +} + +func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) { + ts := newTunnelState(caddy.Log(), 40*time.Millisecond) + conn := newTrackingReadWriteCloser() + ts.registerConnection(conn, nil, "a") + + h := &Handler{ + tunnel: ts, + StreamRetainOnReload: false, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + if conn.isClosed() { + t.Fatal("connection should not close immediately when stream_close_delay is set") + } + + select { + case <-conn.closed: + case <-time.After(500 * time.Millisecond): + t.Fatal("connection did not close after stream_close_delay elapsed") + } +} + +func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) { + const upstreamA = "upstream-a" + const upstreamB = "upstream-b" + + // Simulate old+new configs both referencing upstreamA (refcount 2), + // while upstreamB is only referenced by the old config (refcount 1). + hosts.LoadOrStore(upstreamA, struct{}{}) + hosts.LoadOrStore(upstreamA, struct{}{}) + hosts.LoadOrStore(upstreamB, struct{}{}) + t.Cleanup(func() { + _, _ = hosts.Delete(upstreamA) + _, _ = hosts.Delete(upstreamA) + _, _ = hosts.Delete(upstreamB) + }) + + ts := newTunnelState(caddy.Log(), 0) + connA := newTrackingReadWriteCloser() + connB := newTrackingReadWriteCloser() + ts.registerConnection(connA, nil, upstreamA) + ts.registerConnection(connB, nil, upstreamB) + + h := &Handler{ + tunnel: ts, + StreamRetainOnReload: true, + Upstreams: UpstreamPool{ + &Upstream{Dial: upstreamA}, + &Upstream{Dial: upstreamB}, + }, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + + if connA.isClosed() { + t.Fatal("connection for retained upstream should remain open") + } + if !connB.isClosed() { + t.Fatal("connection for removed upstream should be closed") + } +} + +func TestHandlerUnmarshalCaddyfileStreamLogSkipHandshake(t *testing.T) { + d := caddyfile.NewTestDispenser(` + reverse_proxy localhost:9000 { + stream_log_skip_handshake + } + `) + + var h Handler + if err := h.UnmarshalCaddyfile(d); err != nil { + t.Fatalf("UnmarshalCaddyfile() error = %v", err) + } + if !h.StreamLogSkipHandshake { + t.Fatal("expected stream_log_skip_handshake to enable StreamLogSkipHandshake") + } +}