Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] 1ca4406a36 Merge origin/master into rewrite-modify-query
Co-authored-by: francislavoie <2111701+francislavoie@users.noreply.github.com>
2026-05-08 18:18:51 +00:00
Francis Lavoie 2ad19885b5 rewrite: Add option to force modifying the query 2026-03-01 16:11:03 -05:00
35 changed files with 383 additions and 2665 deletions
+1 -3
View File
@@ -132,8 +132,6 @@ jobs:
- name: Run tests
# id: step_test
# continue-on-error: true
env:
GODEBUG: http2xconnect=1
run: |
# (go test -v -coverprofile=cover-profile.out -race ./... 2>&1) > test-results/test-result.out
go test -v -coverprofile="cover-profile.out" -short -race ./...
@@ -193,7 +191,7 @@ jobs:
retries=3
exit_code=0
while ((retries > 0)); do
GODEBUG=http2xconnect=1 CGO_ENABLED=0 go test -p 1 -v ./...
CGO_ENABLED=0 go test -p 1 -v ./...
exit_code=$?
if ((exit_code == 0)); then
break
+38 -13
View File
@@ -120,6 +120,10 @@ type AdminConfig struct {
//
// EXPERIMENTAL: This feature is subject to change.
Remote *RemoteAdmin `json:"remote,omitempty"`
// Holds onto the routers so that we can later provision them
// if they require provisioning.
routers []AdminRouter
}
// ConfigSettings configures the management of configuration.
@@ -218,7 +222,7 @@ type AdminPermissions struct {
// newAdminHandler reads admin's config and returns an http.Handler suitable
// for use in an admin endpoint server, which will be listening on listenAddr.
func (admin *AdminConfig) newAdminHandler(addr NetworkAddress, remote bool, ctx Context) (adminHandler, error) {
func (admin *AdminConfig) newAdminHandler(addr NetworkAddress, remote bool, _ Context) adminHandler {
muxWrap := adminHandler{mux: http.NewServeMux()}
// secure the local or remote endpoint respectively
@@ -275,21 +279,34 @@ func (admin *AdminConfig) newAdminHandler(addr NetworkAddress, remote bool, ctx
// register third-party module endpoints
for _, m := range GetModules("admin.api") {
router := m.New().(AdminRouter)
// provision the router before registering its routes, so
// handlers have access to all provisioned state
if provisioner, ok := router.(Provisioner); ok {
if err := provisioner.Provision(ctx); err != nil {
return adminHandler{}, fmt.Errorf("provisioning admin router module %s: %v", m.ID, err)
}
}
for _, route := range router.Routes() {
addRoute(route.Pattern, handlerLabel, route.Handler)
}
admin.routers = append(admin.routers, router)
}
return muxWrap, nil
return muxWrap
}
// provisionAdminRouters provisions all the router modules
// in the admin.api namespace that need provisioning.
func (admin *AdminConfig) provisionAdminRouters(ctx Context) error {
for _, router := range admin.routers {
provisioner, ok := router.(Provisioner)
if !ok {
continue
}
err := provisioner.Provision(ctx)
if err != nil {
return err
}
}
// We no longer need the routers once provisioned, allow for GC
admin.routers = nil
return nil
}
// allowedOrigins returns a list of origins that are allowed.
@@ -413,7 +430,11 @@ func replaceLocalAdminServer(cfg *Config, ctx Context) error {
return err
}
handler, err := cfg.Admin.newAdminHandler(addr, false, ctx)
handler := cfg.Admin.newAdminHandler(addr, false, ctx)
// run the provisioners for loaded modules to make sure local
// state is properly re-initialized in the new admin server
err = cfg.Admin.provisionAdminRouters(ctx)
if err != nil {
return err
}
@@ -537,7 +558,11 @@ func replaceRemoteAdminServer(ctx Context, cfg *Config) error {
// make the HTTP handler but disable Host/Origin enforcement
// because we are using TLS authentication instead
handler, err := cfg.Admin.newAdminHandler(addr, true, ctx)
handler := cfg.Admin.newAdminHandler(addr, true, ctx)
// run the provisioners for loaded modules to make sure local
// state is properly re-initialized in the new admin server
err = cfg.Admin.provisionAdminRouters(ctx)
if err != nil {
return err
}
+15 -9
View File
@@ -340,10 +340,7 @@ func TestAdminHandlerBuiltinRouteErrors(t *testing.T) {
if err != nil {
t.Fatalf("Failed to parse address: %v", err)
}
handler, err := cfg.Admin.newAdminHandler(addr, false, Context{})
if err != nil {
t.Fatalf("Failed to create admin handler: %v", err)
}
handler := cfg.Admin.newAdminHandler(addr, false, Context{})
tests := []struct {
name string
@@ -464,10 +461,7 @@ func TestNewAdminHandlerRouterRegistration(t *testing.T) {
admin := &AdminConfig{
EnforceOrigin: false,
}
handler, err := admin.newAdminHandler(addr, false, Context{})
if err != nil {
t.Fatalf("Failed to create admin handler: %v", err)
}
handler := admin.newAdminHandler(addr, false, Context{})
req := httptest.NewRequest("GET", "/mock", nil)
req.Host = "localhost:2019"
@@ -479,6 +473,10 @@ func TestNewAdminHandlerRouterRegistration(t *testing.T) {
t.Errorf("Expected status code %d but got %d", http.StatusOK, rr.Code)
t.Logf("Response body: %s", rr.Body.String())
}
if len(admin.routers) != 1 {
t.Errorf("Expected 1 router to be stored, got %d", len(admin.routers))
}
}
type mockProvisionableRouter struct {
@@ -516,16 +514,19 @@ func TestAdminRouterProvisioning(t *testing.T) {
name string
provisionErr error
wantErr bool
routersAfter int // expected number of routers after provisioning
}{
{
name: "successful provisioning",
provisionErr: nil,
wantErr: false,
routersAfter: 0,
},
{
name: "provisioning error",
provisionErr: fmt.Errorf("provision failed"),
wantErr: true,
routersAfter: 1,
},
}
@@ -561,7 +562,8 @@ func TestAdminRouterProvisioning(t *testing.T) {
t.Fatalf("Failed to parse address: %v", err)
}
_, err = admin.newAdminHandler(addr, false, Context{})
_ = admin.newAdminHandler(addr, false, Context{})
err = admin.provisionAdminRouters(Context{})
if test.wantErr {
if err == nil {
@@ -572,6 +574,10 @@ func TestAdminRouterProvisioning(t *testing.T) {
t.Errorf("Expected no error but got: %v", err)
}
}
if len(admin.routers) != test.routersAfter {
t.Errorf("Expected %d routers after provisioning, got %d", test.routersAfter, len(admin.routers))
}
})
}
}
+7
View File
@@ -440,6 +440,13 @@ func run(newCfg *Config, start bool) (Context, error) {
}
}()
// Provision any admin routers which may need to access
// some of the other apps at runtime
err = ctx.cfg.Admin.provisionAdminRouters(ctx)
if err != nil {
return ctx, err
}
// Start
err = func() error {
started := make([]string, 0, len(ctx.cfg.apps))
@@ -11,7 +11,9 @@ reverse_proxy 127.0.0.1:65535 {
@accel header X-Accel-Redirect *
handle_response @accel {
respond "Header X-Accel-Redirect!"
rewrite * {rp.header.X-Accel-Redirect} {
force_modify_query
}
}
@another {
@@ -104,10 +106,12 @@ reverse_proxy 127.0.0.1:65535 {
},
"routes": [
{
"group": "group0",
"handle": [
{
"body": "Header X-Accel-Redirect!",
"handler": "static_response"
"force_modify_query": true,
"handler": "rewrite",
"uri": "{http.reverse_proxy.header.X-Accel-Redirect}"
}
]
}
@@ -1,328 +0,0 @@
package integration
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"github.com/caddyserver/caddy/v2/caddytest"
)
var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support")
func TestReverseProxyExtendedConnectOverH2(t *testing.T) {
tester := caddytest.NewTester(t)
backend := newWebsocketUpgradeEchoBackend(t)
defer backend.Close()
tester.InitServer(fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
servers :9443 {
protocols h2
}
}
https://localhost:9443 {
reverse_proxy %s
}
`, backend.addr), "caddyfile")
const payload = "extended-connect-echo\n"
if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil {
if errors.Is(err, errExtendedConnectUnsupportedByPeer) {
t.Skipf("skipping extended CONNECT integration test: %v", err)
}
t.Fatalf("extended connect h2 echo failed: %v", err)
}
}
func assertExtendedConnectH2Echo(addr, payload string) error {
conn, err := tlsDialH2(addr)
if err != nil {
return fmt.Errorf("dialing h2 tls: %w", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return fmt.Errorf("setting deadline: %w", err)
}
fr := http2.NewFramer(conn, conn)
if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil {
return fmt.Errorf("writing client preface: %w", err)
}
if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil {
return fmt.Errorf("writing client settings: %w", err)
}
supported, err := waitForServerSettings(fr)
if err != nil {
return err
}
if !supported {
return errExtendedConnectUnsupportedByPeer
}
if err := waitForSettingsAck(fr); err != nil {
return err
}
if err := writeExtendedConnectHeaders(fr, addr); err != nil {
return err
}
status, err := readResponseStatus(fr, 1)
if err != nil {
return err
}
if status != "200" {
return fmt.Errorf("unexpected extended connect status: got=%s want=200", status)
}
if err := fr.WriteData(1, false, []byte(payload)); err != nil {
return fmt.Errorf("writing stream data: %w", err)
}
echo, err := readStreamData(fr, 1, len(payload))
if err != nil {
return err
}
if echo != payload {
return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload)
}
_ = fr.WriteRSTStream(1, http2.ErrCodeNo)
return nil
}
func tlsDialH2(addr string) (net.Conn, error) {
var lastErr error
for i := 0; i < 30; i++ {
dialer := &net.Dialer{Timeout: 2 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
ServerName: "localhost",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
})
if err == nil {
return conn, nil
}
lastErr = err
time.Sleep(100 * time.Millisecond)
}
return nil, lastErr
}
func waitForServerSettings(fr *http2.Framer) (bool, error) {
for {
frame, err := fr.ReadFrame()
if err != nil {
return false, fmt.Errorf("reading frame before connect: %w", err)
}
settings, ok := frame.(*http2.SettingsFrame)
if !ok {
continue
}
if settings.IsAck() {
continue
}
supported := false
if err := settings.ForeachSetting(func(s http2.Setting) error {
if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 {
supported = true
}
return nil
}); err != nil {
return false, fmt.Errorf("reading server settings: %w", err)
}
if err := fr.WriteSettingsAck(); err != nil {
return false, fmt.Errorf("writing settings ack: %w", err)
}
return supported, nil
}
}
func waitForSettingsAck(fr *http2.Framer) error {
for {
frame, err := fr.ReadFrame()
if err != nil {
return fmt.Errorf("reading settings ack: %w", err)
}
settings, ok := frame.(*http2.SettingsFrame)
if ok && settings.IsAck() {
return nil
}
}
}
func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error {
var hb bytes.Buffer
enc := hpack.NewEncoder(&hb)
for _, hf := range []hpack.HeaderField{
{Name: ":method", Value: "CONNECT"},
{Name: ":scheme", Value: "https"},
{Name: ":authority", Value: addr},
{Name: ":path", Value: "/upgrade"},
{Name: ":protocol", Value: "websocket"},
} {
if err := enc.WriteField(hf); err != nil {
return fmt.Errorf("encoding request headers: %w", err)
}
}
if err := fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1,
BlockFragment: hb.Bytes(),
EndHeaders: true,
EndStream: false,
}); err != nil {
return fmt.Errorf("writing extended connect headers: %w", err)
}
return nil
}
func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) {
var block bytes.Buffer
for {
frame, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading response headers: %w", err)
}
if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID {
return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode)
}
h, ok := frame.(*http2.HeadersFrame)
if !ok || h.StreamID != streamID {
continue
}
if _, err := block.Write(h.HeaderBlockFragment()); err != nil {
return "", fmt.Errorf("buffering response header fragment: %w", err)
}
for !h.HeadersEnded() {
next, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading continuation frame: %w", err)
}
c, ok := next.(*http2.ContinuationFrame)
if !ok || c.StreamID != streamID {
continue
}
if _, err := block.Write(c.HeaderBlockFragment()); err != nil {
return "", fmt.Errorf("buffering continuation fragment: %w", err)
}
if c.HeadersEnded() {
break
}
}
break
}
var status string
dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) {
if f.Name == ":status" {
status = f.Value
}
})
if _, err := dec.Write(block.Bytes()); err != nil {
return "", fmt.Errorf("decoding response header block: %w", err)
}
if status == "" {
return "", fmt.Errorf("missing :status in response headers")
}
return status, nil
}
func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) {
buf := make([]byte, 0, n)
for len(buf) < n {
frame, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading stream data: %w", err)
}
d, ok := frame.(*http2.DataFrame)
if !ok || d.StreamID != streamID {
continue
}
buf = append(buf, d.Data()...)
}
return string(buf[:n]), nil
}
type websocketUpgradeEchoBackend struct {
addr string
ln net.Listener
server *http.Server
}
func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend {
t.Helper()
backend := &websocketUpgradeEchoBackend{}
backend.server = &http.Server{
Handler: http.HandlerFunc(backend.serveHTTP),
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listening for websocket backend: %v", err)
}
backend.ln = ln
backend.addr = ln.Addr().String()
go func() {
_ = backend.server.Serve(ln)
}()
return backend
}
func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
conn, rw, err := hijacker.Hijack()
if err != nil {
return
}
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
_ = rw.Flush()
go func() {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}()
}
func (b *websocketUpgradeEchoBackend) Close() {
_ = b.server.Close()
_ = b.ln.Close()
}
@@ -1,130 +0,0 @@
package integration
import (
"bufio"
"fmt"
"io"
"net"
"net/textproto"
"strings"
"testing"
"time"
"github.com/caddyserver/caddy/v2/caddytest"
)
func TestReverseProxyUpgradeWithEncode(t *testing.T) {
tester := caddytest.NewTester(t)
backend := newUpgradeEchoBackend(t)
defer backend.Close()
tester.InitServer(fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
}
localhost:9080 {
route {
encode gzip
reverse_proxy %s
}
}
`, backend.addr), "caddyfile")
client := newUpgradedStreamClientWithHeaders(t, map[string]string{
"Accept-Encoding": "gzip",
})
defer client.Close()
if err := client.echo("encode-upgrade\n"); err != nil {
t.Fatalf("upgraded stream echo through encode failed: %v", err)
}
}
func TestReverseProxyUpgradeWithInterceptHandleResponse(t *testing.T) {
tester := caddytest.NewTester(t)
backend := newUpgradeEchoBackend(t)
defer backend.Close()
tester.InitServer(fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
}
localhost:9080 {
route {
intercept {
@upgrade status 101
handle_response @upgrade {
respond "should-not-run"
}
}
reverse_proxy %s
}
}
`, backend.addr), "caddyfile")
client := newUpgradedStreamClientWithHeaders(t, nil)
defer client.Close()
if err := client.echo("intercept-upgrade\n"); err != nil {
t.Fatalf("upgraded stream echo through intercept failed: %v", err)
}
}
func newUpgradedStreamClientWithHeaders(t *testing.T, extraHeaders map[string]string) *upgradedStreamClient {
t.Helper()
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
if err != nil {
t.Fatalf("dialing caddy: %v", err)
}
requestLines := []string{
"GET /upgrade HTTP/1.1",
"Host: localhost:9080",
"Connection: Upgrade",
"Upgrade: stress-stream",
}
for k, v := range extraHeaders {
requestLines = append(requestLines, k+": "+v)
}
requestLines = append(requestLines, "", "")
if _, err := io.WriteString(conn, strings.Join(requestLines, "\r\n")); err != nil {
_ = conn.Close()
t.Fatalf("writing upgrade request: %v", err)
}
reader := bufio.NewReader(conn)
tproto := textproto.NewReader(reader)
statusLine, err := tproto.ReadLine()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade status line: %v", err)
}
if !strings.Contains(statusLine, "101") {
_ = conn.Close()
t.Fatalf("unexpected upgrade status: %s", statusLine)
}
headers, err := tproto.ReadMIMEHeader()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade headers: %v", err)
}
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
_ = conn.Close()
t.Fatalf("unexpected upgrade response headers: %v", headers)
}
return &upgradedStreamClient{conn: conn, reader: reader}
}
@@ -1,504 +0,0 @@
package integration
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"os"
"runtime"
"runtime/debug"
"runtime/pprof"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2/caddytest"
)
const (
defaultStressStreamCount = 1
defaultStressReloadCount = 1
defaultStressCloseDelay = 500 * time.Millisecond
)
func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) {
tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{
LoadRequestTimeout: 30 * time.Second,
TestRequestTimeout: 30 * time.Second,
})
backend := newUpgradeEchoBackend(t)
defer backend.Close()
// Three scenarios, each sequential so they don't share Caddy state:
//
// legacy no delay, close on reload immediately (old default)
// close_delay stream_close_delay, the old "keep-alive workaround"
// detached stream_detached, the new explicit detached flag
//
// Reloads are spread across time and interleaved with echo-checks so
// stream health is exercised at each reload boundary, not only at the end.
legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0)
closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t))
detached := runReloadStress(t, tester, backend.addr, "detached", true, 0)
if legacy.aliveAfterReloads != 0 {
t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads)
}
if closeDelay.aliveBeforeDelayExpiry == 0 {
t.Fatalf("close_delay mode: all streams closed before delay expired (expected them alive)")
}
if closeDelay.aliveAfterReloads != 0 {
t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads)
}
if detached.aliveAfterReloads != detached.streamCount {
t.Fatalf("detached mode kept %d/%d upgraded streams alive after reloads", detached.aliveAfterReloads, detached.streamCount)
}
t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
formatBytes(legacy.beforeReload.HeapInuse),
formatBytes(legacy.midReload.HeapInuse),
formatBytes(legacy.afterReload.HeapInuse),
formatBytesDiff(legacy.beforeReload.HeapInuse, legacy.afterReload.HeapInuse),
legacy.beforeReload.HeapObjects, legacy.afterReload.HeapObjects,
legacy.beforeReload.handlerFrames, legacy.afterReload.handlerFrames,
)
t.Logf("close_delay heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
formatBytes(closeDelay.beforeReload.HeapInuse),
formatBytes(closeDelay.midReload.HeapInuse),
formatBytes(closeDelay.afterReload.HeapInuse),
formatBytesDiff(closeDelay.beforeReload.HeapInuse, closeDelay.afterReload.HeapInuse),
closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects,
closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames,
)
t.Logf("detached heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
formatBytes(detached.beforeReload.HeapInuse),
formatBytes(detached.midReload.HeapInuse),
formatBytes(detached.afterReload.HeapInuse),
formatBytesDiff(detached.beforeReload.HeapInuse, detached.afterReload.HeapInuse),
detached.beforeReload.HeapObjects, detached.afterReload.HeapObjects,
detached.beforeReload.handlerFrames, detached.afterReload.handlerFrames,
)
}
type stressRunResult struct {
streamCount int
aliveAfterReloads int
aliveBeforeDelayExpiry int // only meaningful for close_delay mode
beforeReload heapSnapshot
midReload heapSnapshot // after all reloads, before delay expiry clean-up
afterReload heapSnapshot // after all streams have been fully cleaned up
}
type heapSnapshot struct {
HeapInuse uint64
HeapObjects uint64
handlerFrames int
profileBytes int
}
// runReloadStress opens streamCount upgraded streams, then performs reloadCount
// config reloads spread over time. An echo check is performed every 6 reloads so
// stream health is exercised at each reload boundary rather than only at the end.
// closeDelay mirrors the stream_close_delay config option; pass 0 to disable.
func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, detach bool, closeDelay time.Duration) stressRunResult {
t.Helper()
const echoEvery = 6 // perform an echo check every N reloads
streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount)
reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount)
tester.InitServer(reloadStressConfig(backendAddr, detach, closeDelay, 0), "caddyfile")
clients := make([]*upgradedStreamClient, 0, streamCount)
for i := 0; i < streamCount; i++ {
client := newUpgradedStreamClient(t)
clients = append(clients, client)
if err := client.echo(fmt.Sprintf("%s-warmup-%02d\n", mode, i)); err != nil {
closeClients(clients)
t.Fatalf("warmup echo failed in %s mode: %v", mode, err)
}
}
defer closeClients(clients)
before := captureHeapSnapshot(t)
// Reloads are spread across time; between batches of echoEvery reloads we
// pause briefly and measure stream health so the snapshot reflects real-world
// reload cadence rather than a tight loop.
for i := 1; i <= reloadCount; i++ {
loadCaddyfileConfig(t, reloadStressConfig(backendAddr, detach, closeDelay, i))
// Small pause after each reload to let connection teardown propagate.
time.Sleep(50 * time.Millisecond)
if i%echoEvery == 0 {
alive := countAliveStreams(clients)
t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i)
// In detached mode, every stream must survive every reload (upstream unchanged).
if detach {
for j, client := range clients {
if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil {
t.Fatalf("detached mode stream %d died at reload %d: %v", j, i, err)
}
}
}
}
}
// mid snapshot: after all reloads but before any close_delay timer has fired
// (the delay is long enough to still be running at this point).
mid := captureHeapSnapshot(t)
// For legacy mode: the reloads close streams immediately; wait for that to complete.
// For close_delay mode: streams are still alive here; wait for the delay to fire.
// For detached mode: streams survive indefinitely; no wait needed.
var aliveBeforeDelayExpiry int
aliveAfterReloads := countAliveStreams(clients)
switch {
case detach:
// nothing to wait for
case closeDelay > 0:
// streams should still be alive at this point (delay hasn't expired)
aliveBeforeDelayExpiry = aliveAfterReloads
t.Logf("%s mode: %d/%d streams alive before close_delay expires; waiting %v for cleanup",
mode, aliveBeforeDelayExpiry, streamCount, closeDelay)
time.Sleep(closeDelay + 200*time.Millisecond)
aliveAfterReloads = countAliveStreams(clients)
default:
deadline := time.Now().Add(2 * time.Second)
for aliveAfterReloads > 0 && time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
aliveAfterReloads = countAliveStreams(clients)
}
}
after := captureHeapSnapshot(t)
t.Logf("%s mode heap profile size: before=%dB mid=%dB after=%dB objects(before=%d mid=%d after=%d)",
mode,
before.profileBytes, mid.profileBytes, after.profileBytes,
before.HeapObjects, mid.HeapObjects, after.HeapObjects,
)
return stressRunResult{
streamCount: streamCount,
aliveAfterReloads: aliveAfterReloads,
aliveBeforeDelayExpiry: aliveBeforeDelayExpiry,
beforeReload: before,
midReload: mid,
afterReload: after,
}
}
func envIntOrDefault(t *testing.T, key string, def int) int {
t.Helper()
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return def
}
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
t.Fatalf("invalid %s=%q: must be a positive integer", key, raw)
}
return v
}
func stressCloseDelay(t *testing.T) time.Duration {
t.Helper()
const key = "CADDY_STRESS_CLOSE_DELAY"
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return defaultStressCloseDelay
}
v, err := time.ParseDuration(raw)
if err != nil || v <= 0 {
t.Fatalf("invalid %s=%q: must be a positive duration", key, raw)
}
return v
}
func loadCaddyfileConfig(t *testing.T, rawConfig string) {
t.Helper()
client := &http.Client{Timeout: 30 * time.Second}
req, err := http.NewRequest(http.MethodPost, "http://localhost:2999/load", strings.NewReader(rawConfig))
if err != nil {
t.Fatalf("creating load request: %v", err)
}
req.Header.Set("Content-Type", "text/caddyfile")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("loading config: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("reading load response: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("loading config failed: status=%d body=%s", resp.StatusCode, body)
}
}
func reloadStressConfig(backendAddr string, detach bool, closeDelay time.Duration, revision int) string {
var directives string
if detach {
directives += "\n\t\tstream_detached"
}
if closeDelay > 0 {
directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay)
}
return fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
}
localhost:9080 {
reverse_proxy %s {
header_up X-Reload-Revision %d%s
}
}
`, backendAddr, revision, directives)
}
func captureHeapSnapshot(t *testing.T) heapSnapshot {
t.Helper()
runtime.GC()
debug.FreeOSMemory()
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
var buf bytes.Buffer
if err := pprof.Lookup("heap").WriteTo(&buf, 1); err != nil {
t.Fatalf("capturing heap profile: %v", err)
}
profile := buf.String()
return heapSnapshot{
HeapInuse: mem.HeapInuse,
HeapObjects: mem.HeapObjects,
handlerFrames: strings.Count(profile, "modules/caddyhttp/reverseproxy.(*Handler)"),
profileBytes: buf.Len(),
}
}
func countAliveStreams(clients []*upgradedStreamClient) int {
alive := 0
for index, client := range clients {
if err := client.echo(fmt.Sprintf("alive-check-%02d\n", index)); err == nil {
alive++
}
}
return alive
}
func closeClients(clients []*upgradedStreamClient) {
for _, client := range clients {
if client != nil {
_ = client.Close()
}
}
}
func formatBytes(value uint64) string {
const unit = 1024
if value < unit {
return fmt.Sprintf("%d B", value)
}
div, exp := uint64(unit), 0
for n := value / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(value)/float64(div), "KMGTPE"[exp])
}
func formatBytesDiff(before, after uint64) string {
if after >= before {
return "+" + formatBytes(after-before)
}
return "-" + formatBytes(before-after)
}
type upgradedStreamClient struct {
conn net.Conn
reader *bufio.Reader
mu sync.Mutex
}
func newUpgradedStreamClient(t *testing.T) *upgradedStreamClient {
t.Helper()
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
if err != nil {
t.Fatalf("dialing caddy: %v", err)
}
request := strings.Join([]string{
"GET /upgrade HTTP/1.1",
"Host: localhost:9080",
"Connection: Upgrade",
"Upgrade: stress-stream",
"",
"",
}, "\r\n")
if _, err := io.WriteString(conn, request); err != nil {
_ = conn.Close()
t.Fatalf("writing upgrade request: %v", err)
}
reader := bufio.NewReader(conn)
tproto := textproto.NewReader(reader)
statusLine, err := tproto.ReadLine()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade status line: %v", err)
}
if !strings.Contains(statusLine, "101") {
_ = conn.Close()
t.Fatalf("unexpected upgrade status: %s", statusLine)
}
headers, err := tproto.ReadMIMEHeader()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade headers: %v", err)
}
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
_ = conn.Close()
t.Fatalf("unexpected upgrade response headers: %v", headers)
}
return &upgradedStreamClient{conn: conn, reader: reader}
}
func (c *upgradedStreamClient) echo(payload string) error {
c.mu.Lock()
defer c.mu.Unlock()
deadline := time.Now().Add(1 * time.Second)
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return err
}
if _, err := io.WriteString(c.conn, payload); err != nil {
return err
}
if err := c.conn.SetReadDeadline(deadline); err != nil {
return err
}
buf := make([]byte, len(payload))
if _, err := io.ReadFull(c.reader, buf); err != nil {
return err
}
if string(buf) != payload {
return fmt.Errorf("unexpected echoed payload: got %q want %q", string(buf), payload)
}
return nil
}
func (c *upgradedStreamClient) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.conn.Close()
}
type upgradeEchoBackend struct {
addr string
ln net.Listener
mu sync.Mutex
conns map[net.Conn]struct{}
server *http.Server
}
func newUpgradeEchoBackend(t *testing.T) *upgradeEchoBackend {
t.Helper()
backend := &upgradeEchoBackend{conns: make(map[net.Conn]struct{})}
backend.server = &http.Server{
Handler: http.HandlerFunc(backend.serveHTTP),
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listening for backend: %v", err)
}
backend.ln = ln
backend.addr = ln.Addr().String()
go func() {
_ = backend.server.Serve(ln)
}()
return backend
}
func (b *upgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "stress-stream") {
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
conn, rw, err := hijacker.Hijack()
if err != nil {
return
}
b.trackConn(conn)
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: stress-stream\r\n\r\n")
_ = rw.Flush()
go func() {
defer b.untrackConn(conn)
defer conn.Close()
_, _ = io.Copy(conn, conn)
}()
}
func (b *upgradeEchoBackend) trackConn(conn net.Conn) {
b.mu.Lock()
b.conns[conn] = struct{}{}
b.mu.Unlock()
}
func (b *upgradeEchoBackend) untrackConn(conn net.Conn) {
b.mu.Lock()
delete(b.conns, conn)
b.mu.Unlock()
}
func (b *upgradeEchoBackend) Close() {
_ = b.server.Close()
_ = b.ln.Close()
b.mu.Lock()
defer b.mu.Unlock()
for conn := range b.conns {
_ = conn.Close()
}
clear(b.conns)
}
+1 -1
View File
@@ -20,7 +20,7 @@ require (
github.com/klauspost/cpuid/v2 v2.3.0
github.com/mholt/acmez/v3 v3.1.6
github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.59.1
github.com/quic-go/quic-go v0.59.0
github.com/smallstep/certificates v0.30.2
github.com/smallstep/nosql v0.8.0
github.com/smallstep/truststore v0.13.0
+2 -2
View File
@@ -280,8 +280,8 @@ github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEy
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.1 h1:0Gmua0HW1Tv7ANR7hUYwRyD0MG5OJfgvYSZasGZzBic=
github.com/quic-go/quic-go v0.59.1/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
+32 -2
View File
@@ -20,6 +20,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"maps"
"net"
"net/http"
"strconv"
@@ -240,7 +241,12 @@ func (app *App) Provision(ctx caddy.Context) error {
// if no protocols configured explicitly, enable all except h2c
if len(srv.Protocols) == 0 {
srv.Protocols = srv.protocolsWithDefaults()
srv.Protocols = []string{"h1", "h2", "h3"}
}
srvProtocolsUnique := map[string]struct{}{}
for _, srvProtocol := range srv.Protocols {
srvProtocolsUnique[srvProtocol] = struct{}{}
}
if srv.ListenProtocols != nil {
@@ -251,7 +257,31 @@ func (app *App) Provision(ctx caddy.Context) error {
for i, lnProtocols := range srv.ListenProtocols {
if lnProtocols != nil {
srv.ListenProtocols[i] = srv.listenerProtocolsWithDefaults(lnProtocols)
// populate empty listen protocols with server protocols
lnProtocolsDefault := false
var lnProtocolsInclude []string
srvProtocolsInclude := maps.Clone(srvProtocolsUnique)
// keep existing listener protocols unless they are empty
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" {
lnProtocolsDefault = true
} else {
lnProtocolsInclude = append(lnProtocolsInclude, lnProtocol)
delete(srvProtocolsInclude, lnProtocol)
}
}
// append server protocols to listener protocols if any listener protocols were empty
if lnProtocolsDefault {
for _, srvProtocol := range srv.Protocols {
if _, ok := srvProtocolsInclude[srvProtocol]; ok {
lnProtocolsInclude = append(lnProtocolsInclude, srvProtocol)
}
}
}
srv.ListenProtocols[i] = lnProtocolsInclude
}
}
}
+1 -15
View File
@@ -173,7 +173,7 @@ func (app *App) automaticHTTPSPhase1(ctx caddy.Context, repl *caddy.Replacer) er
for d := range serverDomainSet {
echDomains = append(echDomains, d)
}
app.tlsApp.RegisterServerNames(echDomains, httpsRRALPNs(srv))
app.tlsApp.RegisterServerNames(echDomains)
// nothing more to do here if there are no domains that qualify for
// automatic HTTPS and there are no explicit TLS connection policies:
@@ -574,20 +574,6 @@ func (app *App) makeRedirRoute(redirToPort uint, matcherSet MatcherSet) Route {
}
}
func httpsRRALPNs(srv *Server) []string {
alpn := make(map[string]struct{}, 3)
if srv.protocol("h3") {
alpn["h3"] = struct{}{}
}
if srv.protocol("h2") {
alpn["h2"] = struct{}{}
}
if srv.protocol("h1") {
alpn["http/1.1"] = struct{}{}
}
return caddytls.OrderedHTTPSRRALPN(alpn)
}
// createAutomationPolicies ensures that automated certificates for this
// app are managed properly. This adds up to two automation policies:
// one for the public names, and one for the internal names. If a catch-all
+30 -33
View File
@@ -1,47 +1,44 @@
package caddyhttp
import (
"reflect"
"testing"
"github.com/caddyserver/caddy/v2"
)
func TestHTTPSRRALPNsDefaultProtocols(t *testing.T) {
srv := &Server{}
func TestRecordAutoHTTPSRedirectAddressPrefersHTTPSPort(t *testing.T) {
app := &App{HTTPSPort: 443}
redirDomains := make(map[string][]caddy.NetworkAddress)
got := httpsRRALPNs(srv)
want := []string{"h3", "h2", "http/1.1"}
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", StartPort: 2345, EndPort: 2345})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", StartPort: 443, EndPort: 443})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", StartPort: 8443, EndPort: 8443})
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected ALPN values: got %v want %v", got, want)
got := redirDomains["example.com"]
if len(got) != 1 {
t.Fatalf("expected 1 redirect address, got %d: %#v", len(got), got)
}
if got[0].StartPort != 443 {
t.Fatalf("expected redirect to prefer HTTPS port 443, got %#v", got[0])
}
}
func TestHTTPSRRALPNsListenProtocolOverrides(t *testing.T) {
srv := &Server{
Protocols: []string{"h1", "h2"},
ListenProtocols: [][]string{
{"h1"},
nil,
{},
{"h3", ""},
},
func TestRecordAutoHTTPSRedirectAddressKeepsAllBindAddressesOnWinningPort(t *testing.T) {
app := &App{HTTPSPort: 443}
redirDomains := make(map[string][]caddy.NetworkAddress)
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", Host: "10.0.0.189", StartPort: 8443, EndPort: 8443})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", Host: "10.0.0.189", StartPort: 443, EndPort: 443})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", Host: "2603:c024:8002:9500:9eb:e5d3:3975:d056", StartPort: 443, EndPort: 443})
got := redirDomains["example.com"]
if len(got) != 2 {
t.Fatalf("expected 2 redirect addresses for both bind addresses on the winning port, got %d: %#v", len(got), got)
}
got := httpsRRALPNs(srv)
want := []string{"h3", "h2", "http/1.1"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected ALPN values: got %v want %v", got, want)
}
}
func TestHTTPSRRALPNsIgnoresH2COnly(t *testing.T) {
srv := &Server{
Protocols: []string{"h2c"},
}
got := httpsRRALPNs(srv)
if len(got) != 0 {
t.Fatalf("unexpected ALPN values: got %v want none", got)
if got[0].StartPort != 443 || got[1].StartPort != 443 {
t.Fatalf("expected both redirect addresses to stay on HTTPS port 443, got %#v", got)
}
if got[0].Host != "10.0.0.189" || got[1].Host != "2603:c024:8002:9500:9eb:e5d3:3975:d056" {
t.Fatalf("expected both bind addresses to be preserved, got %#v", got)
}
}
@@ -28,7 +28,6 @@ import (
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/internal/filesystems"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
)
type testCase struct {
@@ -189,105 +188,6 @@ func fileMatcherTest(t *testing.T, i int, tc testCase) {
}
}
func TestTryFilesRewriteEscapesMatchedPath(t *testing.T) {
root := t.TempDir()
tests := []struct {
name string
requestTarget string
filename string
extraFiles []string
wantPath string
wantRequestURI string
skipWindows bool
}{
{
name: "question mark in path",
requestTarget: "/%3F.html",
filename: "?.html",
wantPath: "/?.html",
wantRequestURI: "/%3F.html",
skipWindows: true,
},
{
name: "percent in path",
requestTarget: "/%25.html",
filename: "%.html",
wantPath: "/%.html",
wantRequestURI: "/%25.html",
},
{
name: "encoded question mark remains percent-encoded",
requestTarget: "/%253F.html",
filename: "%3F.html",
wantPath: "/%3F.html",
wantRequestURI: "/%253F.html",
},
{
name: "question mark in nested path",
requestTarget: "/nested/%3F.html",
filename: filepath.Join("nested", "?.html"),
wantPath: "/nested/?.html",
wantRequestURI: "/nested/%3F.html",
skipWindows: true,
},
{
name: "encoded slash in filename does not conflict with nesting",
requestTarget: "/nested%252Ffile.html",
filename: "nested%2Ffile.html",
extraFiles: []string{filepath.Join("nested", "file.html")},
wantPath: "/nested%2Ffile.html",
wantRequestURI: "/nested%252Ffile.html",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.skipWindows && runtime.GOOS == "windows" {
t.Skip("Windows file names cannot contain question marks")
}
for _, name := range append([]string{tc.filename}, tc.extraFiles...) {
filename := filepath.Join(root, name)
if err := os.MkdirAll(filepath.Dir(filename), 0o700); err != nil {
t.Fatalf("creating test file parent directory: %v", err)
}
if err := os.WriteFile(filename, []byte(name), 0o600); err != nil {
t.Fatalf("writing test file: %v", err)
}
}
m := &MatchFile{
fsmap: &filesystems.FileSystemMap{},
Root: root,
TryFiles: []string{"{http.request.uri.path}"},
}
req := httptest.NewRequest(http.MethodGet, "http://example.com"+tc.requestTarget, nil)
repl := caddyhttp.NewTestReplacer(req)
matched, err := m.MatchWithError(req)
if err != nil {
t.Fatalf("matching file: %v", err)
}
if !matched {
t.Fatalf("expected request %s to match %s", tc.requestTarget, tc.filename)
}
rewrite.Rewrite{URI: "{http.matchers.file.relative}"}.Rewrite(req, repl)
if req.URL.Path != tc.wantPath {
t.Errorf("rewritten path = %q, want %q", req.URL.Path, tc.wantPath)
}
if req.RequestURI != tc.wantRequestURI {
t.Errorf("rewritten request URI = %q, want %q", req.RequestURI, tc.wantRequestURI)
}
if req.URL.RawQuery != "" {
t.Errorf("rewritten raw query = %q, want empty", req.URL.RawQuery)
}
})
}
}
func TestPHPFileMatcher(t *testing.T) {
for i, tc := range []struct {
path string
+4 -58
View File
@@ -21,8 +21,6 @@ import (
"io"
"net"
"net/http"
"github.com/caddyserver/caddy/v2"
)
// ResponseWriterWrapper wraps an underlying ResponseWriter and
@@ -72,8 +70,6 @@ type responseRecorder struct {
size int
wroteHeader bool
stream bool
hijacked bool
detached bool
readSize *int
}
@@ -148,8 +144,7 @@ func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer
// WriteHeader writes the headers with statusCode to the wrapped
// ResponseWriter unless the response is to be buffered instead.
// 1xx responses are never buffered, except 101 which is treated
// as a final upgrade response.
// 1xx responses are never buffered.
func (rr *responseRecorder) WriteHeader(statusCode int) {
if rr.wroteHeader {
return
@@ -166,12 +161,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header())
}
// 1xx responses except 101 aren't final; just informational
if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols {
// 1xx responses aren't final; just informational
if statusCode < 100 || statusCode > 199 {
rr.wroteHeader = true
}
// if 1xx or not buffered, immediately write header
// if informational or not buffered, immediately write header
if rr.stream || (100 <= statusCode && statusCode <= 199) {
rr.ResponseWriterWrapper.WriteHeader(statusCode)
}
@@ -227,18 +222,7 @@ func (rr *responseRecorder) Buffered() bool {
return !rr.stream
}
func (rr *responseRecorder) DetachAfterHijack(detached bool) bool {
if rr.hijacked {
return false
}
rr.detached = detached
return true
}
func (rr *responseRecorder) WriteResponse() error {
if rr.hijacked {
return nil
}
if rr.statusCode == 0 {
// could happen if no handlers actually wrote anything,
// and this prevents a panic; status must be > 0
@@ -269,25 +253,11 @@ func (rr *responseRecorder) setReadSize(size *int) {
}
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !rr.wroteHeader {
// hijacking without writing status code first works as long as
// subsequent writes follows http1.1 wire format, but it will
// show up with a status code of 0 in the access log and bytes
// written will include response headers. Response headers won't
// be present in the log if not set on the response writer.
caddy.Log().Warn("hijacking without writing status code first")
}
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
if err != nil {
return nil, nil, err
}
rr.hijacked = true
rr.stream = true
rr.wroteHeader = true
if rr.detached {
return conn, brw, nil
}
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)
@@ -341,29 +311,6 @@ func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
// DetachResponseWriterAfterHijack detaches w or one of its wrapped
// response writers when it's hijacked. Returns true if not already
// hijacked. When detached, bytes read or written stats will not be
// recorded for the hijacked connection, and it's safe to use the
// connection after http middleware returns.
func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool {
for w != nil {
if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok {
return detacher.DetachAfterHijack(detached)
}
unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter })
if !ok {
return false
}
next := unwrapper.Unwrap()
if next == w {
return false
}
w = next
}
return false
}
// ResponseRecorder is a http.ResponseWriter that records
// responses instead of writing them to the client. See
// docs for NewResponseRecorder for proper usage.
@@ -372,7 +319,6 @@ type ResponseRecorder interface {
Status() int
Buffer() *bytes.Buffer
Buffered() bool
DetachAfterHijack(bool) bool
Size() int
WriteResponse() error
}
-93
View File
@@ -1,14 +1,11 @@
package caddyhttp
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
)
type responseWriterSpy interface {
@@ -47,50 +44,6 @@ func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }
type hijackRespWriter struct {
baseRespWriter
header http.Header
status int
conn net.Conn
}
func newHijackRespWriter() *hijackRespWriter {
return &hijackRespWriter{
header: make(http.Header),
conn: stubConn{},
}
}
func (hrw *hijackRespWriter) Header() http.Header {
return hrw.header
}
func (hrw *hijackRespWriter) WriteHeader(statusCode int) {
hrw.status = statusCode
}
func (hrw *hijackRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
br := bufio.NewReader(hrw.conn)
bw := bufio.NewWriter(hrw.conn)
return hrw.conn, bufio.NewReadWriter(br, bw), nil
}
type stubConn struct{}
func (stubConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (stubConn) Write(p []byte) (int, error) { return len(p), nil }
func (stubConn) Close() error { return nil }
func (stubConn) LocalAddr() net.Addr { return stubAddr("local") }
func (stubConn) RemoteAddr() net.Addr { return stubAddr("remote") }
func (stubConn) SetDeadline(time.Time) error { return nil }
func (stubConn) SetReadDeadline(time.Time) error { return nil }
func (stubConn) SetWriteDeadline(time.Time) error { return nil }
type stubAddr string
func (a stubAddr) Network() string { return "tcp" }
func (a stubAddr) String() string { return string(a) }
func TestResponseWriterWrapperReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
@@ -216,49 +169,3 @@ func TestResponseRecorderReadFrom(t *testing.T) {
})
}
}
func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) {
w := newHijackRespWriter()
var buf bytes.Buffer
rr := NewResponseRecorder(w, &buf, func(status int, header http.Header) bool {
return true
})
rr.WriteHeader(http.StatusSwitchingProtocols)
if rr.Status() != http.StatusSwitchingProtocols {
t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols)
}
if w.status != http.StatusSwitchingProtocols {
t.Fatalf("underlying status = %d, want %d", w.status, http.StatusSwitchingProtocols)
}
hj, ok := rr.(http.Hijacker)
if !ok {
t.Fatal("response recorder does not implement http.Hijacker")
}
conn, _, err := hj.Hijack()
if err != nil {
t.Fatalf("Hijack() error = %v", err)
}
defer conn.Close()
if rr.Buffered() {
t.Fatal("hijacked response should not remain buffered")
}
if rr.DetachAfterHijack(true) {
t.Fatal("response recorder should report hijacked state by returning false")
}
if DetachResponseWriterAfterHijack(rr, true) {
t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack")
}
if err := rr.WriteResponse(); err != nil {
t.Fatalf("WriteResponse() after hijack returned error: %v", err)
}
if rr.Size() != 0 {
t.Fatalf("size = %d, want 0 after hijack handshake", rr.Size())
}
if got := w.Written(); got != "" {
t.Fatalf("unexpected buffered body write after hijack: %q", got)
}
}
@@ -99,12 +99,6 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// stream_buffer_size <size>
// stream_timeout <duration>
// stream_close_delay <duration>
// stream_detached
// stream_logs {
// level <debug|info|warn|error>
// logger_name <name|access>
// skip_handshake
// }
// verbose_logs
//
// # request manipulation
@@ -709,49 +703,6 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
h.StreamCloseDelay = caddy.Duration(dur)
}
case "stream_detached":
if d.NextArg() {
return d.ArgErr()
}
h.StreamDetached = true
case "stream_logs":
if d.NextArg() {
return d.ArgErr()
}
if h.StreamLogs == nil {
h.StreamLogs = new(StreamLogs)
}
nesting := d.Nesting()
for d.NextBlock(nesting) {
switch d.Val() {
case "level":
if !d.NextArg() {
return d.ArgErr()
}
h.StreamLogs.Level = d.Val()
if d.NextArg() {
return d.ArgErr()
}
case "logger_name":
if !d.NextArg() {
return d.ArgErr()
}
h.StreamLogs.LoggerName = d.Val()
if d.NextArg() {
return d.ArgErr()
}
case "skip_handshake":
if d.NextArg() {
return d.ArgErr()
}
h.StreamLogs.SkipHandshake = true
default:
return d.Errf("unrecognized stream_logs option: %s", d.Val())
}
}
case "trusted_proxies":
for d.NextArg() {
if d.Val() == "private_ranges" {
@@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
hrc.isFinalized = true
// write the response
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr)
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger)
}
// CopyResponseHeadersHandler is a special HTTP handler which may
@@ -1,146 +0,0 @@
package reverseproxy
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"go.uber.org/zap"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
type extendedConnectCapture struct {
method string
headers http.Header
body []byte
extendedBodyPresent bool
extendedConnectBody []byte
}
type extendedConnectCaptureTransport struct {
mu sync.Mutex
capture extendedConnectCapture
}
func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
c := extendedConnectCapture{
method: req.Method,
headers: req.Header.Clone(),
body: body,
}
if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
c.extendedBodyPresent = true
c.extendedConnectBody, err = io.ReadAll(rc)
if err != nil {
return nil, err
}
_ = rc.Close()
}
tr.mu.Lock()
tr.capture = c
tr.mu.Unlock()
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: req,
}, nil
}
func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture {
tr.mu.Lock()
defer tr.mu.Unlock()
return tr.capture
}
func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) {
tests := []struct {
name string
protoMajor int
proto string
headers map[string]string
}{
{
name: "h2 extended connect",
protoMajor: 2,
proto: "HTTP/2.0",
headers: map[string]string{
":protocol": "websocket",
},
},
{
name: "h3 extended connect",
protoMajor: 3,
proto: "websocket",
headers: map[string]string{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
const payload = "extended-connect-body"
transport := new(extendedConnectCaptureTransport)
h := &Handler{
logger: zap.NewNop(),
Transport: transport,
Upstreams: UpstreamPool{
&Upstream{Host: new(Host), Dial: "127.0.0.1:8443"},
},
LoadBalancing: &LoadBalancing{
SelectionPolicy: &RoundRobinSelection{},
},
}
req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload))
req.ProtoMajor = tc.protoMajor
req.Proto = tc.proto
for key, value := range tc.headers {
req.Header.Set(key, value)
}
req = prepareTestRequest(req)
rr := httptest.NewRecorder()
err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
if err != nil {
t.Fatalf("ServeHTTP() error = %v", err)
}
captured := transport.Snapshot()
if captured.method != http.MethodGet {
t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet)
}
if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") {
t.Fatalf("Upgrade header = %q, want websocket", got)
}
if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") {
t.Fatalf("Connection header = %q, want Upgrade", got)
}
if got := captured.headers.Get(":protocol"); got != "" {
t.Fatalf(":protocol header should be removed, got %q", got)
}
if len(captured.body) != 0 {
t.Fatalf("upstream request body length = %d, want 0", len(captured.body))
}
if !captured.extendedBodyPresent {
t.Fatal("extended_connect_websocket_body variable missing from request context")
}
if string(captured.extendedConnectBody) != payload {
t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload)
}
})
}
}
-79
View File
@@ -16,10 +16,6 @@ import (
var reverseProxyMetrics = struct {
once sync.Once
upstreamsHealthy *prometheus.GaugeVec
streamsActive *prometheus.GaugeVec
streamsTotal *prometheus.CounterVec
streamDuration *prometheus.HistogramVec
streamBytes *prometheus.CounterVec
logger *zap.Logger
}{}
@@ -27,8 +23,6 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
const ns, sub = "caddy", "reverse_proxy"
upstreamsLabels := []string{"upstream"}
streamResultLabels := []string{"upstream", "result"}
streamBytesLabels := []string{"upstream", "direction"}
reverseProxyMetrics.once.Do(func() {
reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: ns,
@@ -36,31 +30,6 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
Name: "upstreams_healthy",
Help: "Health status of reverse proxy upstreams.",
}, upstreamsLabels)
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: ns,
Subsystem: sub,
Name: "streams_active",
Help: "Number of currently active upgraded reverse proxy streams.",
}, upstreamsLabels)
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: ns,
Subsystem: sub,
Name: "streams_total",
Help: "Total number of upgraded reverse proxy streams by close result.",
}, streamResultLabels)
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: ns,
Subsystem: sub,
Name: "stream_duration_seconds",
Help: "Duration of upgraded reverse proxy streams by close result.",
Buckets: prometheus.DefBuckets,
}, streamResultLabels)
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: ns,
Subsystem: sub,
Name: "stream_bytes_total",
Help: "Total bytes proxied across upgraded reverse proxy streams.",
}, streamBytesLabels)
})
// duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because
@@ -73,58 +42,10 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamsActive,
NewCollector: reverseProxyMetrics.streamsActive,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamsTotal,
NewCollector: reverseProxyMetrics.streamsTotal,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamDuration,
NewCollector: reverseProxyMetrics.streamDuration,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamBytes,
NewCollector: reverseProxyMetrics.streamBytes,
}) {
panic(err)
}
reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics")
}
func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) {
labels := prometheus.Labels{"upstream": upstream}
reverseProxyMetrics.streamsActive.With(labels).Inc()
var once sync.Once
return func(result string, duration time.Duration, toBackend, fromBackend int64) {
once.Do(func() {
reverseProxyMetrics.streamsActive.With(labels).Dec()
reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc()
reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds())
if toBackend > 0 {
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend))
}
if fromBackend > 0 {
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend))
}
})
}
}
type metricsUpstreamsHealthyUpdater struct {
handler *Handler
}
@@ -1,67 +0,0 @@
package reverseproxy
import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) {
const upstream = "127.0.0.1:7443"
// Use fresh metric vectors for deterministic assertions in this unit test.
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
finish := trackActiveStream(upstream)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 {
t.Fatalf("active streams = %v, want 1", got)
}
finish("closed", 150*time.Millisecond, 1234, 4321)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 {
t.Fatalf("active streams = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 {
t.Fatalf("streams_total closed = %v, want 1", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 {
t.Fatalf("bytes to_upstream = %v, want 1234", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 {
t.Fatalf("bytes from_upstream = %v, want 4321", got)
}
// A second finish call should be ignored by the once guard.
finish("error", 1*time.Second, 111, 222)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 {
t.Fatalf("streams_total error = %v, want 0", got)
}
}
func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) {
const upstream = "127.0.0.1:9000"
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0)
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 {
t.Fatalf("bytes to_upstream = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 {
t.Fatalf("bytes from_upstream = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 {
t.Fatalf("streams_total timeout = %v, want 1", got)
}
}
@@ -730,58 +730,3 @@ func TestRetryMatchAllowsExpressionMixedWithOtherMatchers(t *testing.T) {
})
}
}
// TestSubrouteErrorFallbackWithBody is similar to TestDialErrorBodyRetry but
// mimics Subroute's Error handler rather than testing retries specifically
func TestSubrouteErrorFallbackWithBody(t *testing.T) {
// Good upstream: echoes the request body with 200 OK.
goodServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read body: "+err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, err = w.Write(body)
if err != nil {
t.Errorf("error writing in good server: %v", err)
}
}))
t.Cleanup(goodServer.Close)
// Handler which will dial error
badProxy := minimalHandler(0, &Upstream{Host: new(Host), Dial: deadUpstreamAddr(t)})
bodyReader := newCloseOnCloseReader("hello world")
req := httptest.NewRequest("POST", "http://localhost/", bodyReader)
// httptest.NewRequest wraps the reader in NopCloser; replace
// it with our close-aware reader so Close() is propagated.
req.Body = bodyReader
req = prepareTestRequest(req)
rec := httptest.NewRecorder()
err := badProxy.ServeHTTP(rec, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
if err == nil {
t.Fatalf("Expected error from badProxy.ServeHTTP")
}
// Simulate the Subroute's Error handler by calling another handler with the
// same request and recorder
goodProxy := minimalHandler(0, &Upstream{Host: new(Host), Dial: goodServer.Listener.Addr().String()})
err = goodProxy.ServeHTTP(rec, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
if err != nil {
t.Fatalf("Expected no error from goodProxy.ServeHTTP, got: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("status: got %d, want %d", rec.Code, http.StatusOK)
}
expectedBody := "hello world"
if rec.Body.String() != expectedBody {
t.Errorf("body: got %q, want %q", rec.Body.String(), expectedBody)
}
}
+28 -158
View File
@@ -186,22 +186,6 @@ type Handler struct {
// by the previous config closing. Default: no delay.
StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"`
// If true, upgraded connections such as WebSockets are detached from
// the handler and retained across config reloads when their upstream
// still exists in the new config. Connections using upstreams that are
// removed are closed during cleanup. By default this is false, preserving
// legacy behavior where upgraded connections are closed on reload
// (optionally delayed by stream_close_delay).
// Only http1.1 websocket connections are affected, websockets for h2/h3
// are not affected. If true, bytes transferred for http1.1 in the access
// logs will be zero but those stats can be found in the stream logs for
// http1/2/3 regardless if this is enabled.
StreamDetached bool `json:"stream_detached,omitempty"`
// Controls logging behavior for upgraded stream lifecycle events.
// If omitted, defaults are used (level=DEBUG, logger_name="http.handlers.reverse_proxy.stream").
StreamLogs *StreamLogs `json:"stream_logs,omitempty"`
// If configured, rewrites the copy of the upstream request.
// Allows changing the request method and URI (path and query).
// Since the rewrite is applied to the copy, it does not persist
@@ -256,16 +240,14 @@ type Handler struct {
// Holds the handle_response Caddyfile tokens while adapting
handleResponseSegments []*caddyfile.Dispenser
// Tracks hijacked/upgraded connections (WebSocket etc.) so they can be
// closed when their upstream is removed from the config.
tunnelTracker *tunnelTracker
// Stores upgraded requests (hijacked connections) for proper cleanup
connections map[io.ReadWriteCloser]openConnection
connectionsCloseTimer *time.Timer
connectionsMu *sync.Mutex
ctx caddy.Context
logger *zap.Logger
events *caddyevents.App
streamLogLevel zapcore.Level
streamLogLoggerName string
}
// CaddyModule returns the Caddy module information.
@@ -285,25 +267,8 @@ func (h *Handler) Provision(ctx caddy.Context) error {
h.events = eventAppIface.(*caddyevents.App)
h.ctx = ctx
h.logger = ctx.Logger()
h.tunnelTracker = newTunnelTracker(h.logger, time.Duration(h.StreamCloseDelay))
h.streamLogLevel = defaultStreamLogLevel
h.streamLogLoggerName = defaultStreamLoggerName
if h.StreamLogs != nil {
if h.StreamLogs.Level != "" {
lvl, err := zapcore.ParseLevel(strings.ToLower(strings.TrimSpace(h.StreamLogs.Level)))
if err != nil {
return fmt.Errorf("invalid stream_logs.level %q: %w", h.StreamLogs.Level, err)
}
h.streamLogLevel = lvl
}
if name := strings.TrimSpace(h.StreamLogs.LoggerName); name != "" {
h.streamLogLoggerName = name
}
}
if h.StreamDetached {
registerDetachedTunnelTrackers(h.tunnelTracker)
}
h.connections = make(map[io.ReadWriteCloser]openConnection)
h.connectionsMu = new(sync.Mutex)
// warn about unsafe buffering config
if h.RequestBuffers == -1 || h.ResponseBuffers == -1 {
@@ -472,85 +437,15 @@ func (h *Handler) Provision(ctx caddy.Context) error {
return nil
}
func (h Handler) streamLogsSkipHandshake() bool {
return h.StreamLogs != nil && h.StreamLogs.SkipHandshake
}
func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger {
name := strings.TrimSpace(h.streamLogLoggerName)
if name == "" {
name = defaultStreamLoggerName
}
if name == streamLoggerNameUseAccess {
logger := caddy.Log().Named(defaultAccessLoggerBase)
names := caddyhttp.GetVar(req.Context(), caddyhttp.AccessLoggerNameVarKey)
namesSlice, ok := names.([]any)
if !ok {
return logger
}
for _, v := range namesSlice {
name, ok := v.(string)
if !ok {
continue
}
if name == "" {
return logger
}
return logger.Named(name)
}
return logger
}
return caddy.Log().Named(name)
}
var (
detachedTunnelTrackers = make(map[*tunnelTracker]struct{})
detachedTunnelTrackersMu sync.Mutex
)
func registerDetachedTunnelTrackers(ts *tunnelTracker) {
detachedTunnelTrackersMu.Lock()
defer detachedTunnelTrackersMu.Unlock()
detachedTunnelTrackers[ts] = struct{}{}
}
func notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream string, self *tunnelTracker) error {
detachedTunnelTrackersMu.Lock()
defer detachedTunnelTrackersMu.Unlock()
var err error
for tunnel := range detachedTunnelTrackers {
if closeErr := tunnel.closeConnectionsForUpstream(upstream); closeErr != nil && tunnel == self && err == nil {
err = closeErr
}
}
return err
}
func unregisterDetachedTunnelTrackers(ts *tunnelTracker) {
detachedTunnelTrackersMu.Lock()
defer detachedTunnelTrackersMu.Unlock()
delete(detachedTunnelTrackers, ts)
}
// Cleanup cleans up the resources made by h.
func (h *Handler) Cleanup() error {
// even if StreamDetached is true, extended connect websockets may still be running
err := h.tunnelTracker.cleanupAttachedConnections()
err := h.cleanupConnections()
// remove hosts from our config from the pool
for _, upstream := range h.Upstreams {
// hosts.Delete returns deleted=true when the ref count reaches zero,
// meaning no other active config references this upstream. In that
// case close any tunnels proxying to it; otherwise let them survive
// to their natural end since the upstream is still in use.
deleted, _ := hosts.Delete(upstream.String())
if deleted {
if closeErr := notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream.String(), h.tunnelTracker); closeErr != nil && err == nil {
err = closeErr
}
}
_, _ = hosts.Delete(upstream.String())
}
return err
}
@@ -593,19 +488,20 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
reqHost := clonedReq.Host
reqHeader := clonedReq.Header
// If the request contained a body, wrap it in io.NopCloser
// to prevent Go's transport from closing it on dial errors.
// cloneRequest does a shallow copy, so clonedReq.Body and
// When retries are configured and there is a body, wrap it in
// io.NopCloser to prevent Go's transport from closing it on dial
// errors. cloneRequest does a shallow copy, so clonedReq.Body and
// r.Body share the same io.ReadCloser — a dial-failure Close()
// would kill the original body for all subsequent retry
// attempts or subsequent handlers. The real body is closed by
// the HTTP server when the handler returns.
// would kill the original body for all subsequent retry attempts.
// The real body is closed by the HTTP server when the handler
// returns.
//
// If the body was already fully buffered (via request_buffers),
// we also extract the buffer so the retry loop can replay it
// from the beginning on each attempt. (see #6259, #7546, #7713)
// from the beginning on each attempt. (see #6259, #7546)
var bufferedReqBody *bytes.Buffer
if clonedReq.Body != nil {
if clonedReq.Body != nil && h.LoadBalancing != nil &&
(h.LoadBalancing.Retries > 0 || h.LoadBalancing.TryDuration > 0) {
if reqBodyBuf, ok := clonedReq.Body.(bodyReadCloser); ok && reqBodyBuf.body == nil && reqBodyBuf.buf != nil {
bufferedReqBody = reqBodyBuf.buf
reqBodyBuf.buf = nil
@@ -1242,11 +1138,10 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
// we use the original request here, so that any routes from 'next'
// see the original request rather than the proxy cloned request.
hrc := &handleResponseContext{
handler: h,
response: res,
start: start,
logger: logger,
upstreamAddr: di.Upstream.String(),
handler: h,
response: res,
start: start,
logger: logger,
}
ctx := origReq.Context()
ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc)
@@ -1276,7 +1171,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
}
// copy the response body and headers back to the upstream client
return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String())
return h.finalizeResponse(rw, req, res, repl, start, logger)
}
// finalizeResponse prepares and copies the response.
@@ -1287,11 +1182,12 @@ func (h *Handler) finalizeResponse(
repl *caddy.Replacer,
start time.Time,
logger *zap.Logger,
upstreamAddr string,
) error {
// deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
h.handleUpgradeResponse(logger, rw, req, res, upstreamAddr)
var wg sync.WaitGroup
h.handleUpgradeResponse(logger, &wg, rw, req, res)
wg.Wait()
return nil
}
@@ -1898,22 +1794,6 @@ func (brc bodyReadCloser) Close() error {
return nil
}
// StreamLogs controls logging for upgraded stream lifecycle events.
type StreamLogs struct {
// The minimum level at which stream lifecycle events are logged.
// Supported values are debug, info, warn, and error. Default: debug.
Level string `json:"level,omitempty"`
// Logger name for stream lifecycle logs. Default: "http.handlers.reverse_proxy.stream".
// Special value "access" uses the access logger namespace and, if set,
// respects the first value in access_logger_names/log_name for the request.
LoggerName string `json:"logger_name,omitempty"`
// If true, suppresses the access log entry normally emitted when an
// upgraded stream handshake completes and the request unwinds.
SkipHandshake bool `json:"skip_handshake,omitempty"`
}
// bufPool is used for buffering requests and responses.
var bufPool = sync.Pool{
New: func() any {
@@ -1946,9 +1826,6 @@ type handleResponseContext struct {
// i.e. copied and closed, to make sure that it doesn't
// happen twice.
isFinalized bool
// upstreamAddr is the selected upstream address for this request.
upstreamAddr string
}
// proxyHandleResponseContextCtxKey is the context key for the active proxy handler
@@ -1959,13 +1836,6 @@ const proxyHandleResponseContextCtxKey caddy.CtxKey = "reverse_proxy_handle_resp
// errNoUpstream occurs when there are no upstream available.
var errNoUpstream = fmt.Errorf("no upstreams available")
const (
defaultStreamLogLevel = zapcore.DebugLevel
defaultStreamLoggerName = "http.handlers.reverse_proxy.stream"
streamLoggerNameUseAccess = "access"
defaultAccessLoggerBase = "http.log.access"
)
// Interface guards
var (
_ caddy.Provisioner = (*Handler)(nil)
+109 -268
View File
@@ -26,7 +26,6 @@ import (
"io"
weakrand "math/rand/v2"
"mime"
"net"
"net/http"
"sync"
"time"
@@ -36,16 +35,15 @@ import (
"go.uber.org/zap/zapcore"
"golang.org/x/net/http/httpguts"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
type extendedConnectReadWriteCloser struct {
type h2ReadWriteCloser struct {
io.ReadCloser
http.ResponseWriter
}
func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) {
func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
n, err = rwc.ResponseWriter.Write(p)
if err != nil {
return 0, err
@@ -59,7 +57,7 @@ func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) {
return n, nil
}
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) {
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
@@ -92,37 +90,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
copyHeader(rw.Header(), res.Header)
normalizeWebsocketHeaders(rw.Header())
// Capture all h fields needed by the tunnel now, so that the Handler (h)
// is not referenced after this function returns (for HTTP/1.1 hijacked
// connections the tunnel runs in a detached goroutine).
tunnel := h.tunnelTracker
bufferSize := h.StreamBufferSize
streamTimeout := time.Duration(h.StreamTimeout)
if h.StreamDetached {
// the return value should be true as it's not hijacked yet,
// but some middleware may wrap response writers incorrectly
if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) {
if c := logger.Check(zap.DebugLevel, "detaching connection failed"); c != nil {
c.Write(zap.String("tip", "check if your response writers have an Unwrap method or if already hijacked"))
}
}
}
var (
conn io.ReadWriteCloser
brw *bufio.ReadWriter
detached = h.StreamDetached
conn io.ReadWriteCloser
brw *bufio.ReadWriter
)
// websocket over http2 or http3 if extended connect is enabled,
// assuming backend doesn't support this, the request will be
// modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can
// be removed for those backends
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can be removed for those backends
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
// websocket over extended connect can't be detached. rw and req.Body
// are only valid while the handler goroutine is running
detached = false
req.Body = body
rw.Header().Del("Upgrade")
rw.Header().Del("Connection")
@@ -130,18 +104,18 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
rw.WriteHeader(http.StatusOK)
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
c.Write(zap.Int("http_version", req.ProtoMajor))
c.Write(zap.Int("http_version", 2))
}
//nolint:bodyclose
flushErr := http.NewResponseController(rw).Flush()
if flushErr != nil {
if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil {
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
c.Write(zap.Error(flushErr))
}
return
}
conn = extendedConnectReadWriteCloser{req.Body, rw}
conn = h2ReadWriteCloser{req.Body, rw}
// bufio is not needed, use minimal buffer
brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
} else {
@@ -169,6 +143,27 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
}
}
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
backConnCloseCh := make(chan struct{})
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
start := time.Now()
defer func() {
conn.Close()
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
c.Write(zap.Duration("duration", time.Since(start)))
}
}()
if err := brw.Flush(); err != nil {
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
c.Write(zap.Error(err))
@@ -189,12 +184,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
}
}
// Register both connections with the tunnel tracker. We also try to
// gracefully close connections we recognize as websockets. We need to make
// sure the client connection messages (i.e. to upstream) are masked, so we
// need to know whether the connection is considered the server or the
// client side of the proxy. Note that gracefulClose must not capture h,
// since the tunnel may outlive the handler instance.
// Ensure the hijacked client connection, and the new connection established
// with the backend, are both closed in the event of a server shutdown. This
// is done by registering them. We also try to gracefully close connections
// we recognize as websockets.
// We need to make sure the client connection messages (i.e. to upstream)
// are masked, so we need to know whether the connection is considered the
// server or the client side of the proxy.
gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error {
if isWebsocket(req) {
return func() error {
@@ -203,147 +199,43 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
}
return nil
}
deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), detached, upstreamAddr)
deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), detached, upstreamAddr)
if h.streamLogsSkipHandshake() {
caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true)
}
repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
repl.Set("http.reverse_proxy.upgraded", true)
streamUUID, _ := repl.GetString("http.request.uuid")
streamFields := makeStreamLogFields(streamUUID)
streamLogger := h.streamLoggerForRequest(req)
streamLevel := h.streamLogLevel
finishMetrics := trackActiveStream(upstreamAddr)
start := time.Now()
if !detached {
handleUpgradeTunnel(
streamLogger,
streamLevel,
conn,
backConn,
deleteFrontConn,
deleteBackConn,
bufferSize,
streamTimeout,
start,
finishMetrics,
streamFields,
)
} else {
// start a new goroutine
go handleUpgradeTunnel(
streamLogger,
streamLevel,
conn,
backConn,
deleteFrontConn,
deleteBackConn,
bufferSize,
streamTimeout,
start,
finishMetrics,
streamFields,
)
}
}
// handleUpgradeTunnel returns when transfer is done.
func handleUpgradeTunnel(
streamLogger *zap.Logger,
streamLevel zapcore.Level,
conn io.ReadWriteCloser,
backConn io.ReadWriteCloser,
deleteFrontConn func(),
deleteBackConn func(),
bufferSize int,
streamTimeout time.Duration,
start time.Time,
finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64),
streamFields []zap.Field,
) {
defer deleteBackConn()
deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false))
deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true))
defer deleteFrontConn()
var (
wg sync.WaitGroup
toBackend int64
fromBackend int64
result string
)
defer deleteBackConn()
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to
// send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
spc := switchProtocolCopier{
user: conn,
backend: backConn,
wg: &wg,
bufferSize: bufferSize,
sent: &toBackend,
received: &fromBackend,
wg: wg,
bufferSize: h.StreamBufferSize,
}
wg.Add(2)
// setup the timeout if requested
var timeoutc <-chan time.Time
if streamTimeout > 0 {
timer := time.NewTimer(streamTimeout)
if h.StreamTimeout > 0 {
timer := time.NewTimer(time.Duration(h.StreamTimeout))
defer timer.Stop()
timeoutc = timer.C
}
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
wg.Add(2)
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
select {
case err := <-errc:
result = classifyStreamResult(err)
if c := streamLogger.Check(streamLevel, "streaming error"); c != nil {
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
c.Write(zap.Error(err))
}
case t := <-timeoutc:
result = "timeout"
if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", t))
case time := <-timeoutc:
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", time))
}
}
// Close both ends to unblock the still-running copy goroutine,
// then wait for it so byte counts are final before metrics/logging.
conn.Close()
backConn.Close()
wg.Wait()
finishMetrics(result, time.Since(start), toBackend, fromBackend)
if c := streamLogger.Check(streamLevel, "connection closed"); c != nil {
fields := append([]zap.Field{}, streamFields...)
fields = append(fields,
zap.Duration("duration", time.Since(start)),
zap.Int64("bytes_to_backend", toBackend),
zap.Int64("bytes_from_backend", fromBackend),
)
c.Write(fields...)
}
}
func classifyStreamResult(err error) string {
if err == nil ||
errors.Is(err, io.EOF) ||
errors.Is(err, net.ErrClosed) ||
errors.Is(err, context.Canceled) {
return "closed"
}
return "error"
}
func makeStreamLogFields(streamUUID string) []zap.Field {
fields := make([]zap.Field, 0, 1)
if streamUUID != "" {
fields = append(fields, zap.String("uuid", streamUUID))
}
return fields
}
// flushInterval returns the p.FlushInterval value, conditionally
@@ -483,101 +375,75 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
}
}
// openConnection maps an open connection to an optional function for graceful
// close and records which upstream address the connection is proxying to.
// Also tracks whether the connection is detached, which means it should only be
// closed when the upstream is removed from the config, not on every reload.
type openConnection struct {
conn io.ReadWriteCloser
gracefulClose func() error
detached bool
upstream string
}
// tunnelTracker tracks hijacked/upgraded connections for selective cleanup.
// This exists to detach the lifecycle of streaming connections from the proxy
// Handler and config, since we typically want them to survive past config reloads.
// It also allows for selective connection cleanup based on their attachment status.
type tunnelTracker struct {
connections map[io.ReadWriteCloser]openConnection
closeTimer *time.Timer
closeDelay time.Duration
stopped bool
mu sync.Mutex
logger *zap.Logger
}
func newTunnelTracker(logger *zap.Logger, closeDelay time.Duration) *tunnelTracker {
return &tunnelTracker{
connections: make(map[io.ReadWriteCloser]openConnection),
closeDelay: closeDelay,
logger: logger,
}
}
// registerConnection stores conn in the tracking map. The caller must invoke
// the returned del func when the connection is done.
func (ts *tunnelTracker) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) {
ts.mu.Lock()
ts.connections[conn] = openConnection{conn, gracefulClose, detached, upstream}
ts.mu.Unlock()
// registerConnection holds onto conn so it can be closed in the event
// of a server shutdown. This is useful because hijacked connections or
// connections dialed to backends don't close when server is shut down.
// The caller should call the returned delete() function when the
// connection is done to remove it from memory.
func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) {
h.connectionsMu.Lock()
h.connections[conn] = openConnection{conn, gracefulClose}
h.connectionsMu.Unlock()
return func() {
ts.mu.Lock()
delete(ts.connections, conn)
if len(ts.connections) == 0 && ts.stopped {
unregisterDetachedTunnelTrackers(ts)
if ts.closeTimer != nil {
if ts.closeTimer.Stop() {
ts.logger.Debug("stopped streaming connections close timer - all connections are already closed")
}
ts.closeTimer = nil
h.connectionsMu.Lock()
delete(h.connections, conn)
// if there is no connection left before the connections close timer fires
if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
// we release the timer that holds the reference to Handler
if (*h.connectionsCloseTimer).Stop() {
h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
}
h.connectionsCloseTimer = nil
}
ts.mu.Unlock()
h.connectionsMu.Unlock()
}
}
// closeAttachedConnections closes all tracked attached connections.
func (ts *tunnelTracker) closeAttachedConnections() error {
// closeConnections immediately closes all hijacked connections (both to client and backend).
func (h *Handler) closeConnections() error {
var err error
ts.mu.Lock()
defer ts.mu.Unlock()
ts.stopped = true
for _, oc := range ts.connections {
// detached connections are only closed when the upstream is gone from the config
if oc.detached {
continue
}
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
for _, oc := range h.connections {
if oc.gracefulClose != nil {
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
// this is potentially blocking while we have the lock on the connections
// map, but that should be OK since the server has in theory shut down
// and we are no longer using the connections map
gracefulErr := oc.gracefulClose()
if gracefulErr != nil && err == nil {
err = gracefulErr
}
}
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
closeErr := oc.conn.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
return err
}
// cleanupAttachedConnections closes upgraded attached connections.
// Depending on closeDelay it does that either immediately or after a timer.
func (ts *tunnelTracker) cleanupAttachedConnections() error {
if ts.closeDelay == 0 {
return ts.closeAttachedConnections()
// cleanupConnections closes hijacked connections.
// Depending on the value of StreamCloseDelay it does that either immediately
// or sets up a timer that will do that later.
func (h *Handler) cleanupConnections() error {
if h.StreamCloseDelay == 0 {
return h.closeConnections()
}
ts.mu.Lock()
defer ts.mu.Unlock()
if len(ts.connections) > 0 {
delay := ts.closeDelay
ts.closeTimer = time.AfterFunc(delay, func() {
if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
// the handler is shut down, no new connection can appear,
// so we can skip setting up the timer when there are no connections
if len(h.connections) > 0 {
delay := time.Duration(h.StreamCloseDelay)
h.connectionsCloseTimer = time.AfterFunc(delay, func() {
if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
c.Write(zap.Duration("delay", delay))
}
err := ts.closeAttachedConnections()
err := h.closeConnections()
if err != nil {
if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil {
if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil {
c.Write(
zap.Error(err),
zap.Duration("delay", delay),
@@ -701,29 +567,11 @@ func isWebsocket(r *http.Request) bool {
httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket")
}
// closeConnectionsForUpstream closes all tracked connections that were
// established to the given upstream address.
func (ts *tunnelTracker) closeConnectionsForUpstream(addr string) error {
var err error
ts.mu.Lock()
defer ts.mu.Unlock()
if !ts.stopped {
return nil
}
for _, oc := range ts.connections {
if oc.upstream != addr {
continue
}
if oc.gracefulClose != nil {
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
err = gracefulErr
}
}
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
err = closeErr
}
}
return err
// openConnection maps an open connection to
// an optional function for graceful close.
type openConnection struct {
conn io.ReadWriteCloser
gracefulClose func() error
}
type maxLatencyWriter struct {
@@ -794,23 +642,16 @@ type switchProtocolCopier struct {
user, backend io.ReadWriteCloser
wg *sync.WaitGroup
bufferSize int
// sent and received accumulate byte counts for each direction.
// They are written before wg.Done() and read after wg.Wait(), so no
// additional synchronization is needed beyond the WaitGroup barrier.
sent *int64 // bytes copied to backend; must be non-nil
received *int64 // bytes copied from backend; must be non-nil
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
n, err := io.CopyBuffer(c.user, c.backend, c.buffer())
*c.received = n
_, err := io.CopyBuffer(c.user, c.backend, c.buffer())
errc <- err
c.wg.Done()
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
n, err := io.CopyBuffer(c.backend, c.user, c.buffer())
*c.sent = n
_, err := io.CopyBuffer(c.backend, c.user, c.buffer())
errc <- err
c.wg.Done()
}
@@ -7,10 +7,8 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
)
func TestHandlerCopyResponse(t *testing.T) {
@@ -43,15 +41,12 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) {
var wg sync.WaitGroup
var errc = make(chan error, 1)
var dst bytes.Buffer
var sent, received int64
copier := switchProtocolCopier{
user: nopReadWriteCloser{Reader: strings.NewReader("hello")},
backend: nopReadWriteCloser{Writer: &dst},
wg: &wg,
bufferSize: 7,
sent: &sent,
received: &received,
}
buf := copier.buffer()
@@ -85,146 +80,3 @@ type nopReadWriteCloser struct {
}
func (nopReadWriteCloser) Close() error { return nil }
type trackingReadWriteCloser struct {
closed chan struct{}
one sync.Once
}
func newTrackingReadWriteCloser() *trackingReadWriteCloser {
return &trackingReadWriteCloser{closed: make(chan struct{})}
}
func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF }
func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil }
func (c *trackingReadWriteCloser) Close() error {
c.one.Do(func() {
close(c.closed)
})
return nil
}
func (c *trackingReadWriteCloser) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) {
ts := newTunnelTracker(caddy.Log(), 0)
connA := newTrackingReadWriteCloser()
connB := newTrackingReadWriteCloser()
ts.registerConnection(connA, nil, false, "a")
ts.registerConnection(connB, nil, false, "b")
h := &Handler{
tunnelTracker: ts,
StreamDetached: false,
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if !connA.isClosed() || !connB.isClosed() {
t.Fatalf("legacy cleanup should close all upgraded connections")
}
}
func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) {
ts := newTunnelTracker(caddy.Log(), 40*time.Millisecond)
conn := newTrackingReadWriteCloser()
ts.registerConnection(conn, nil, false, "a")
h := &Handler{
tunnelTracker: ts,
StreamDetached: false,
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if conn.isClosed() {
t.Fatal("connection should not close immediately when stream_close_delay is set")
}
select {
case <-conn.closed:
case <-time.After(500 * time.Millisecond):
t.Fatal("connection did not close after stream_close_delay elapsed")
}
}
func TestHandlerCleanupDetachedModeClosesOnlyRemovedUpstreams(t *testing.T) {
const upstreamA = "upstream-a"
const upstreamB = "upstream-b"
// Simulate old+new configs both referencing upstreamA (refcount 2),
// while upstreamB is only referenced by the old config (refcount 1).
hosts.LoadOrStore(upstreamA, struct{}{})
hosts.LoadOrStore(upstreamA, struct{}{})
hosts.LoadOrStore(upstreamB, struct{}{})
t.Cleanup(func() {
_, _ = hosts.Delete(upstreamA)
_, _ = hosts.Delete(upstreamA)
_, _ = hosts.Delete(upstreamB)
})
ts := newTunnelTracker(caddy.Log(), 0)
registerDetachedTunnelTrackers(ts)
connA := newTrackingReadWriteCloser()
connB := newTrackingReadWriteCloser()
ts.registerConnection(connA, nil, true, upstreamA)
ts.registerConnection(connB, nil, true, upstreamB)
h := &Handler{
tunnelTracker: ts,
StreamDetached: true,
Upstreams: UpstreamPool{
&Upstream{Dial: upstreamA},
&Upstream{Dial: upstreamB},
},
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if connA.isClosed() {
t.Fatal("connection for detached upstream should remain open")
}
if !connB.isClosed() {
t.Fatal("connection for removed upstream should be closed")
}
}
func TestHandlerUnmarshalCaddyfileStreamLogsBlock(t *testing.T) {
d := caddyfile.NewTestDispenser(`
reverse_proxy localhost:9000 {
stream_logs {
level info
logger_name access
skip_handshake
}
}
`)
var h Handler
if err := h.UnmarshalCaddyfile(d); err != nil {
t.Fatalf("UnmarshalCaddyfile() error = %v", err)
}
if h.StreamLogs == nil {
t.Fatal("expected stream_logs to be configured")
}
if h.StreamLogs.Level != "info" {
t.Fatalf("expected stream_logs.level=info, got %q", h.StreamLogs.Level)
}
if h.StreamLogs.LoggerName != "access" {
t.Fatalf("expected stream_logs.logger_name=access, got %q", h.StreamLogs.LoggerName)
}
if !h.StreamLogs.SkipHandshake {
t.Fatal("expected stream_logs.skip_handshake=true")
}
}
+28 -3
View File
@@ -34,7 +34,9 @@ func init() {
// parseCaddyfileRewrite sets up a basic rewrite handler from Caddyfile tokens. Syntax:
//
// rewrite [<matcher>] <to>
// rewrite [<matcher>] <to> {
// force_modify_query
// }
//
// Only URI components which are given in <to> will be set in the resulting URI.
// See the docs for the rewrite handler for more information.
@@ -50,12 +52,30 @@ func parseCaddyfileRewrite(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
return nil, h.Errf("too many arguments; should only be a matcher and a URI")
}
parseBlock := func(rewr *Rewrite) error {
for nesting := h.Nesting(); h.NextBlock(nesting); {
switch h.Val() {
case "force_modify_query":
rewr.ForceModifyQuery = true
default:
return h.Errf("unknown subdirective: %s", h.Val())
}
}
return nil
}
// with only one arg, assume it's a rewrite URI with no matcher token
if argsCount == 1 {
if !h.NextArg() {
return nil, h.ArgErr()
}
return h.NewRoute(nil, Rewrite{URI: h.Val()}), nil
rewr := Rewrite{URI: h.Val()}
err := parseBlock(&rewr)
if err != nil {
return nil, err
}
return h.NewRoute(nil, rewr), nil
}
// parse the matcher token into a matcher set
@@ -66,7 +86,12 @@ func parseCaddyfileRewrite(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
h.Next() // consume directive name again, matcher parsing does a reset
h.Next() // advance to the rewrite URI
return h.NewRoute(userMatcherSet, Rewrite{URI: h.Val()}), nil
rewr := Rewrite{URI: h.Val()}
err = parseBlock(&rewr)
if err != nil {
return nil, err
}
return h.NewRoute(userMatcherSet, rewr), nil
}
// parseCaddyfileMethod sets up a basic method rewrite handler from Caddyfile tokens. Syntax:
+24 -28
View File
@@ -92,6 +92,17 @@ type Rewrite struct {
// Mutates the query string of the URI.
Query *queryOps `json:"query,omitempty"`
// If true, the rewrite will be forced to also apply to the
// query part of the URL. This is only needed if the configured
// URI does not include a '?' character which is normally used
// to determine whether the query should be modified. In other
// words, this allows rewriting both the path and query when
// using a placeholder as the replacement value, whereas otherwise
// only the path would be rewritten because the placeholder itself
// does not contain a '?' character. Only use this if the placeholder
// is trusted to not be vulnerable to query injections.
ForceModifyQuery bool `json:"force_modify_query,omitempty"`
logger *zap.Logger
}
@@ -211,7 +222,12 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
var newPath, newQuery, newFrag string
if path != "" {
path = escapePathPlaceholders(path, r, repl)
// replace the `path` placeholder to escaped path
pathPlaceholder := "{http.request.uri.path}"
if strings.Contains(path, pathPlaceholder) {
path = strings.ReplaceAll(path, pathPlaceholder, r.URL.EscapedPath())
}
newPath = repl.ReplaceAll(path, "")
}
@@ -221,10 +237,15 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
// recompute; new path contains a query string
var injectedQuery string
newPath, injectedQuery = before, after
// don't overwrite explicitly-configured query string
if query == "" {
// don't overwrite explicitly-configured query string,
// unless configured explicitly to do so
if query == "" || rewr.ForceModifyQuery {
query = injectedQuery
}
if rewr.ForceModifyQuery {
qsStart = 0
}
}
if query != "" {
@@ -295,31 +316,6 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
return r.Method != oldMethod || r.RequestURI != oldURI
}
func escapePathPlaceholders(path string, r *http.Request, repl *caddy.Replacer) string {
// Replace path-valued placeholders in escaped form before the URI is parsed,
// otherwise literal '?' and '%' bytes from the path can be interpreted as URI
// delimiters or percent-escape sequences during the rewrite.
pathPlaceholder := "{http.request.uri.path}"
if strings.Contains(path, pathPlaceholder) {
path = strings.ReplaceAll(path, pathPlaceholder, r.URL.EscapedPath())
}
fileMatchRelativePlaceholder := "{http.matchers.file.relative}"
if strings.Contains(path, fileMatchRelativePlaceholder) {
if val, ok := repl.Get("http.matchers.file.relative"); ok {
if relativePath, ok := val.(string); ok {
path = strings.ReplaceAll(path, fileMatchRelativePlaceholder, escapePathPreservingSlashes(relativePath))
}
}
}
return path
}
func escapePathPreservingSlashes(path string) string {
return strings.ReplaceAll(url.PathEscape(path), "%2F", "/")
}
// buildQueryString takes an input query string and
// performs replacements on each component, returning
// the resulting query string. This function appends
+20
View File
@@ -225,6 +225,23 @@ func TestRewrite(t *testing.T) {
input: newRequest(t, "GET", "/foo#fragFirst?c=d"),
expect: newRequest(t, "GET", "/bar#fragFirst?c=d"),
},
{
rule: Rewrite{URI: "{test.path_and_query}"},
input: newRequest(t, "GET", "/"),
expect: newRequest(t, "GET", "/foo"),
},
{
// TODO: This might be an incorrect result, since it also replaces
// the path with empty string when that might not be the intent.
rule: Rewrite{URI: "{test.query}", ForceModifyQuery: true},
input: newRequest(t, "GET", "/foo"),
expect: newRequest(t, "GET", "?bar=1"),
},
{
rule: Rewrite{URI: "{test.path_and_query}", ForceModifyQuery: true},
input: newRequest(t, "GET", "/"),
expect: newRequest(t, "GET", "/foo?bar=1"),
},
{
rule: Rewrite{URI: "/api/admin/panel"},
input: newRequest(t, "GET", "/api/admin%2Fpanel"),
@@ -364,6 +381,9 @@ func TestRewrite(t *testing.T) {
repl.Set("http.request.uri", tc.input.RequestURI)
repl.Set("http.request.uri.path", tc.input.URL.Path)
repl.Set("http.request.uri.query", tc.input.URL.RawQuery)
repl.Set("test.path", "/foo")
repl.Set("test.query", "?bar=1")
repl.Set("test.path_and_query", "/foo?bar=1")
// we can't directly call Provision() without a valid caddy.Context
// (TODO: fix that) so here we ad-hoc compile the regex
+8 -46
View File
@@ -300,8 +300,6 @@ type Server struct {
onStopFuncs []func(context.Context) error // TODO: Experimental (Nov. 2023)
}
var defaultProtocols = []string{"h1", "h2", "h3"}
var (
ServerHeader = "Caddy"
serverHeader = []string{ServerHeader}
@@ -901,56 +899,20 @@ func (s *Server) logRequest(
// protocol returns true if the protocol proto is configured/enabled.
func (s *Server) protocol(proto string) bool {
if s.ListenProtocols == nil {
return slices.Contains(s.protocolsWithDefaults(), proto)
}
for _, lnProtocols := range s.ListenProtocols {
if slices.Contains(s.listenerProtocolsWithDefaults(lnProtocols), proto) {
if slices.Contains(s.Protocols, proto) {
return true
}
}
return false
}
func (s *Server) protocolsWithDefaults() []string {
if len(s.Protocols) == 0 {
return defaultProtocols
}
return s.Protocols
}
func (s *Server) listenerProtocolsWithDefaults(lnProtocols []string) []string {
serverProtocols := s.protocolsWithDefaults()
if len(lnProtocols) == 0 {
return serverProtocols
}
lnProtocolsDefault := false
lnProtocolsInclude := make([]string, 0, len(lnProtocols)+len(serverProtocols))
srvProtocolsInclude := make(map[string]struct{}, len(serverProtocols))
for _, srvProtocol := range serverProtocols {
srvProtocolsInclude[srvProtocol] = struct{}{}
}
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" {
lnProtocolsDefault = true
continue
}
lnProtocolsInclude = append(lnProtocolsInclude, lnProtocol)
delete(srvProtocolsInclude, lnProtocol)
}
if lnProtocolsDefault {
for _, srvProtocol := range serverProtocols {
if _, ok := srvProtocolsInclude[srvProtocol]; ok {
lnProtocolsInclude = append(lnProtocolsInclude, srvProtocol)
} else {
for _, lnProtocols := range s.ListenProtocols {
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" && slices.Contains(s.Protocols, proto) || lnProtocol == proto {
return true
}
}
}
}
return lnProtocolsInclude
return false
}
// Listeners returns the server's listeners. These are active listeners,
+2 -30
View File
@@ -36,22 +36,13 @@ func init() {
// Templates is a middleware which executes response bodies as Go templates.
// The syntax is documented in the Go standard library's
// [text/template package](https://golang.org/pkg/text/template/).
// Note that ANY response body that matches and qualifies may be evaluated,
// even if it comes from a proxied backend.
//
// ⚠️ Template functions/actions can access the environment, files on disk,
// and make HTTP requests. This is extremely useful, but you need to make
// sure templates are only evaluated on content that you trust, control, or
// at least sanitize properly.
// ⚠️ Template functions/actions are still experimental, so they are subject to change.
//
// ⚠️ Templates are still experimental, so they are subject to change.
// Custom template functions can be registered by creating a plugin module under the `http.handlers.templates.functions.*` namespace that implements the `CustomFunctions` interface.
//
// [All Sprig functions](https://masterminds.github.io/sprig/) are supported.
//
// Custom template functions can be registered by creating a plugin module
// under the `http.handlers.templates.functions.*` namespace that implements
// the `CustomFunctions` interface.
//
// In addition to the standard functions and the Sprig library, Caddy adds
// extra functions and data that are available to a template:
//
@@ -171,25 +162,6 @@ func init() {
// {{listFiles "/mydir"}}
// ```
//
// ##### `fileExists`
//
// Returns true if the given file name, relative to the template context's file root,
// can be opened successfully.
//
// ```
// {{fileExists "path/to/file.html"}}
// ```
//
// ##### `fileStat`
//
// Returns [FileInfo](https://pkg.go.dev/io/fs#FileInfo) using [Stat](https://pkg.go.dev/io/fs#Stat)
// on the given file name, relative to the template context's file root.
//
// ```
// {{$css := fileStat "css/style.css" -}}
// <link rel="stylesheet" href="/css/style.css?v={{ $css.ModTime.Unix }}">
// ```
//
// ##### `markdown`
//
// Renders the given Markdown text as HTML and returns it. This uses the
+2 -2
View File
@@ -153,9 +153,9 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy.Context) *tls.Config {
// in its config (remember, TLS connection policies are used by *other* apps to
// run TLS servers) -- we skip names with placeholders
if tlsApp.EncryptedClientHello.Publication == nil {
var echNames []string
repl := caddy.NewReplacer()
for _, p := range cp {
var echNames []string
for _, m := range p.matchers {
if sni, ok := m.(MatchServerName); ok {
for _, name := range sni {
@@ -164,8 +164,8 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy.Context) *tls.Config {
}
}
}
tlsApp.RegisterServerNames(echNames, p.ALPN)
}
tlsApp.RegisterServerNames(echNames)
}
tlsCfg.GetEncryptedClientHelloKeys = func(chi *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
+7 -26
View File
@@ -440,10 +440,6 @@ func (t *TLS) publishECHConfigs(logger *zap.Logger) error {
zap.Strings("domains", dnsNamesToPublish),
zap.Uint8s("config_ids", configIDs))
if dnsPublisher, ok := publisher.(*ECHDNSPublisher); ok {
dnsPublisher.alpnByDomain = t.alpnValuesForServerNames(dnsNamesToPublish)
}
// publish this ECH config list with this publisher
pubTime := time.Now()
err := publisher.PublishECHConfigList(t.ctx, dnsNamesToPublish, echCfgListBin)
@@ -780,8 +776,7 @@ type ECHDNSPublisher struct {
ProviderRaw json.RawMessage `json:"provider,omitempty" caddy:"namespace=dns.providers inline_key=name"`
provider ECHDNSProvider
alpnByDomain map[string][]string
logger *zap.Logger
logger *zap.Logger
}
// CaddyModule returns the Caddy module information.
@@ -877,7 +872,12 @@ nextName:
continue
}
params := httpsRec.Params
params = dnsPub.publishedSvcParams(domain, params, configListBin)
if params == nil {
params = make(libdns.SvcParams)
}
// overwrite only the "ech" SvcParamKey
params["ech"] = []string{base64.StdEncoding.EncodeToString(configListBin)}
// publish record
_, err = dnsPub.provider.SetRecords(ctx, zone, []libdns.Record{
@@ -903,25 +903,6 @@ nextName:
return nil
}
func (dnsPub *ECHDNSPublisher) publishedSvcParams(domain string, existing libdns.SvcParams, configListBin []byte) libdns.SvcParams {
params := make(libdns.SvcParams, len(existing)+2)
for key, values := range existing {
params[key] = append([]string(nil), values...)
}
params["ech"] = []string{base64.StdEncoding.EncodeToString(configListBin)}
if len(dnsPub.alpnByDomain) == 0 {
return params
}
if alpn := dnsPub.alpnByDomain[strings.ToLower(domain)]; len(alpn) > 0 {
params["alpn"] = append([]string(nil), alpn...)
}
return params
}
// echConfig represents an ECHConfig from the specification,
// [draft-ietf-tls-esni-22](https://www.ietf.org/archive/id/draft-ietf-tls-esni-22.html).
type echConfig struct {
-65
View File
@@ -1,65 +0,0 @@
package caddytls
import (
"encoding/base64"
"reflect"
"sync"
"testing"
"github.com/libdns/libdns"
)
func TestRegisterServerNamesWithALPN(t *testing.T) {
tlsApp := &TLS{
serverNames: make(map[string]serverNameRegistration),
serverNamesMu: new(sync.Mutex),
}
tlsApp.RegisterServerNames([]string{
"Example.com:443",
"example.com",
"127.0.0.1:443",
}, []string{"h2", "http/1.1"})
tlsApp.RegisterServerNames([]string{"EXAMPLE.COM"}, []string{"h3"})
got := tlsApp.alpnValuesForServerNames([]string{"example.com:443", "127.0.0.1:443"})
want := map[string][]string{
"example.com": {"h3", "h2", "http/1.1"},
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected ALPN values: got %#v want %#v", got, want)
}
}
func TestECHDNSPublisherPublishedSvcParams(t *testing.T) {
dnsPub := &ECHDNSPublisher{
alpnByDomain: map[string][]string{
"example.com": {"h3", "h2", "http/1.1"},
},
}
existing := libdns.SvcParams{
"alpn": {"h2"},
"ipv4hint": {"203.0.113.10"},
}
got := dnsPub.publishedSvcParams("Example.com", existing, []byte{0x01, 0x02, 0x03})
if !reflect.DeepEqual(existing["alpn"], []string{"h2"}) {
t.Fatalf("existing params mutated: got %v", existing["alpn"])
}
if !reflect.DeepEqual(got["alpn"], []string{"h3", "h2", "http/1.1"}) {
t.Fatalf("unexpected ALPN params: got %v", got["alpn"])
}
if !reflect.DeepEqual(got["ipv4hint"], []string{"203.0.113.10"}) {
t.Fatalf("unexpected preserved params: got %v", got["ipv4hint"])
}
wantECH := base64.StdEncoding.EncodeToString([]byte{0x01, 0x02, 0x03})
if !reflect.DeepEqual(got["ech"], []string{wantECH}) {
t.Fatalf("unexpected ECH params: got %v want %v", got["ech"], wantECH)
}
}
+16 -104
View File
@@ -23,7 +23,6 @@ import (
"net"
"net/http"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
@@ -141,7 +140,7 @@ type TLS struct {
logger *zap.Logger
events *caddyevents.App
serverNames map[string]serverNameRegistration
serverNames map[string]struct{}
serverNamesMu *sync.Mutex
// set of subjects with managed certificates,
@@ -169,7 +168,7 @@ func (t *TLS) Provision(ctx caddy.Context) error {
t.logger = ctx.Logger()
repl := caddy.NewReplacer()
t.managing, t.loaded = make(map[string]string), make(map[string]string)
t.serverNames = make(map[string]serverNameRegistration)
t.serverNames = make(map[string]struct{})
t.serverNamesMu = new(sync.Mutex)
// set up default DNS module, if any, and make sure it implements all the
@@ -614,8 +613,8 @@ func (t *TLS) Manage(subjects map[string]struct{}) error {
// managingWildcardFor returns true if the app is managing a certificate that covers that
// subject name (including consideration of wildcards), either from its internal list of
// names that it IS managing certs for, from the otherSubjsToManage which includes names
// that WILL be managed, or from names configured in the 'automate' loader.
// names that it IS managing certs for, or from the otherSubjsToManage which includes names
// that WILL be managed.
func (t *TLS) managingWildcardFor(subj string, otherSubjsToManage map[string]struct{}) bool {
// TODO: we could also consider manually-loaded certs using t.HasCertificateForSubject(),
// but that does not account for how manually-loaded certs may be restricted as to which
@@ -630,9 +629,7 @@ func (t *TLS) managingWildcardFor(subj string, otherSubjsToManage map[string]str
return managing
}
// replace labels of the domain with wildcards until we get a match from names
// already being managed, those about to be managed in this batch, or those
// configured for automation
// replace labels of the domain with wildcards until we get a match
labels := strings.Split(subj, ".")
for i := range labels {
if labels[i] == "*" {
@@ -646,117 +643,32 @@ func (t *TLS) managingWildcardFor(subj string, otherSubjsToManage map[string]str
if _, ok := otherSubjsToManage[candidate]; ok {
return true
}
if _, ok := t.automateNames[candidate]; ok {
return true
}
}
return false
}
// RegisterServerNames registers the provided DNS names with the TLS app and
// associates them with the given HTTPS RR ALPN values, if any. This is
// currently used to auto-publish Encrypted ClientHello (ECH) configurations,
// if enabled. Use of this function by apps using the TLS app removes the need
// for the user to redundantly specify domain names in their configuration.
// This function separates hostname and port, keeping only the hostname, and
// filters IP addresses which can't be used with ECH.
// RegisterServerNames registers the provided DNS names with the TLS app.
// This is currently used to auto-publish Encrypted ClientHello (ECH)
// configurations, if enabled. Use of this function by apps using the TLS
// app removes the need for the user to redundantly specify domain names
// in their configuration. This function separates hostname and port
// (keeping only the hotsname) and filters IP addresses, which can't be
// used with ECH.
//
// EXPERIMENTAL: This function and its semantics/behavior are subject to change.
func (t *TLS) RegisterServerNames(dnsNames, alpnValues []string) {
func (t *TLS) RegisterServerNames(dnsNames []string) {
t.serverNamesMu.Lock()
defer t.serverNamesMu.Unlock()
for _, name := range dnsNames {
host, _, err := net.SplitHostPort(name)
if err != nil {
host = name
}
host = strings.ToLower(strings.TrimSpace(host))
if host == "" || certmagic.SubjectIsIP(host) {
continue
}
registration := t.serverNames[host]
if len(alpnValues) == 0 {
t.serverNames[host] = registration
continue
}
if registration.alpnValues == nil {
registration.alpnValues = make(map[string]struct{}, len(alpnValues))
}
for _, alpn := range alpnValues {
if alpn == "" {
continue
}
registration.alpnValues[alpn] = struct{}{}
}
t.serverNames[host] = registration
}
}
func (t *TLS) alpnValuesForServerNames(dnsNames []string) map[string][]string {
t.serverNamesMu.Lock()
defer t.serverNamesMu.Unlock()
result := make(map[string][]string, len(dnsNames))
for _, name := range dnsNames {
host, _, err := net.SplitHostPort(name)
if err != nil {
host = name
}
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
continue
}
registration, ok := t.serverNames[host]
if !ok || len(registration.alpnValues) == 0 {
continue
}
result[host] = OrderedHTTPSRRALPN(registration.alpnValues)
}
return result
}
// OrderedHTTPSRRALPN returns the HTTPS RR ALPN values in preferred order.
func OrderedHTTPSRRALPN(alpnSet map[string]struct{}) []string {
if len(alpnSet) == 0 {
return nil
}
knownOrder := append([]string{"h3"}, defaultALPN...)
ordered := make([]string, 0, len(alpnSet))
seen := make(map[string]struct{}, len(alpnSet))
for _, alpn := range knownOrder {
if _, ok := alpnSet[alpn]; ok {
ordered = append(ordered, alpn)
seen[alpn] = struct{}{}
if strings.TrimSpace(host) != "" && !certmagic.SubjectIsIP(host) {
t.serverNames[strings.ToLower(host)] = struct{}{}
}
}
if len(ordered) == len(alpnSet) {
return ordered
}
var remaining []string
for alpn := range alpnSet {
if _, ok := seen[alpn]; ok {
continue
}
remaining = append(remaining, alpn)
}
slices.Sort(remaining)
return append(ordered, remaining...)
}
type serverNameRegistration struct {
alpnValues map[string]struct{}
t.serverNamesMu.Unlock()
}
// HandleHTTPChallenge ensures that the ACME HTTP challenge or ZeroSSL HTTP
-96
View File
@@ -1,96 +0,0 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"encoding/json"
"testing"
"github.com/caddyserver/caddy/v2"
)
func TestAvoidDuplicateAutomation(t *testing.T) {
tests := []struct {
name string
automateNames []string
expectedToManage bool
}{
{
name: "do not manage if wildcard is automated",
automateNames: []string{"*.example.com"},
expectedToManage: false,
},
{
name: "manage if no automation configured",
automateNames: []string{},
expectedToManage: true,
},
{
name: "manage if explicitly requested even when wildcard automated",
automateNames: []string{"*.example.com", "sub.example.com"},
expectedToManage: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
automateJSON, err := json.Marshal(tc.automateNames)
if err != nil {
t.Fatal(err)
}
tlsApp := &TLS{
Automation: &AutomationConfig{
Policies: []*AutomationPolicy{
{
IssuersRaw: []json.RawMessage{
[]byte(`{"module": "internal"}`),
},
},
},
},
CertificatesRaw: map[string]json.RawMessage{
"automate": automateJSON,
},
}
var cfg caddy.Config
ctx, err := caddy.ProvisionContext(&cfg)
if err != nil {
t.Fatal(err)
}
if err := tlsApp.Provision(ctx); err != nil {
t.Fatal(err)
}
// simulate a case wherein the HTTP app starts first and
// tells the TLS app about the following auto-HTTPS domains
httpDomains := map[string]struct{}{"sub.example.com": {}}
if err := tlsApp.Manage(httpDomains); err != nil {
t.Fatal(err)
}
_, actuallyManaged := tlsApp.managing["sub.example.com"]
if actuallyManaged != tc.expectedToManage {
t.Errorf(
"expected sub.example.com individually managed: %v, got: %v",
tc.expectedToManage,
actuallyManaged,
)
}
})
}
}