mirror of
https://github.com/caddyserver/caddy.git
synced 2026-04-24 01:49:32 -04:00
reverseproxy: Optionally detach stream (websockets) from config lifecycle
This commit is contained in:
parent
4430756d5c
commit
476fd0c077
130
caddytest/integration/reverseproxy_upgrade_handlers_test.go
Normal file
130
caddytest/integration/reverseproxy_upgrade_handlers_test.go
Normal file
@ -0,0 +1,130 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddytest"
|
||||
)
|
||||
|
||||
func TestReverseProxyUpgradeWithEncode(t *testing.T) {
|
||||
tester := caddytest.NewTester(t)
|
||||
backend := newUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
tester.InitServer(fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
}
|
||||
|
||||
localhost:9080 {
|
||||
route {
|
||||
encode gzip
|
||||
reverse_proxy %s
|
||||
}
|
||||
}
|
||||
`, backend.addr), "caddyfile")
|
||||
|
||||
client := newUpgradedStreamClientWithHeaders(t, map[string]string{
|
||||
"Accept-Encoding": "gzip",
|
||||
})
|
||||
defer client.Close()
|
||||
|
||||
if err := client.echo("encode-upgrade\n"); err != nil {
|
||||
t.Fatalf("upgraded stream echo through encode failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyUpgradeWithInterceptHandleResponse(t *testing.T) {
|
||||
tester := caddytest.NewTester(t)
|
||||
backend := newUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
tester.InitServer(fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
}
|
||||
|
||||
localhost:9080 {
|
||||
route {
|
||||
intercept {
|
||||
@upgrade status 101
|
||||
handle_response @upgrade {
|
||||
respond "should-not-run"
|
||||
}
|
||||
}
|
||||
reverse_proxy %s
|
||||
}
|
||||
}
|
||||
`, backend.addr), "caddyfile")
|
||||
|
||||
client := newUpgradedStreamClientWithHeaders(t, nil)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.echo("intercept-upgrade\n"); err != nil {
|
||||
t.Fatalf("upgraded stream echo through intercept failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newUpgradedStreamClientWithHeaders(t *testing.T, extraHeaders map[string]string) *upgradedStreamClient {
|
||||
t.Helper()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dialing caddy: %v", err)
|
||||
}
|
||||
|
||||
requestLines := []string{
|
||||
"GET /upgrade HTTP/1.1",
|
||||
"Host: localhost:9080",
|
||||
"Connection: Upgrade",
|
||||
"Upgrade: stress-stream",
|
||||
}
|
||||
for k, v := range extraHeaders {
|
||||
requestLines = append(requestLines, k+": "+v)
|
||||
}
|
||||
requestLines = append(requestLines, "", "")
|
||||
|
||||
if _, err := io.WriteString(conn, strings.Join(requestLines, "\r\n")); err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("writing upgrade request: %v", err)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
tproto := textproto.NewReader(reader)
|
||||
statusLine, err := tproto.ReadLine()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade status line: %v", err)
|
||||
}
|
||||
if !strings.Contains(statusLine, "101") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade status: %s", statusLine)
|
||||
}
|
||||
|
||||
headers, err := tproto.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade headers: %v", err)
|
||||
}
|
||||
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade response headers: %v", headers)
|
||||
}
|
||||
|
||||
return &upgradedStreamClient{conn: conn, reader: reader}
|
||||
}
|
||||
487
caddytest/integration/stream_reload_stress_test.go
Normal file
487
caddytest/integration/stream_reload_stress_test.go
Normal file
@ -0,0 +1,487 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddytest"
|
||||
)
|
||||
|
||||
// stressCloseDelay is the stream_close_delay used for the close_delay scenario.
|
||||
// Long enough to outlast all test reloads; short enough to keep total test time reasonable.
|
||||
const stressCloseDelay = 3 * time.Second
|
||||
|
||||
func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) {
|
||||
tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{
|
||||
LoadRequestTimeout: 30 * time.Second,
|
||||
TestRequestTimeout: 30 * time.Second,
|
||||
})
|
||||
|
||||
backend := newUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
// Three scenarios, each sequential so they don't share Caddy state:
|
||||
//
|
||||
// legacy – no delay, close on reload immediately (old default)
|
||||
// close_delay – stream_close_delay, the old "keep-alive workaround"
|
||||
// retain – stream_retain_on_reload, the new explicit retain flag
|
||||
//
|
||||
// Reloads are spread across time and interleaved with echo-checks so
|
||||
// stream health is exercised at each reload boundary, not only at the end.
|
||||
legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0)
|
||||
closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay)
|
||||
retain := runReloadStress(t, tester, backend.addr, "retain", true, 0)
|
||||
|
||||
if legacy.aliveAfterReloads != 0 {
|
||||
t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads)
|
||||
}
|
||||
if closeDelay.aliveBeforeDelayExpiry == 0 {
|
||||
t.Fatalf("close_delay mode: all streams closed before delay expired (expected them alive)")
|
||||
}
|
||||
if closeDelay.aliveAfterReloads != 0 {
|
||||
t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads)
|
||||
}
|
||||
if retain.aliveAfterReloads != retain.streamCount {
|
||||
t.Fatalf("retain mode kept %d/%d upgraded streams alive after reloads", retain.aliveAfterReloads, retain.streamCount)
|
||||
}
|
||||
|
||||
t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
|
||||
formatBytes(legacy.beforeReload.HeapInuse),
|
||||
formatBytes(legacy.midReload.HeapInuse),
|
||||
formatBytes(legacy.afterReload.HeapInuse),
|
||||
formatBytesDiff(legacy.beforeReload.HeapInuse, legacy.afterReload.HeapInuse),
|
||||
legacy.beforeReload.HeapObjects, legacy.afterReload.HeapObjects,
|
||||
legacy.beforeReload.handlerFrames, legacy.afterReload.handlerFrames,
|
||||
)
|
||||
t.Logf("close_delay heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
|
||||
formatBytes(closeDelay.beforeReload.HeapInuse),
|
||||
formatBytes(closeDelay.midReload.HeapInuse),
|
||||
formatBytes(closeDelay.afterReload.HeapInuse),
|
||||
formatBytesDiff(closeDelay.beforeReload.HeapInuse, closeDelay.afterReload.HeapInuse),
|
||||
closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects,
|
||||
closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames,
|
||||
)
|
||||
t.Logf("retain heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
|
||||
formatBytes(retain.beforeReload.HeapInuse),
|
||||
formatBytes(retain.midReload.HeapInuse),
|
||||
formatBytes(retain.afterReload.HeapInuse),
|
||||
formatBytesDiff(retain.beforeReload.HeapInuse, retain.afterReload.HeapInuse),
|
||||
retain.beforeReload.HeapObjects, retain.afterReload.HeapObjects,
|
||||
retain.beforeReload.handlerFrames, retain.afterReload.handlerFrames,
|
||||
)
|
||||
}
|
||||
|
||||
type stressRunResult struct {
|
||||
streamCount int
|
||||
aliveAfterReloads int
|
||||
aliveBeforeDelayExpiry int // only meaningful for close_delay mode
|
||||
beforeReload heapSnapshot
|
||||
midReload heapSnapshot // after all reloads, before delay expiry clean-up
|
||||
afterReload heapSnapshot // after all streams have been fully cleaned up
|
||||
}
|
||||
|
||||
type heapSnapshot struct {
|
||||
HeapInuse uint64
|
||||
HeapObjects uint64
|
||||
handlerFrames int
|
||||
profileBytes int
|
||||
}
|
||||
|
||||
// runReloadStress opens streamCount upgraded streams, then performs reloadCount
|
||||
// config reloads spread over time. An echo check is performed every 6 reloads so
|
||||
// stream health is exercised at each reload boundary rather than only at the end.
|
||||
// closeDelay mirrors the stream_close_delay config option; pass 0 to disable.
|
||||
func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, retain bool, closeDelay time.Duration) stressRunResult {
|
||||
t.Helper()
|
||||
|
||||
const echoEvery = 6 // perform an echo check every N reloads
|
||||
|
||||
streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", 12)
|
||||
reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", 24)
|
||||
|
||||
tester.InitServer(reloadStressConfig(backendAddr, retain, closeDelay, 0), "caddyfile")
|
||||
|
||||
clients := make([]*upgradedStreamClient, 0, streamCount)
|
||||
for i := 0; i < streamCount; i++ {
|
||||
client := newUpgradedStreamClient(t)
|
||||
clients = append(clients, client)
|
||||
if err := client.echo(fmt.Sprintf("%s-warmup-%02d\n", mode, i)); err != nil {
|
||||
closeClients(clients)
|
||||
t.Fatalf("warmup echo failed in %s mode: %v", mode, err)
|
||||
}
|
||||
}
|
||||
defer closeClients(clients)
|
||||
|
||||
before := captureHeapSnapshot(t)
|
||||
|
||||
// Reloads are spread across time; between batches of echoEvery reloads we
|
||||
// pause briefly and measure stream health so the snapshot reflects real-world
|
||||
// reload cadence rather than a tight loop.
|
||||
for i := 1; i <= reloadCount; i++ {
|
||||
loadCaddyfileConfig(t, reloadStressConfig(backendAddr, retain, closeDelay, i))
|
||||
|
||||
// Small pause after each reload to let connection teardown propagate.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if i%echoEvery == 0 {
|
||||
alive := countAliveStreams(clients)
|
||||
t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i)
|
||||
|
||||
// In retain mode every stream must survive every reload (upstream unchanged).
|
||||
if retain {
|
||||
for j, client := range clients {
|
||||
if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil {
|
||||
t.Fatalf("retain mode stream %d died at reload %d: %v", j, i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mid snapshot: after all reloads but before any close_delay timer has fired
|
||||
// (the delay is long enough to still be running at this point).
|
||||
mid := captureHeapSnapshot(t)
|
||||
|
||||
// For legacy mode: the reloads close streams immediately; wait for that to complete.
|
||||
// For close_delay mode: streams are still alive here; wait for the delay to fire.
|
||||
// For retain mode: streams survive indefinitely; no wait needed.
|
||||
var aliveBeforeDelayExpiry int
|
||||
aliveAfterReloads := countAliveStreams(clients)
|
||||
switch {
|
||||
case retain:
|
||||
// nothing to wait for
|
||||
case closeDelay > 0:
|
||||
// streams should still be alive at this point (delay hasn't expired)
|
||||
aliveBeforeDelayExpiry = aliveAfterReloads
|
||||
t.Logf("%s mode: %d/%d streams alive before close_delay expires; waiting %v for cleanup",
|
||||
mode, aliveBeforeDelayExpiry, streamCount, closeDelay)
|
||||
time.Sleep(closeDelay + 200*time.Millisecond)
|
||||
aliveAfterReloads = countAliveStreams(clients)
|
||||
default:
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for aliveAfterReloads > 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
aliveAfterReloads = countAliveStreams(clients)
|
||||
}
|
||||
}
|
||||
|
||||
after := captureHeapSnapshot(t)
|
||||
t.Logf("%s mode heap profile size: before=%dB mid=%dB after=%dB objects(before=%d mid=%d after=%d)",
|
||||
mode,
|
||||
before.profileBytes, mid.profileBytes, after.profileBytes,
|
||||
before.HeapObjects, mid.HeapObjects, after.HeapObjects,
|
||||
)
|
||||
|
||||
return stressRunResult{
|
||||
streamCount: streamCount,
|
||||
aliveAfterReloads: aliveAfterReloads,
|
||||
aliveBeforeDelayExpiry: aliveBeforeDelayExpiry,
|
||||
beforeReload: before,
|
||||
midReload: mid,
|
||||
afterReload: after,
|
||||
}
|
||||
}
|
||||
|
||||
func envIntOrDefault(t *testing.T, key string, def int) int {
|
||||
t.Helper()
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v <= 0 {
|
||||
t.Fatalf("invalid %s=%q: must be a positive integer", key, raw)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func loadCaddyfileConfig(t *testing.T, rawConfig string) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
req, err := http.NewRequest(http.MethodPost, "http://localhost:2999/load", strings.NewReader(rawConfig))
|
||||
if err != nil {
|
||||
t.Fatalf("creating load request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "text/caddyfile")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("loading config: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading load response: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("loading config failed: status=%d body=%s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func reloadStressConfig(backendAddr string, retain bool, closeDelay time.Duration, revision int) string {
|
||||
var directives string
|
||||
if retain {
|
||||
directives += "\n\t\tstream_retain_on_reload"
|
||||
}
|
||||
if closeDelay > 0 {
|
||||
directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
}
|
||||
|
||||
localhost:9080 {
|
||||
reverse_proxy %s {
|
||||
header_up X-Reload-Revision %d%s
|
||||
}
|
||||
}
|
||||
`, backendAddr, revision, directives)
|
||||
}
|
||||
|
||||
func captureHeapSnapshot(t *testing.T) heapSnapshot {
|
||||
t.Helper()
|
||||
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := pprof.Lookup("heap").WriteTo(&buf, 1); err != nil {
|
||||
t.Fatalf("capturing heap profile: %v", err)
|
||||
}
|
||||
profile := buf.String()
|
||||
|
||||
return heapSnapshot{
|
||||
HeapInuse: mem.HeapInuse,
|
||||
HeapObjects: mem.HeapObjects,
|
||||
handlerFrames: strings.Count(profile, "modules/caddyhttp/reverseproxy.(*Handler)"),
|
||||
profileBytes: buf.Len(),
|
||||
}
|
||||
}
|
||||
|
||||
func countAliveStreams(clients []*upgradedStreamClient) int {
|
||||
alive := 0
|
||||
for index, client := range clients {
|
||||
if err := client.echo(fmt.Sprintf("alive-check-%02d\n", index)); err == nil {
|
||||
alive++
|
||||
}
|
||||
}
|
||||
return alive
|
||||
}
|
||||
|
||||
func closeClients(clients []*upgradedStreamClient) {
|
||||
for _, client := range clients {
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatBytes(value uint64) string {
|
||||
const unit = 1024
|
||||
if value < unit {
|
||||
return fmt.Sprintf("%d B", value)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := value / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(value)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func formatBytesDiff(before, after uint64) string {
|
||||
if after >= before {
|
||||
return "+" + formatBytes(after-before)
|
||||
}
|
||||
return "-" + formatBytes(before-after)
|
||||
}
|
||||
|
||||
type upgradedStreamClient struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newUpgradedStreamClient(t *testing.T) *upgradedStreamClient {
|
||||
t.Helper()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dialing caddy: %v", err)
|
||||
}
|
||||
|
||||
request := strings.Join([]string{
|
||||
"GET /upgrade HTTP/1.1",
|
||||
"Host: localhost:9080",
|
||||
"Connection: Upgrade",
|
||||
"Upgrade: stress-stream",
|
||||
"",
|
||||
"",
|
||||
}, "\r\n")
|
||||
if _, err := io.WriteString(conn, request); err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("writing upgrade request: %v", err)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
tproto := textproto.NewReader(reader)
|
||||
statusLine, err := tproto.ReadLine()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade status line: %v", err)
|
||||
}
|
||||
if !strings.Contains(statusLine, "101") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade status: %s", statusLine)
|
||||
}
|
||||
|
||||
headers, err := tproto.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade headers: %v", err)
|
||||
}
|
||||
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade response headers: %v", headers)
|
||||
}
|
||||
|
||||
return &upgradedStreamClient{conn: conn, reader: reader}
|
||||
}
|
||||
|
||||
func (c *upgradedStreamClient) echo(payload string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
deadline := time.Now().Add(1 * time.Second)
|
||||
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(c.conn, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(c.reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if string(buf) != payload {
|
||||
return fmt.Errorf("unexpected echoed payload: got %q want %q", string(buf), payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *upgradedStreamClient) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
type upgradeEchoBackend struct {
|
||||
addr string
|
||||
ln net.Listener
|
||||
mu sync.Mutex
|
||||
conns map[net.Conn]struct{}
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func newUpgradeEchoBackend(t *testing.T) *upgradeEchoBackend {
|
||||
t.Helper()
|
||||
|
||||
backend := &upgradeEchoBackend{conns: make(map[net.Conn]struct{})}
|
||||
backend.server = &http.Server{
|
||||
Handler: http.HandlerFunc(backend.serveHTTP),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listening for backend: %v", err)
|
||||
}
|
||||
backend.ln = ln
|
||||
backend.addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
_ = backend.server.Serve(ln)
|
||||
}()
|
||||
|
||||
return backend
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "stress-stream") {
|
||||
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.trackConn(conn)
|
||||
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: stress-stream\r\n\r\n")
|
||||
_ = rw.Flush()
|
||||
|
||||
go func() {
|
||||
defer b.untrackConn(conn)
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) trackConn(conn net.Conn) {
|
||||
b.mu.Lock()
|
||||
b.conns[conn] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) untrackConn(conn net.Conn) {
|
||||
b.mu.Lock()
|
||||
delete(b.conns, conn)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) Close() {
|
||||
_ = b.server.Close()
|
||||
_ = b.ln.Close()
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
for conn := range b.conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
clear(b.conns)
|
||||
}
|
||||
@ -405,6 +405,11 @@ func (rw *responseWriter) ReadFrom(r io.Reader) (int64, error) {
|
||||
// Close writes any remaining buffered response and
|
||||
// deallocates any active resources.
|
||||
func (rw *responseWriter) Close() error {
|
||||
if caddyhttp.ResponseWriterHijacked(rw.ResponseWriter) {
|
||||
rw.releaseEncoder()
|
||||
return nil
|
||||
}
|
||||
|
||||
// didn't write, probably head request
|
||||
if !rw.wroteHeader {
|
||||
cl, err := strconv.Atoi(rw.Header().Get("Content-Length"))
|
||||
@ -422,13 +427,20 @@ func (rw *responseWriter) Close() error {
|
||||
var err error
|
||||
if rw.w != nil {
|
||||
err = rw.w.Close()
|
||||
rw.w.Reset(nil)
|
||||
rw.config.writerPools[rw.encodingName].Put(rw.w)
|
||||
rw.w = nil
|
||||
rw.releaseEncoder()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (rw *responseWriter) releaseEncoder() {
|
||||
if rw.w == nil {
|
||||
return
|
||||
}
|
||||
rw.w.Reset(nil)
|
||||
rw.config.writerPools[rw.encodingName].Put(rw.w)
|
||||
rw.w = nil
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter.
|
||||
func (rw *responseWriter) Unwrap() http.ResponseWriter {
|
||||
return rw.ResponseWriter
|
||||
|
||||
@ -70,6 +70,7 @@ type responseRecorder struct {
|
||||
size int
|
||||
wroteHeader bool
|
||||
stream bool
|
||||
hijacked bool
|
||||
|
||||
readSize *int
|
||||
}
|
||||
@ -144,7 +145,8 @@ func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer
|
||||
|
||||
// WriteHeader writes the headers with statusCode to the wrapped
|
||||
// ResponseWriter unless the response is to be buffered instead.
|
||||
// 1xx responses are never buffered.
|
||||
// 1xx responses are never buffered, except 101 which is treated
|
||||
// as a final upgrade response.
|
||||
func (rr *responseRecorder) WriteHeader(statusCode int) {
|
||||
if rr.wroteHeader {
|
||||
return
|
||||
@ -153,6 +155,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
|
||||
// save statusCode always, in case HTTP middleware upgrades websocket
|
||||
// connections by manually setting headers and writing status 101
|
||||
rr.statusCode = statusCode
|
||||
if statusCode == http.StatusSwitchingProtocols {
|
||||
rr.stream = true
|
||||
rr.wroteHeader = true
|
||||
rr.ResponseWriterWrapper.WriteHeader(statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
// decide whether we should buffer the response
|
||||
if rr.shouldBuffer == nil {
|
||||
@ -222,7 +230,14 @@ func (rr *responseRecorder) Buffered() bool {
|
||||
return !rr.stream
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) Hijacked() bool {
|
||||
return rr.hijacked
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) WriteResponse() error {
|
||||
if rr.hijacked {
|
||||
return nil
|
||||
}
|
||||
if rr.statusCode == 0 {
|
||||
// could happen if no handlers actually wrote anything,
|
||||
// and this prevents a panic; status must be > 0
|
||||
@ -258,13 +273,16 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
|
||||
conn = &hijackedConn{conn, rr}
|
||||
rr.hijacked = true
|
||||
rr.stream = true
|
||||
rr.wroteHeader = true
|
||||
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not.
|
||||
// Return the raw hijacked connection so upgraded stream traffic does not keep
|
||||
// traversing the response recorder hot path.
|
||||
brw.Writer.Reset(conn)
|
||||
|
||||
buffered := brw.Reader.Buffered()
|
||||
if buffered != 0 {
|
||||
conn.(*hijackedConn).updateReadSize(buffered)
|
||||
data, _ := brw.Peek(buffered)
|
||||
brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn))
|
||||
// peek to make buffered data appear, as Reset will make it 0
|
||||
@ -275,40 +293,24 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return conn, brw, nil
|
||||
}
|
||||
|
||||
// used to track the size of hijacked response writers
|
||||
type hijackedConn struct {
|
||||
net.Conn
|
||||
rr *responseRecorder
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) updateReadSize(n int) {
|
||||
if hc.rr.readSize != nil {
|
||||
*hc.rr.readSize += n
|
||||
// ResponseWriterHijacked reports whether w or one of its wrapped response
|
||||
// writers has been hijacked.
|
||||
func ResponseWriterHijacked(w http.ResponseWriter) bool {
|
||||
for w != nil {
|
||||
if hijacked, ok := w.(interface{ Hijacked() bool }); ok && hijacked.Hijacked() {
|
||||
return true
|
||||
}
|
||||
unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter })
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
next := unwrapper.Unwrap()
|
||||
if next == w {
|
||||
return false
|
||||
}
|
||||
w = next
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) Read(p []byte) (int, error) {
|
||||
n, err := hc.Conn.Read(p)
|
||||
hc.updateReadSize(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := io.Copy(w, hc.Conn)
|
||||
hc.updateReadSize(int(n))
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) Write(p []byte) (int, error) {
|
||||
n, err := hc.Conn.Write(p)
|
||||
hc.rr.size += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
|
||||
n, err := io.Copy(hc.Conn, r)
|
||||
hc.rr.size += int(n)
|
||||
return n, err
|
||||
return false
|
||||
}
|
||||
|
||||
// ResponseRecorder is a http.ResponseWriter that records
|
||||
@ -319,6 +321,7 @@ type ResponseRecorder interface {
|
||||
Status() int
|
||||
Buffer() *bytes.Buffer
|
||||
Buffered() bool
|
||||
Hijacked() bool
|
||||
Size() int
|
||||
WriteResponse() error
|
||||
}
|
||||
@ -338,7 +341,4 @@ var (
|
||||
// see PR #5022 (25%-50% speedup)
|
||||
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
|
||||
_ io.ReaderFrom = (*responseRecorder)(nil)
|
||||
_ io.ReaderFrom = (*hijackedConn)(nil)
|
||||
|
||||
_ io.WriterTo = (*hijackedConn)(nil)
|
||||
)
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type responseWriterSpy interface {
|
||||
@ -44,6 +47,50 @@ func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
|
||||
|
||||
func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }
|
||||
|
||||
type hijackRespWriter struct {
|
||||
baseRespWriter
|
||||
header http.Header
|
||||
status int
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func newHijackRespWriter() *hijackRespWriter {
|
||||
return &hijackRespWriter{
|
||||
header: make(http.Header),
|
||||
conn: stubConn{},
|
||||
}
|
||||
}
|
||||
|
||||
func (hrw *hijackRespWriter) Header() http.Header {
|
||||
return hrw.header
|
||||
}
|
||||
|
||||
func (hrw *hijackRespWriter) WriteHeader(statusCode int) {
|
||||
hrw.status = statusCode
|
||||
}
|
||||
|
||||
func (hrw *hijackRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
br := bufio.NewReader(hrw.conn)
|
||||
bw := bufio.NewWriter(hrw.conn)
|
||||
return hrw.conn, bufio.NewReadWriter(br, bw), nil
|
||||
}
|
||||
|
||||
type stubConn struct{}
|
||||
|
||||
func (stubConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (stubConn) Write(p []byte) (int, error) { return len(p), nil }
|
||||
func (stubConn) Close() error { return nil }
|
||||
func (stubConn) LocalAddr() net.Addr { return stubAddr("local") }
|
||||
func (stubConn) RemoteAddr() net.Addr { return stubAddr("remote") }
|
||||
func (stubConn) SetDeadline(time.Time) error { return nil }
|
||||
func (stubConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (stubConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type stubAddr string
|
||||
|
||||
func (a stubAddr) Network() string { return "tcp" }
|
||||
func (a stubAddr) String() string { return string(a) }
|
||||
|
||||
func TestResponseWriterWrapperReadFrom(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
responseWriter responseWriterSpy
|
||||
@ -169,3 +216,49 @@ func TestResponseRecorderReadFrom(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) {
|
||||
w := newHijackRespWriter()
|
||||
var buf bytes.Buffer
|
||||
|
||||
rr := NewResponseRecorder(w, &buf, func(status int, header http.Header) bool {
|
||||
return true
|
||||
})
|
||||
rr.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
||||
if rr.Buffered() {
|
||||
t.Fatal("101 switching protocols response should not remain buffered")
|
||||
}
|
||||
if rr.Status() != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols)
|
||||
}
|
||||
if w.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("underlying status = %d, want %d", w.status, http.StatusSwitchingProtocols)
|
||||
}
|
||||
|
||||
hj, ok := rr.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Fatal("response recorder does not implement http.Hijacker")
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("Hijack() error = %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if !rr.Hijacked() {
|
||||
t.Fatal("response recorder should report hijacked state")
|
||||
}
|
||||
if !ResponseWriterHijacked(rr) {
|
||||
t.Fatal("ResponseWriterHijacked() should report true after hijack")
|
||||
}
|
||||
if err := rr.WriteResponse(); err != nil {
|
||||
t.Fatalf("WriteResponse() after hijack returned error: %v", err)
|
||||
}
|
||||
if rr.Size() != 0 {
|
||||
t.Fatalf("size = %d, want 0 after hijack handshake", rr.Size())
|
||||
}
|
||||
if got := w.Written(); got != "" {
|
||||
t.Fatalf("unexpected buffered body write after hijack: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,6 +99,8 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
|
||||
// stream_buffer_size <size>
|
||||
// stream_timeout <duration>
|
||||
// stream_close_delay <duration>
|
||||
// stream_retain_on_reload
|
||||
// stream_log_skip_handshake
|
||||
// verbose_logs
|
||||
//
|
||||
// # request manipulation
|
||||
@ -703,6 +705,18 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
h.StreamCloseDelay = caddy.Duration(dur)
|
||||
}
|
||||
|
||||
case "stream_retain_on_reload":
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.StreamRetainOnReload = true
|
||||
|
||||
case "stream_log_skip_handshake":
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.StreamLogSkipHandshake = true
|
||||
|
||||
case "trusted_proxies":
|
||||
for d.NextArg() {
|
||||
if d.Val() == "private_ranges" {
|
||||
|
||||
@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
|
||||
hrc.isFinalized = true
|
||||
|
||||
// write the response
|
||||
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger)
|
||||
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr)
|
||||
}
|
||||
|
||||
// CopyResponseHeadersHandler is a special HTTP handler which may
|
||||
|
||||
@ -16,6 +16,10 @@ import (
|
||||
var reverseProxyMetrics = struct {
|
||||
once sync.Once
|
||||
upstreamsHealthy *prometheus.GaugeVec
|
||||
streamsActive *prometheus.GaugeVec
|
||||
streamsTotal *prometheus.CounterVec
|
||||
streamDuration *prometheus.HistogramVec
|
||||
streamBytes *prometheus.CounterVec
|
||||
logger *zap.Logger
|
||||
}{}
|
||||
|
||||
@ -23,6 +27,8 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
|
||||
const ns, sub = "caddy", "reverse_proxy"
|
||||
|
||||
upstreamsLabels := []string{"upstream"}
|
||||
streamResultLabels := []string{"upstream", "result"}
|
||||
streamBytesLabels := []string{"upstream", "direction"}
|
||||
reverseProxyMetrics.once.Do(func() {
|
||||
reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: ns,
|
||||
@ -30,6 +36,31 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
|
||||
Name: "upstreams_healthy",
|
||||
Help: "Health status of reverse proxy upstreams.",
|
||||
}, upstreamsLabels)
|
||||
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "streams_active",
|
||||
Help: "Number of currently active upgraded reverse proxy streams.",
|
||||
}, upstreamsLabels)
|
||||
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "streams_total",
|
||||
Help: "Total number of upgraded reverse proxy streams by close result.",
|
||||
}, streamResultLabels)
|
||||
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "stream_duration_seconds",
|
||||
Help: "Duration of upgraded reverse proxy streams by close result.",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
}, streamResultLabels)
|
||||
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "stream_bytes_total",
|
||||
Help: "Total bytes proxied across upgraded reverse proxy streams.",
|
||||
}, streamBytesLabels)
|
||||
})
|
||||
|
||||
// duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because
|
||||
@ -42,10 +73,58 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamsActive,
|
||||
NewCollector: reverseProxyMetrics.streamsActive,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamsTotal,
|
||||
NewCollector: reverseProxyMetrics.streamsTotal,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamDuration,
|
||||
NewCollector: reverseProxyMetrics.streamDuration,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamBytes,
|
||||
NewCollector: reverseProxyMetrics.streamBytes,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics")
|
||||
}
|
||||
|
||||
func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) {
|
||||
labels := prometheus.Labels{"upstream": upstream}
|
||||
reverseProxyMetrics.streamsActive.With(labels).Inc()
|
||||
|
||||
var once sync.Once
|
||||
return func(result string, duration time.Duration, toBackend, fromBackend int64) {
|
||||
once.Do(func() {
|
||||
reverseProxyMetrics.streamsActive.With(labels).Dec()
|
||||
reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc()
|
||||
reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds())
|
||||
if toBackend > 0 {
|
||||
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend))
|
||||
}
|
||||
if fromBackend > 0 {
|
||||
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type metricsUpstreamsHealthyUpdater struct {
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
67
modules/caddyhttp/reverseproxy/metrics_test.go
Normal file
67
modules/caddyhttp/reverseproxy/metrics_test.go
Normal file
@ -0,0 +1,67 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
)
|
||||
|
||||
func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) {
|
||||
const upstream = "127.0.0.1:7443"
|
||||
|
||||
// Use fresh metric vectors for deterministic assertions in this unit test.
|
||||
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
|
||||
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
|
||||
|
||||
finish := trackActiveStream(upstream)
|
||||
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 {
|
||||
t.Fatalf("active streams = %v, want 1", got)
|
||||
}
|
||||
|
||||
finish("closed", 150*time.Millisecond, 1234, 4321)
|
||||
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 {
|
||||
t.Fatalf("active streams = %v, want 0", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 {
|
||||
t.Fatalf("streams_total closed = %v, want 1", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 {
|
||||
t.Fatalf("bytes to_upstream = %v, want 1234", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 {
|
||||
t.Fatalf("bytes from_upstream = %v, want 4321", got)
|
||||
}
|
||||
|
||||
// A second finish call should be ignored by the once guard.
|
||||
finish("error", 1*time.Second, 111, 222)
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 {
|
||||
t.Fatalf("streams_total error = %v, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) {
|
||||
const upstream = "127.0.0.1:9000"
|
||||
|
||||
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
|
||||
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
|
||||
|
||||
trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0)
|
||||
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 {
|
||||
t.Fatalf("bytes to_upstream = %v, want 0", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 {
|
||||
t.Fatalf("bytes from_upstream = %v, want 0", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 {
|
||||
t.Fatalf("streams_total timeout = %v, want 1", got)
|
||||
}
|
||||
}
|
||||
@ -186,6 +186,18 @@ type Handler struct {
|
||||
// by the previous config closing. Default: no delay.
|
||||
StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"`
|
||||
|
||||
// If true, upgraded connections such as WebSockets are retained across
|
||||
// config reloads when their upstream still exists in the new config.
|
||||
// Connections using upstreams that are removed are closed during cleanup.
|
||||
// By default this is false, preserving legacy behavior where upgraded
|
||||
// connections are closed on reload (optionally delayed by stream_close_delay).
|
||||
StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"`
|
||||
|
||||
// If true, suppresses the access log entry normally emitted when an
|
||||
// upgraded stream handshake completes and the request unwinds. By default
|
||||
// the handshake is still logged as a normal request with status 101.
|
||||
StreamLogSkipHandshake bool `json:"stream_log_skip_handshake,omitempty"`
|
||||
|
||||
// If configured, rewrites the copy of the upstream request.
|
||||
// Allows changing the request method and URI (path and query).
|
||||
// Since the rewrite is applied to the copy, it does not persist
|
||||
@ -240,10 +252,9 @@ type Handler struct {
|
||||
// Holds the handle_response Caddyfile tokens while adapting
|
||||
handleResponseSegments []*caddyfile.Dispenser
|
||||
|
||||
// Stores upgraded requests (hijacked connections) for proper cleanup
|
||||
connections map[io.ReadWriteCloser]openConnection
|
||||
connectionsCloseTimer *time.Timer
|
||||
connectionsMu *sync.Mutex
|
||||
// Tracks hijacked/upgraded connections (WebSocket etc.) so they can be
|
||||
// closed when their upstream is removed from the config.
|
||||
tunnel *tunnelState
|
||||
|
||||
ctx caddy.Context
|
||||
logger *zap.Logger
|
||||
@ -267,8 +278,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||
h.events = eventAppIface.(*caddyevents.App)
|
||||
h.ctx = ctx
|
||||
h.logger = ctx.Logger()
|
||||
h.connections = make(map[io.ReadWriteCloser]openConnection)
|
||||
h.connectionsMu = new(sync.Mutex)
|
||||
h.tunnel = newTunnelState(h.logger, time.Duration(h.StreamCloseDelay))
|
||||
|
||||
// warn about unsafe buffering config
|
||||
if h.RequestBuffers == -1 || h.ResponseBuffers == -1 {
|
||||
@ -439,13 +449,29 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||
|
||||
// Cleanup cleans up the resources made by h.
|
||||
func (h *Handler) Cleanup() error {
|
||||
err := h.cleanupConnections()
|
||||
|
||||
// remove hosts from our config from the pool
|
||||
for _, upstream := range h.Upstreams {
|
||||
_, _ = hosts.Delete(upstream.String())
|
||||
if !h.StreamRetainOnReload {
|
||||
// Legacy behavior: close all upgraded connections on reload, either
|
||||
// immediately or after StreamCloseDelay.
|
||||
err := h.tunnel.cleanupConnections()
|
||||
for _, upstream := range h.Upstreams {
|
||||
_, _ = hosts.Delete(upstream.String())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, upstream := range h.Upstreams {
|
||||
// hosts.Delete returns deleted=true when the ref count reaches zero,
|
||||
// meaning no other active config references this upstream. In that
|
||||
// case close any tunnels proxying to it; otherwise let them survive
|
||||
// to their natural end since the upstream is still in use.
|
||||
deleted, _ := hosts.Delete(upstream.String())
|
||||
if deleted {
|
||||
if closeErr := h.tunnel.closeConnectionsForUpstream(upstream.String()); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1092,10 +1118,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
|
||||
// we use the original request here, so that any routes from 'next'
|
||||
// see the original request rather than the proxy cloned request.
|
||||
hrc := &handleResponseContext{
|
||||
handler: h,
|
||||
response: res,
|
||||
start: start,
|
||||
logger: logger,
|
||||
handler: h,
|
||||
response: res,
|
||||
start: start,
|
||||
logger: logger,
|
||||
upstreamAddr: di.Upstream.String(),
|
||||
}
|
||||
ctx := origReq.Context()
|
||||
ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc)
|
||||
@ -1125,7 +1152,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
|
||||
}
|
||||
|
||||
// copy the response body and headers back to the upstream client
|
||||
return h.finalizeResponse(rw, req, res, repl, start, logger)
|
||||
return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String())
|
||||
}
|
||||
|
||||
// finalizeResponse prepares and copies the response.
|
||||
@ -1136,11 +1163,12 @@ func (h *Handler) finalizeResponse(
|
||||
repl *caddy.Replacer,
|
||||
start time.Time,
|
||||
logger *zap.Logger,
|
||||
upstreamAddr string,
|
||||
) error {
|
||||
// deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
var wg sync.WaitGroup
|
||||
h.handleUpgradeResponse(logger, &wg, rw, req, res)
|
||||
h.handleUpgradeResponse(logger, &wg, rw, req, res, upstreamAddr)
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
@ -1719,6 +1747,9 @@ type handleResponseContext struct {
|
||||
// i.e. copied and closed, to make sure that it doesn't
|
||||
// happen twice.
|
||||
isFinalized bool
|
||||
|
||||
// upstreamAddr is the selected upstream address for this request.
|
||||
upstreamAddr string
|
||||
}
|
||||
|
||||
// proxyHandleResponseContextCtxKey is the context key for the active proxy handler
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"io"
|
||||
weakrand "math/rand/v2"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@ -35,6 +36,7 @@ import (
|
||||
"go.uber.org/zap/zapcore"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
@ -57,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) {
|
||||
reqUpType := upgradeType(req.Header)
|
||||
resUpType := upgradeType(res.Header)
|
||||
|
||||
@ -90,13 +92,22 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
normalizeWebsocketHeaders(rw.Header())
|
||||
|
||||
// Capture all h fields needed by the tunnel now, so that the Handler (h)
|
||||
// is not referenced after this function returns (for HTTP/1.1 hijacked
|
||||
// connections the tunnel runs in a detached goroutine).
|
||||
tunnel := h.tunnel
|
||||
bufferSize := h.StreamBufferSize
|
||||
streamTimeout := time.Duration(h.StreamTimeout)
|
||||
|
||||
var (
|
||||
conn io.ReadWriteCloser
|
||||
brw *bufio.ReadWriter
|
||||
isH2 bool
|
||||
)
|
||||
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
|
||||
// TODO: once we can reliably detect backend support this, it can be removed for those backends
|
||||
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
|
||||
isH2 = true
|
||||
req.Body = body
|
||||
rw.Header().Del("Upgrade")
|
||||
rw.Header().Del("Connection")
|
||||
@ -143,26 +154,24 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
}
|
||||
}
|
||||
|
||||
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
|
||||
backConnCloseCh := make(chan struct{})
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
conn.Close()
|
||||
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
|
||||
c.Write(zap.Duration("duration", time.Since(start)))
|
||||
}
|
||||
}()
|
||||
// For H2 extended connect: close backConn when the request context is
|
||||
// cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections
|
||||
// we skip this because req.Context() may be cancelled when ServeHTTP
|
||||
// returns early, which would prematurely close the backend connection.
|
||||
if isH2 {
|
||||
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
|
||||
backConnCloseCh := make(chan struct{})
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
}
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
|
||||
@ -184,13 +193,11 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the hijacked client connection, and the new connection established
|
||||
// with the backend, are both closed in the event of a server shutdown. This
|
||||
// is done by registering them. We also try to gracefully close connections
|
||||
// we recognize as websockets.
|
||||
// We need to make sure the client connection messages (i.e. to upstream)
|
||||
// are masked, so we need to know whether the connection is considered the
|
||||
// server or the client side of the proxy.
|
||||
// Register both connections with the tunnel tracker. We also try to
|
||||
// gracefully close connections we recognize as websockets. We need to make
|
||||
// sure the client connection messages (i.e. to upstream) are masked, so we
|
||||
// need to know whether the connection is considered the server or the
|
||||
// client side of the proxy.
|
||||
gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error {
|
||||
if isWebsocket(req) {
|
||||
return func() error {
|
||||
@ -199,43 +206,186 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false))
|
||||
deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true))
|
||||
defer deleteFrontConn()
|
||||
deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), upstreamAddr)
|
||||
deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), upstreamAddr)
|
||||
if h.StreamLogSkipHandshake {
|
||||
caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true)
|
||||
}
|
||||
repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
repl.Set("http.reverse_proxy.upgraded", true)
|
||||
finishMetrics := trackActiveStream(upstreamAddr)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
if isH2 {
|
||||
h.handleH2UpgradeTunnel(logger, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics)
|
||||
} else {
|
||||
h.handleDetachedUpgradeTunnel(logger, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics)
|
||||
// Return immediately without touching wg. finalizeResponse's
|
||||
// wg.Wait() returns at once since wg was never incremented.
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleH2UpgradeTunnel(
|
||||
logger *zap.Logger,
|
||||
wg *sync.WaitGroup,
|
||||
conn io.ReadWriteCloser,
|
||||
backConn io.ReadWriteCloser,
|
||||
deleteFrontConn func(),
|
||||
deleteBackConn func(),
|
||||
bufferSize int,
|
||||
streamTimeout time.Duration,
|
||||
start time.Time,
|
||||
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
|
||||
) {
|
||||
// H2 extended connect: ServeHTTP must block because rw and req.Body are
|
||||
// only valid while the handler goroutine is running. Defers clean up
|
||||
// when the select below fires and this function returns.
|
||||
defer deleteBackConn()
|
||||
|
||||
spc := switchProtocolCopier{
|
||||
user: conn,
|
||||
backend: backConn,
|
||||
wg: wg,
|
||||
bufferSize: h.StreamBufferSize,
|
||||
}
|
||||
|
||||
// setup the timeout if requested
|
||||
var timeoutc <-chan time.Time
|
||||
if h.StreamTimeout > 0 {
|
||||
timer := time.NewTimer(time.Duration(h.StreamTimeout))
|
||||
defer timer.Stop()
|
||||
timeoutc = timer.C
|
||||
}
|
||||
defer deleteFrontConn()
|
||||
var (
|
||||
toBackend int64
|
||||
fromBackend int64
|
||||
result = "closed"
|
||||
)
|
||||
|
||||
// when a stream timeout is encountered, no error will be read from errc
|
||||
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
|
||||
// see: https://github.com/caddyserver/caddy/issues/7418
|
||||
errc := make(chan error, 2)
|
||||
spc := switchProtocolCopier{
|
||||
user: conn,
|
||||
backend: backConn,
|
||||
wg: wg,
|
||||
bufferSize: bufferSize,
|
||||
sent: &toBackend,
|
||||
received: &fromBackend,
|
||||
}
|
||||
wg.Add(2)
|
||||
|
||||
var timeoutc <-chan time.Time
|
||||
if streamTimeout > 0 {
|
||||
timer := time.NewTimer(streamTimeout)
|
||||
defer timer.Stop()
|
||||
timeoutc = timer.C
|
||||
}
|
||||
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
select {
|
||||
case err := <-errc:
|
||||
result = classifyStreamResult(err)
|
||||
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
case time := <-timeoutc:
|
||||
case t := <-timeoutc:
|
||||
result = "timeout"
|
||||
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
|
||||
c.Write(zap.Time("timeout", time))
|
||||
c.Write(zap.Time("timeout", t))
|
||||
}
|
||||
}
|
||||
|
||||
// Close both ends to unblock the still-running copy goroutine,
|
||||
// then wait for it so byte counts are final before metrics/logging.
|
||||
conn.Close()
|
||||
backConn.Close()
|
||||
wg.Wait()
|
||||
|
||||
finishMetrics(result, time.Since(start), toBackend, fromBackend)
|
||||
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
|
||||
c.Write(
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
zap.Int64("bytes_to_backend", toBackend),
|
||||
zap.Int64("bytes_from_backend", fromBackend),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleDetachedUpgradeTunnel(
|
||||
logger *zap.Logger,
|
||||
conn io.ReadWriteCloser,
|
||||
backConn io.ReadWriteCloser,
|
||||
deleteFrontConn func(),
|
||||
deleteBackConn func(),
|
||||
bufferSize int,
|
||||
streamTimeout time.Duration,
|
||||
start time.Time,
|
||||
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
|
||||
) {
|
||||
// HTTP/1.1 hijacked connection: launch a detached goroutine so that
|
||||
// ServeHTTP can return immediately, allowing the Handler to be GC'd
|
||||
// after a config reload. The goroutine captures only tunnel (a small
|
||||
// *tunnelState), logger, conn/backConn, and scalar config values.
|
||||
go func() {
|
||||
var (
|
||||
toBackend int64
|
||||
fromBackend int64
|
||||
result = "closed"
|
||||
)
|
||||
defer deleteBackConn()
|
||||
defer deleteFrontConn()
|
||||
defer func() {
|
||||
finishMetrics(result, time.Since(start), toBackend, fromBackend)
|
||||
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
|
||||
c.Write(
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
zap.Int64("bytes_to_backend", toBackend),
|
||||
zap.Int64("bytes_from_backend", fromBackend),
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
var innerWg sync.WaitGroup
|
||||
// when a stream timeout is encountered, no error will be read from errc
|
||||
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
|
||||
// see: https://github.com/caddyserver/caddy/issues/7418
|
||||
errc := make(chan error, 2)
|
||||
spc := switchProtocolCopier{
|
||||
user: conn,
|
||||
backend: backConn,
|
||||
wg: &innerWg,
|
||||
bufferSize: bufferSize,
|
||||
sent: &toBackend,
|
||||
received: &fromBackend,
|
||||
}
|
||||
innerWg.Add(2)
|
||||
|
||||
var timeoutc <-chan time.Time
|
||||
if streamTimeout > 0 {
|
||||
timer := time.NewTimer(streamTimeout)
|
||||
defer timer.Stop()
|
||||
timeoutc = timer.C
|
||||
}
|
||||
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
select {
|
||||
case err := <-errc:
|
||||
result = classifyStreamResult(err)
|
||||
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
case t := <-timeoutc:
|
||||
result = "timeout"
|
||||
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
|
||||
c.Write(zap.Time("timeout", t))
|
||||
}
|
||||
}
|
||||
|
||||
// Close both ends to unblock the still-running copy goroutine,
|
||||
// then wait for it to finish so byte counts are accurate before
|
||||
// the deferred log fires.
|
||||
conn.Close()
|
||||
backConn.Close()
|
||||
innerWg.Wait()
|
||||
}()
|
||||
}
|
||||
|
||||
func classifyStreamResult(err error) string {
|
||||
if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
||||
return "closed"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
|
||||
// flushInterval returns the p.FlushInterval value, conditionally
|
||||
@ -375,75 +525,86 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
|
||||
}
|
||||
}
|
||||
|
||||
// registerConnection holds onto conn so it can be closed in the event
|
||||
// of a server shutdown. This is useful because hijacked connections or
|
||||
// connections dialed to backends don't close when server is shut down.
|
||||
// The caller should call the returned delete() function when the
|
||||
// connection is done to remove it from memory.
|
||||
func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) {
|
||||
h.connectionsMu.Lock()
|
||||
h.connections[conn] = openConnection{conn, gracefulClose}
|
||||
h.connectionsMu.Unlock()
|
||||
return func() {
|
||||
h.connectionsMu.Lock()
|
||||
delete(h.connections, conn)
|
||||
// if there is no connection left before the connections close timer fires
|
||||
if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
|
||||
// we release the timer that holds the reference to Handler
|
||||
if (*h.connectionsCloseTimer).Stop() {
|
||||
h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
|
||||
}
|
||||
h.connectionsCloseTimer = nil
|
||||
}
|
||||
h.connectionsMu.Unlock()
|
||||
// openConnection maps an open connection to an optional function for graceful
|
||||
// close and records which upstream address the connection is proxying to.
|
||||
type openConnection struct {
|
||||
conn io.ReadWriteCloser
|
||||
gracefulClose func() error
|
||||
upstream string
|
||||
}
|
||||
|
||||
// tunnelState tracks hijacked/upgraded connections for selective cleanup.
|
||||
type tunnelState struct {
|
||||
connections map[io.ReadWriteCloser]openConnection
|
||||
closeTimer *time.Timer
|
||||
closeDelay time.Duration
|
||||
mu sync.Mutex
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func newTunnelState(logger *zap.Logger, closeDelay time.Duration) *tunnelState {
|
||||
return &tunnelState{
|
||||
connections: make(map[io.ReadWriteCloser]openConnection),
|
||||
closeDelay: closeDelay,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnections immediately closes all hijacked connections (both to client and backend).
|
||||
func (h *Handler) closeConnections() error {
|
||||
var err error
|
||||
h.connectionsMu.Lock()
|
||||
defer h.connectionsMu.Unlock()
|
||||
// registerConnection stores conn in the tracking map. The caller must invoke
|
||||
// the returned del func when the connection is done.
|
||||
func (ts *tunnelState) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, upstream string) (del func()) {
|
||||
ts.mu.Lock()
|
||||
ts.connections[conn] = openConnection{conn, gracefulClose, upstream}
|
||||
ts.mu.Unlock()
|
||||
return func() {
|
||||
ts.mu.Lock()
|
||||
delete(ts.connections, conn)
|
||||
if len(ts.connections) == 0 && ts.closeTimer != nil {
|
||||
if ts.closeTimer.Stop() {
|
||||
ts.logger.Debug("stopped streaming connections close timer - all connections are already closed")
|
||||
}
|
||||
ts.closeTimer = nil
|
||||
}
|
||||
ts.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
for _, oc := range h.connections {
|
||||
// closeConnections closes all tracked connections.
|
||||
func (ts *tunnelState) closeConnections() error {
|
||||
var err error
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
for _, oc := range ts.connections {
|
||||
if oc.gracefulClose != nil {
|
||||
// this is potentially blocking while we have the lock on the connections
|
||||
// map, but that should be OK since the server has in theory shut down
|
||||
// and we are no longer using the connections map
|
||||
gracefulErr := oc.gracefulClose()
|
||||
if gracefulErr != nil && err == nil {
|
||||
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
|
||||
err = gracefulErr
|
||||
}
|
||||
}
|
||||
closeErr := oc.conn.Close()
|
||||
if closeErr != nil && err == nil {
|
||||
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanupConnections closes hijacked connections.
|
||||
// Depending on the value of StreamCloseDelay it does that either immediately
|
||||
// or sets up a timer that will do that later.
|
||||
func (h *Handler) cleanupConnections() error {
|
||||
if h.StreamCloseDelay == 0 {
|
||||
return h.closeConnections()
|
||||
// cleanupConnections closes upgraded connections. Depending on closeDelay it
|
||||
// does that either immediately or after a timer.
|
||||
func (ts *tunnelState) cleanupConnections() error {
|
||||
if ts.closeDelay == 0 {
|
||||
return ts.closeConnections()
|
||||
}
|
||||
|
||||
h.connectionsMu.Lock()
|
||||
defer h.connectionsMu.Unlock()
|
||||
// the handler is shut down, no new connection can appear,
|
||||
// so we can skip setting up the timer when there are no connections
|
||||
if len(h.connections) > 0 {
|
||||
delay := time.Duration(h.StreamCloseDelay)
|
||||
h.connectionsCloseTimer = time.AfterFunc(delay, func() {
|
||||
if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if len(ts.connections) > 0 {
|
||||
delay := ts.closeDelay
|
||||
ts.closeTimer = time.AfterFunc(delay, func() {
|
||||
if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
|
||||
c.Write(zap.Duration("delay", delay))
|
||||
}
|
||||
err := h.closeConnections()
|
||||
err := ts.closeConnections()
|
||||
if err != nil {
|
||||
if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil {
|
||||
if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil {
|
||||
c.Write(
|
||||
zap.Error(err),
|
||||
zap.Duration("delay", delay),
|
||||
@ -567,11 +728,26 @@ func isWebsocket(r *http.Request) bool {
|
||||
httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket")
|
||||
}
|
||||
|
||||
// openConnection maps an open connection to
|
||||
// an optional function for graceful close.
|
||||
type openConnection struct {
|
||||
conn io.ReadWriteCloser
|
||||
gracefulClose func() error
|
||||
// closeConnectionsForUpstream closes all tracked connections that were
|
||||
// established to the given upstream address.
|
||||
func (ts *tunnelState) closeConnectionsForUpstream(addr string) error {
|
||||
var err error
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
for _, oc := range ts.connections {
|
||||
if oc.upstream != addr {
|
||||
continue
|
||||
}
|
||||
if oc.gracefulClose != nil {
|
||||
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
|
||||
err = gracefulErr
|
||||
}
|
||||
}
|
||||
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
@ -642,16 +818,23 @@ type switchProtocolCopier struct {
|
||||
user, backend io.ReadWriteCloser
|
||||
wg *sync.WaitGroup
|
||||
bufferSize int
|
||||
// sent and received accumulate byte counts for each direction.
|
||||
// They are written before wg.Done() and read after wg.Wait(), so no
|
||||
// additional synchronization is needed beyond the WaitGroup barrier.
|
||||
sent *int64 // bytes copied to backend; must be non-nil
|
||||
received *int64 // bytes copied from backend; must be non-nil
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||
_, err := io.CopyBuffer(c.user, c.backend, c.buffer())
|
||||
n, err := io.CopyBuffer(c.user, c.backend, c.buffer())
|
||||
*c.received = n
|
||||
errc <- err
|
||||
c.wg.Done()
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||
_, err := io.CopyBuffer(c.backend, c.user, c.buffer())
|
||||
n, err := io.CopyBuffer(c.backend, c.user, c.buffer())
|
||||
*c.sent = n
|
||||
errc <- err
|
||||
c.wg.Done()
|
||||
}
|
||||
|
||||
@ -7,8 +7,10 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
)
|
||||
|
||||
func TestHandlerCopyResponse(t *testing.T) {
|
||||
@ -41,12 +43,15 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
var errc = make(chan error, 1)
|
||||
var dst bytes.Buffer
|
||||
var sent, received int64
|
||||
|
||||
copier := switchProtocolCopier{
|
||||
user: nopReadWriteCloser{Reader: strings.NewReader("hello")},
|
||||
backend: nopReadWriteCloser{Writer: &dst},
|
||||
wg: &wg,
|
||||
bufferSize: 7,
|
||||
sent: &sent,
|
||||
received: &received,
|
||||
}
|
||||
|
||||
buf := copier.buffer()
|
||||
@ -80,3 +85,132 @@ type nopReadWriteCloser struct {
|
||||
}
|
||||
|
||||
func (nopReadWriteCloser) Close() error { return nil }
|
||||
|
||||
type trackingReadWriteCloser struct {
|
||||
closed chan struct{}
|
||||
one sync.Once
|
||||
}
|
||||
|
||||
func newTrackingReadWriteCloser() *trackingReadWriteCloser {
|
||||
return &trackingReadWriteCloser{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil }
|
||||
func (c *trackingReadWriteCloser) Close() error {
|
||||
c.one.Do(func() {
|
||||
close(c.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *trackingReadWriteCloser) isClosed() bool {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) {
|
||||
ts := newTunnelState(caddy.Log(), 0)
|
||||
connA := newTrackingReadWriteCloser()
|
||||
connB := newTrackingReadWriteCloser()
|
||||
ts.registerConnection(connA, nil, "a")
|
||||
ts.registerConnection(connB, nil, "b")
|
||||
|
||||
h := &Handler{
|
||||
tunnel: ts,
|
||||
StreamRetainOnReload: false,
|
||||
}
|
||||
|
||||
if err := h.Cleanup(); err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
if !connA.isClosed() || !connB.isClosed() {
|
||||
t.Fatalf("legacy cleanup should close all upgraded connections")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) {
|
||||
ts := newTunnelState(caddy.Log(), 40*time.Millisecond)
|
||||
conn := newTrackingReadWriteCloser()
|
||||
ts.registerConnection(conn, nil, "a")
|
||||
|
||||
h := &Handler{
|
||||
tunnel: ts,
|
||||
StreamRetainOnReload: false,
|
||||
}
|
||||
|
||||
if err := h.Cleanup(); err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
if conn.isClosed() {
|
||||
t.Fatal("connection should not close immediately when stream_close_delay is set")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("connection did not close after stream_close_delay elapsed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerCleanupRetainModeClosesOnlyRemovedUpstreams(t *testing.T) {
|
||||
const upstreamA = "upstream-a"
|
||||
const upstreamB = "upstream-b"
|
||||
|
||||
// Simulate old+new configs both referencing upstreamA (refcount 2),
|
||||
// while upstreamB is only referenced by the old config (refcount 1).
|
||||
hosts.LoadOrStore(upstreamA, struct{}{})
|
||||
hosts.LoadOrStore(upstreamA, struct{}{})
|
||||
hosts.LoadOrStore(upstreamB, struct{}{})
|
||||
t.Cleanup(func() {
|
||||
_, _ = hosts.Delete(upstreamA)
|
||||
_, _ = hosts.Delete(upstreamA)
|
||||
_, _ = hosts.Delete(upstreamB)
|
||||
})
|
||||
|
||||
ts := newTunnelState(caddy.Log(), 0)
|
||||
connA := newTrackingReadWriteCloser()
|
||||
connB := newTrackingReadWriteCloser()
|
||||
ts.registerConnection(connA, nil, upstreamA)
|
||||
ts.registerConnection(connB, nil, upstreamB)
|
||||
|
||||
h := &Handler{
|
||||
tunnel: ts,
|
||||
StreamRetainOnReload: true,
|
||||
Upstreams: UpstreamPool{
|
||||
&Upstream{Dial: upstreamA},
|
||||
&Upstream{Dial: upstreamB},
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.Cleanup(); err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
if connA.isClosed() {
|
||||
t.Fatal("connection for retained upstream should remain open")
|
||||
}
|
||||
if !connB.isClosed() {
|
||||
t.Fatal("connection for removed upstream should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerUnmarshalCaddyfileStreamLogSkipHandshake(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(`
|
||||
reverse_proxy localhost:9000 {
|
||||
stream_log_skip_handshake
|
||||
}
|
||||
`)
|
||||
|
||||
var h Handler
|
||||
if err := h.UnmarshalCaddyfile(d); err != nil {
|
||||
t.Fatalf("UnmarshalCaddyfile() error = %v", err)
|
||||
}
|
||||
if !h.StreamLogSkipHandshake {
|
||||
t.Fatal("expected stream_log_skip_handshake to enable StreamLogSkipHandshake")
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user