Compare commits

..

26 Commits

Author SHA1 Message Date
Matt Holt d0a3cf0a0a Merge branch 'master' into proxy-stream-detached 2026-05-12 12:18:27 -06:00
James Hartig 77e9ce7404 reverseproxy: further prevent body closes from dial errors (#7715)
Cross-Build / build (~1.26.0, 1.26, dragonfly) (push) Successful in 1m28s
Cross-Build / build (~1.26.0, 1.26, illumos) (push) Successful in 1m26s
Cross-Build / build (~1.26.0, 1.26, freebsd) (push) Successful in 3m19s
Cross-Build / build (~1.26.0, 1.26, aix) (push) Successful in 3m55s
Cross-Build / build (~1.26.0, 1.26, darwin) (push) Successful in 3m56s
Cross-Build / build (~1.26.0, 1.26, linux) (push) Successful in 1m28s
Cross-Build / build (~1.26.0, 1.26, windows) (push) Successful in 1m26s
Cross-Build / build (~1.26.0, 1.26, openbsd) (push) Successful in 2m50s
Cross-Build / build (~1.26.0, 1.26, solaris) (push) Successful in 2m54s
Cross-Build / build (~1.26.0, 1.26, netbsd) (push) Successful in 5m14s
OpenSSF Scorecard supply-chain security / Scorecard analysis (push) Failing after 6m20s
Tests / test (./cmd/caddy/caddy, ~1.26.0, macos-14, 0, 1.26, mac) (push) Has been cancelled
Tests / test (./cmd/caddy/caddy.exe, ~1.26.0, windows-latest, True, 1.26, windows) (push) Has been cancelled
Lint / lint (macos-14, mac) (push) Has been cancelled
Lint / lint (windows-latest, windows) (push) Has been cancelled
Tests / test (./cmd/caddy/caddy, ~1.26.0, ubuntu-latest, 0, 1.26, linux) (push) Failing after 4m41s
Tests / test (s390x on IBM Z) (push) Has been skipped
Tests / goreleaser-check (push) Has been skipped
Lint / lint (ubuntu-latest, linux) (push) Successful in 4m47s
Lint / govulncheck (push) Successful in 1m16s
Lint / dependency-review (push) Failing after 1m9s
2026-05-12 12:05:50 -06:00
Matthew Holt cc58caa109 go.mod: Upgrade quic-go to v0.59.1
Tests / goreleaser-check (push) Has been skipped
Tests / test (s390x on IBM Z) (push) Has been skipped
Tests / test (./cmd/caddy/caddy, ~1.26.0, ubuntu-latest, 0, 1.26, linux) (push) Failing after 1m32s
Cross-Build / build (~1.26.0, 1.26, darwin) (push) Successful in 2m47s
Cross-Build / build (~1.26.0, 1.26, aix) (push) Successful in 2m48s
Cross-Build / build (~1.26.0, 1.26, illumos) (push) Successful in 1m26s
Cross-Build / build (~1.26.0, 1.26, freebsd) (push) Successful in 3m24s
Cross-Build / build (~1.26.0, 1.26, dragonfly) (push) Successful in 3m25s
Cross-Build / build (~1.26.0, 1.26, openbsd) (push) Successful in 1m23s
Cross-Build / build (~1.26.0, 1.26, linux) (push) Successful in 2m50s
Cross-Build / build (~1.26.0, 1.26, netbsd) (push) Successful in 2m53s
Lint / lint (ubuntu-latest, linux) (push) Successful in 2m21s
Lint / dependency-review (push) Failing after 1m14s
Lint / govulncheck (push) Successful in 1m44s
Cross-Build / build (~1.26.0, 1.26, solaris) (push) Successful in 4m26s
Cross-Build / build (~1.26.0, 1.26, windows) (push) Successful in 4m27s
OpenSSF Scorecard supply-chain security / Scorecard analysis (push) Failing after 7m5s
Tests / test (./cmd/caddy/caddy, ~1.26.0, macos-14, 0, 1.26, mac) (push) Has been cancelled
Tests / test (./cmd/caddy/caddy.exe, ~1.26.0, windows-latest, True, 1.26, windows) (push) Has been cancelled
Lint / lint (macos-14, mac) (push) Has been cancelled
Lint / lint (windows-latest, windows) (push) Has been cancelled
2026-05-11 17:33:42 -06:00
Br1an d80774cb3f metrics: Add nil check for metricsHandler in AdminMetrics.serveHTTP (#7553)
* metrics: Add nil check for metricsHandler in AdminMetrics.serveHTTP

Prevents panic when the admin metrics endpoint is accessed before
the module is fully provisioned. Returns a proper API error instead
of crashing.

* admin: provision router modules before registering routes

Instead of adding a nil check for metricsHandler, address the root
cause by provisioning admin router modules before calling Routes().
This ensures all handler state is initialized before routes are
registered on the mux.

Merge newAdminHandler and provisionAdminRouters into a single step,
removing the two-phase setup where routes were registered first and
modules provisioned later. The AdminConfig.routers field is no longer
needed since provisioning happens inline.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: go fmt admin.go

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-11 17:27:03 -06:00
Rayan Salhab a4a38c3e88 rewrite: escape file matcher paths before rewriting (#7683)
* fix: escape file matcher paths in rewrites

Preserve matched file paths containing literal '?' or '%' when try_files rewrites to http.matchers.file.relative.

* test: cover nested escaped try_files rewrite paths

* test: cover encoded slash try_files rewrite paths

* fix: assert file matcher placeholder as string

---------

Co-authored-by: cyphercodes <cyphercodes@users.noreply.github.com>
2026-05-11 17:16:33 -06:00
Matthew Holt 761347aa63 templates: Explicitly warn about misconfigurations 2026-05-11 16:45:49 -06:00
Steffen Busch 4ba16fe82c docs: add documentation for fileExists and fileStat template functions (#7700)
Tests / test (s390x on IBM Z) (push) Has been skipped
Tests / goreleaser-check (push) Has been skipped
Cross-Build / build (~1.26.0, 1.26, dragonfly) (push) Failing after 1m42s
Cross-Build / build (~1.26.0, 1.26, aix) (push) Successful in 2m37s
Cross-Build / build (~1.26.0, 1.26, darwin) (push) Successful in 3m36s
Tests / test (./cmd/caddy/caddy, ~1.26.0, ubuntu-latest, 0, 1.26, linux) (push) Failing after 3m44s
Cross-Build / build (~1.26.0, 1.26, freebsd) (push) Successful in 3m55s
Cross-Build / build (~1.26.0, 1.26, linux) (push) Successful in 1m24s
Cross-Build / build (~1.26.0, 1.26, illumos) (push) Successful in 2m44s
Cross-Build / build (~1.26.0, 1.26, windows) (push) Successful in 1m20s
Cross-Build / build (~1.26.0, 1.26, solaris) (push) Successful in 2m35s
Cross-Build / build (~1.26.0, 1.26, netbsd) (push) Successful in 2m51s
Cross-Build / build (~1.26.0, 1.26, openbsd) (push) Successful in 2m53s
Lint / govulncheck (push) Successful in 1m41s
Lint / lint (ubuntu-latest, linux) (push) Successful in 3m4s
Lint / dependency-review (push) Failing after 1m4s
OpenSSF Scorecard supply-chain security / Scorecard analysis (push) Failing after 6m4s
Tests / test (./cmd/caddy/caddy, ~1.26.0, macos-14, 0, 1.26, mac) (push) Has been cancelled
Tests / test (./cmd/caddy/caddy.exe, ~1.26.0, windows-latest, True, 1.26, windows) (push) Has been cancelled
Lint / lint (macos-14, mac) (push) Has been cancelled
Lint / lint (windows-latest, windows) (push) Has been cancelled
2026-05-12 04:23:58 +10:00
Rijul 0fab9f0f7d caddytls: avoid duplicate automation for wildcard-covered hosts (#7697)
Tests / test (s390x on IBM Z) (push) Has been skipped
Tests / goreleaser-check (push) Has been skipped
Cross-Build / build (~1.26.0, 1.26, aix) (push) Successful in 1m24s
Tests / test (./cmd/caddy/caddy, ~1.26.0, ubuntu-latest, 0, 1.26, linux) (push) Failing after 1m39s
Cross-Build / build (~1.26.0, 1.26, freebsd) (push) Successful in 1m48s
Cross-Build / build (~1.26.0, 1.26, dragonfly) (push) Successful in 2m32s
Cross-Build / build (~1.26.0, 1.26, illumos) (push) Successful in 1m24s
Cross-Build / build (~1.26.0, 1.26, darwin) (push) Successful in 3m26s
Cross-Build / build (~1.26.0, 1.26, linux) (push) Successful in 2m10s
Cross-Build / build (~1.26.0, 1.26, netbsd) (push) Successful in 1m58s
Cross-Build / build (~1.26.0, 1.26, solaris) (push) Successful in 1m23s
Lint / dependency-review (push) Failing after 24s
OpenSSF Scorecard supply-chain security / Scorecard analysis (push) Failing after 27s
Lint / govulncheck (push) Successful in 1m20s
Lint / lint (ubuntu-latest, linux) (push) Successful in 1m44s
Cross-Build / build (~1.26.0, 1.26, openbsd) (push) Successful in 3m21s
Cross-Build / build (~1.26.0, 1.26, windows) (push) Successful in 2m41s
Tests / test (./cmd/caddy/caddy, ~1.26.0, macos-14, 0, 1.26, mac) (push) Has been cancelled
Tests / test (./cmd/caddy/caddy.exe, ~1.26.0, windows-latest, True, 1.26, windows) (push) Has been cancelled
Lint / lint (macos-14, mac) (push) Has been cancelled
Lint / lint (windows-latest, windows) (push) Has been cancelled
* caddytls: Fix wildcard race in auto-HTTPS launch

When evaluating whether to skip managing an individual subdomain
due to an existing wildcard configuration, we now explicitly consult
the automate loader.

Because Caddy apps can start in any order, relying strictly on the
TLS app's internal management state was non-deterministic if the
HTTP app started first. Checking the automate loader guarantees
predictable behavior since it is fully populated during the
Provision phase, well before any apps are started.

* respond to review comments

1. update requested comment
2. remove personal domain from test
3. add regression test

* remove unnecessary mutex lock

* refactor: -integration test, +explicit cases

* refactor: remove redundant test, add comment

* rename file and add header

* update copyright year
2026-05-11 00:08:40 +10:00
Zen Dodd 5e76b5ee43 tls: add alpn to managed HTTPS records (#7653)
Tests / test (s390x on IBM Z) (push) Has been skipped
Tests / goreleaser-check (push) Has been skipped
Cross-Build / build (~1.26.0, 1.26, aix) (push) Successful in 1m28s
Tests / test (./cmd/caddy/caddy, ~1.26.0, ubuntu-latest, 0, 1.26, linux) (push) Failing after 1m56s
Cross-Build / build (~1.26.0, 1.26, freebsd) (push) Successful in 2m18s
Cross-Build / build (~1.26.0, 1.26, illumos) (push) Successful in 1m26s
Cross-Build / build (~1.26.0, 1.26, darwin) (push) Successful in 2m55s
Cross-Build / build (~1.26.0, 1.26, dragonfly) (push) Successful in 3m3s
Cross-Build / build (~1.26.0, 1.26, linux) (push) Successful in 1m50s
Cross-Build / build (~1.26.0, 1.26, openbsd) (push) Successful in 1m24s
Cross-Build / build (~1.26.0, 1.26, netbsd) (push) Successful in 2m7s
Lint / govulncheck (push) Successful in 1m14s
Lint / dependency-review (push) Failing after 1m14s
OpenSSF Scorecard supply-chain security / Scorecard analysis (push) Failing after 29s
Lint / lint (ubuntu-latest, linux) (push) Successful in 2m43s
Cross-Build / build (~1.26.0, 1.26, solaris) (push) Successful in 3m54s
Cross-Build / build (~1.26.0, 1.26, windows) (push) Successful in 3m48s
Tests / test (./cmd/caddy/caddy, ~1.26.0, macos-14, 0, 1.26, mac) (push) Has been cancelled
Tests / test (./cmd/caddy/caddy.exe, ~1.26.0, windows-latest, True, 1.26, windows) (push) Has been cancelled
Lint / lint (macos-14, mac) (push) Has been cancelled
Lint / lint (windows-latest, windows) (push) Has been cancelled
* tls: add alpn to managed HTTPS records

* tls: centralise HTTPS RR ALPN defaults and registration

Reuse shared protocol defaults instead of repeating the default HTTP protocol list, unify server name registration to carry ALPN in one experimental API and reuse the TLS default ALPN ordering for HTTPS RR publication

* http: centralise effective protocol resolution for HTTPS RR ALPN
2026-05-10 13:10:29 +10:00
Francis Lavoie eeb13f1ca8 More comments 2026-04-25 05:42:43 -04:00
Francis Lavoie 97f5fe0079 Rename to stream_detached 2026-04-25 05:38:37 -04:00
Francis Lavoie 558ec222db Add note about capturing h 2026-04-25 05:38:37 -04:00
Francis Lavoie e3b1bf80f4 Rename to tunnelTracker, reflow some comments 2026-04-25 05:38:37 -04:00
Francis Lavoie 1b8d60c459 Move type and const down to the bottom 2026-04-25 05:38:37 -04:00
WeidiDeng 733aaba102 only clean up connections when stopped 2026-04-25 05:38:37 -04:00
WeidiDeng ed44e4d3f6 change the log level if hijacking without writing a status code first 2026-04-25 05:38:37 -04:00
WeidiDeng f970f397e2 fix tests 2026-04-25 05:38:37 -04:00
WeidiDeng 6ba6cf5d13 fix tests 2026-04-25 05:38:37 -04:00
WeidiDeng ccc76ac1f6 make handleUpgradeTunnel a standalone func 2026-04-25 05:38:37 -04:00
WeidiDeng cee04ab28e correctly close detached streams 2026-04-25 05:38:37 -04:00
WeidiDeng e7055d85a4 simplify streaming handling 2026-04-25 05:38:37 -04:00
WeidiDeng b9b12025c6 record bytes read and written for response writers unless detached 2026-04-25 05:38:37 -04:00
Francis Lavoie 7ef9ecd48a Adjustments from Weidi's review 2026-04-25 05:38:37 -04:00
Francis Lavoie 307dfd0431 Improved logging facilities 2026-04-25 05:38:37 -04:00
Francis Lavoie daea7788ad lint 2026-04-25 05:38:37 -04:00
Francis Lavoie b68e9bfdd4 reverseproxy: Optionally detach stream (websockets) from config lifecycle 2026-04-25 05:38:37 -04:00
35 changed files with 2671 additions and 389 deletions
+3 -1
View File
@@ -132,6 +132,8 @@ jobs:
- name: Run tests
# id: step_test
# continue-on-error: true
env:
GODEBUG: http2xconnect=1
run: |
# (go test -v -coverprofile=cover-profile.out -race ./... 2>&1) > test-results/test-result.out
go test -v -coverprofile="cover-profile.out" -short -race ./...
@@ -191,7 +193,7 @@ jobs:
retries=3
exit_code=0
while ((retries > 0)); do
CGO_ENABLED=0 go test -p 1 -v ./...
GODEBUG=http2xconnect=1 CGO_ENABLED=0 go test -p 1 -v ./...
exit_code=$?
if ((exit_code == 0)); then
break
+13 -38
View File
@@ -120,10 +120,6 @@ type AdminConfig struct {
//
// EXPERIMENTAL: This feature is subject to change.
Remote *RemoteAdmin `json:"remote,omitempty"`
// Holds onto the routers so that we can later provision them
// if they require provisioning.
routers []AdminRouter
}
// ConfigSettings configures the management of configuration.
@@ -222,7 +218,7 @@ type AdminPermissions struct {
// newAdminHandler reads admin's config and returns an http.Handler suitable
// for use in an admin endpoint server, which will be listening on listenAddr.
func (admin *AdminConfig) newAdminHandler(addr NetworkAddress, remote bool, _ Context) adminHandler {
func (admin *AdminConfig) newAdminHandler(addr NetworkAddress, remote bool, ctx Context) (adminHandler, error) {
muxWrap := adminHandler{mux: http.NewServeMux()}
// secure the local or remote endpoint respectively
@@ -279,34 +275,21 @@ func (admin *AdminConfig) newAdminHandler(addr NetworkAddress, remote bool, _ Co
// register third-party module endpoints
for _, m := range GetModules("admin.api") {
router := m.New().(AdminRouter)
// provision the router before registering its routes, so
// handlers have access to all provisioned state
if provisioner, ok := router.(Provisioner); ok {
if err := provisioner.Provision(ctx); err != nil {
return adminHandler{}, fmt.Errorf("provisioning admin router module %s: %v", m.ID, err)
}
}
for _, route := range router.Routes() {
addRoute(route.Pattern, handlerLabel, route.Handler)
}
admin.routers = append(admin.routers, router)
}
return muxWrap
}
// provisionAdminRouters provisions all the router modules
// in the admin.api namespace that need provisioning.
func (admin *AdminConfig) provisionAdminRouters(ctx Context) error {
for _, router := range admin.routers {
provisioner, ok := router.(Provisioner)
if !ok {
continue
}
err := provisioner.Provision(ctx)
if err != nil {
return err
}
}
// We no longer need the routers once provisioned, allow for GC
admin.routers = nil
return nil
return muxWrap, nil
}
// allowedOrigins returns a list of origins that are allowed.
@@ -430,11 +413,7 @@ func replaceLocalAdminServer(cfg *Config, ctx Context) error {
return err
}
handler := cfg.Admin.newAdminHandler(addr, false, ctx)
// run the provisioners for loaded modules to make sure local
// state is properly re-initialized in the new admin server
err = cfg.Admin.provisionAdminRouters(ctx)
handler, err := cfg.Admin.newAdminHandler(addr, false, ctx)
if err != nil {
return err
}
@@ -558,11 +537,7 @@ func replaceRemoteAdminServer(ctx Context, cfg *Config) error {
// make the HTTP handler but disable Host/Origin enforcement
// because we are using TLS authentication instead
handler := cfg.Admin.newAdminHandler(addr, true, ctx)
// run the provisioners for loaded modules to make sure local
// state is properly re-initialized in the new admin server
err = cfg.Admin.provisionAdminRouters(ctx)
handler, err := cfg.Admin.newAdminHandler(addr, true, ctx)
if err != nil {
return err
}
+9 -15
View File
@@ -340,7 +340,10 @@ func TestAdminHandlerBuiltinRouteErrors(t *testing.T) {
if err != nil {
t.Fatalf("Failed to parse address: %v", err)
}
handler := cfg.Admin.newAdminHandler(addr, false, Context{})
handler, err := cfg.Admin.newAdminHandler(addr, false, Context{})
if err != nil {
t.Fatalf("Failed to create admin handler: %v", err)
}
tests := []struct {
name string
@@ -461,7 +464,10 @@ func TestNewAdminHandlerRouterRegistration(t *testing.T) {
admin := &AdminConfig{
EnforceOrigin: false,
}
handler := admin.newAdminHandler(addr, false, Context{})
handler, err := admin.newAdminHandler(addr, false, Context{})
if err != nil {
t.Fatalf("Failed to create admin handler: %v", err)
}
req := httptest.NewRequest("GET", "/mock", nil)
req.Host = "localhost:2019"
@@ -473,10 +479,6 @@ func TestNewAdminHandlerRouterRegistration(t *testing.T) {
t.Errorf("Expected status code %d but got %d", http.StatusOK, rr.Code)
t.Logf("Response body: %s", rr.Body.String())
}
if len(admin.routers) != 1 {
t.Errorf("Expected 1 router to be stored, got %d", len(admin.routers))
}
}
type mockProvisionableRouter struct {
@@ -514,19 +516,16 @@ func TestAdminRouterProvisioning(t *testing.T) {
name string
provisionErr error
wantErr bool
routersAfter int // expected number of routers after provisioning
}{
{
name: "successful provisioning",
provisionErr: nil,
wantErr: false,
routersAfter: 0,
},
{
name: "provisioning error",
provisionErr: fmt.Errorf("provision failed"),
wantErr: true,
routersAfter: 1,
},
}
@@ -562,8 +561,7 @@ func TestAdminRouterProvisioning(t *testing.T) {
t.Fatalf("Failed to parse address: %v", err)
}
_ = admin.newAdminHandler(addr, false, Context{})
err = admin.provisionAdminRouters(Context{})
_, err = admin.newAdminHandler(addr, false, Context{})
if test.wantErr {
if err == nil {
@@ -574,10 +572,6 @@ func TestAdminRouterProvisioning(t *testing.T) {
t.Errorf("Expected no error but got: %v", err)
}
}
if len(admin.routers) != test.routersAfter {
t.Errorf("Expected %d routers after provisioning, got %d", test.routersAfter, len(admin.routers))
}
})
}
}
-7
View File
@@ -440,13 +440,6 @@ func run(newCfg *Config, start bool) (Context, error) {
}
}()
// Provision any admin routers which may need to access
// some of the other apps at runtime
err = ctx.cfg.Admin.provisionAdminRouters(ctx)
if err != nil {
return ctx, err
}
// Start
err = func() error {
started := make([]string, 0, len(ctx.cfg.apps))
@@ -11,9 +11,7 @@ reverse_proxy 127.0.0.1:65535 {
@accel header X-Accel-Redirect *
handle_response @accel {
rewrite * {rp.header.X-Accel-Redirect} {
force_modify_query
}
respond "Header X-Accel-Redirect!"
}
@another {
@@ -106,12 +104,10 @@ reverse_proxy 127.0.0.1:65535 {
},
"routes": [
{
"group": "group0",
"handle": [
{
"force_modify_query": true,
"handler": "rewrite",
"uri": "{http.reverse_proxy.header.X-Accel-Redirect}"
"body": "Header X-Accel-Redirect!",
"handler": "static_response"
}
]
}
@@ -0,0 +1,328 @@
package integration
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"github.com/caddyserver/caddy/v2/caddytest"
)
var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support")
func TestReverseProxyExtendedConnectOverH2(t *testing.T) {
tester := caddytest.NewTester(t)
backend := newWebsocketUpgradeEchoBackend(t)
defer backend.Close()
tester.InitServer(fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
servers :9443 {
protocols h2
}
}
https://localhost:9443 {
reverse_proxy %s
}
`, backend.addr), "caddyfile")
const payload = "extended-connect-echo\n"
if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil {
if errors.Is(err, errExtendedConnectUnsupportedByPeer) {
t.Skipf("skipping extended CONNECT integration test: %v", err)
}
t.Fatalf("extended connect h2 echo failed: %v", err)
}
}
func assertExtendedConnectH2Echo(addr, payload string) error {
conn, err := tlsDialH2(addr)
if err != nil {
return fmt.Errorf("dialing h2 tls: %w", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return fmt.Errorf("setting deadline: %w", err)
}
fr := http2.NewFramer(conn, conn)
if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil {
return fmt.Errorf("writing client preface: %w", err)
}
if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil {
return fmt.Errorf("writing client settings: %w", err)
}
supported, err := waitForServerSettings(fr)
if err != nil {
return err
}
if !supported {
return errExtendedConnectUnsupportedByPeer
}
if err := waitForSettingsAck(fr); err != nil {
return err
}
if err := writeExtendedConnectHeaders(fr, addr); err != nil {
return err
}
status, err := readResponseStatus(fr, 1)
if err != nil {
return err
}
if status != "200" {
return fmt.Errorf("unexpected extended connect status: got=%s want=200", status)
}
if err := fr.WriteData(1, false, []byte(payload)); err != nil {
return fmt.Errorf("writing stream data: %w", err)
}
echo, err := readStreamData(fr, 1, len(payload))
if err != nil {
return err
}
if echo != payload {
return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload)
}
_ = fr.WriteRSTStream(1, http2.ErrCodeNo)
return nil
}
func tlsDialH2(addr string) (net.Conn, error) {
var lastErr error
for i := 0; i < 30; i++ {
dialer := &net.Dialer{Timeout: 2 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
ServerName: "localhost",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
})
if err == nil {
return conn, nil
}
lastErr = err
time.Sleep(100 * time.Millisecond)
}
return nil, lastErr
}
func waitForServerSettings(fr *http2.Framer) (bool, error) {
for {
frame, err := fr.ReadFrame()
if err != nil {
return false, fmt.Errorf("reading frame before connect: %w", err)
}
settings, ok := frame.(*http2.SettingsFrame)
if !ok {
continue
}
if settings.IsAck() {
continue
}
supported := false
if err := settings.ForeachSetting(func(s http2.Setting) error {
if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 {
supported = true
}
return nil
}); err != nil {
return false, fmt.Errorf("reading server settings: %w", err)
}
if err := fr.WriteSettingsAck(); err != nil {
return false, fmt.Errorf("writing settings ack: %w", err)
}
return supported, nil
}
}
func waitForSettingsAck(fr *http2.Framer) error {
for {
frame, err := fr.ReadFrame()
if err != nil {
return fmt.Errorf("reading settings ack: %w", err)
}
settings, ok := frame.(*http2.SettingsFrame)
if ok && settings.IsAck() {
return nil
}
}
}
func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error {
var hb bytes.Buffer
enc := hpack.NewEncoder(&hb)
for _, hf := range []hpack.HeaderField{
{Name: ":method", Value: "CONNECT"},
{Name: ":scheme", Value: "https"},
{Name: ":authority", Value: addr},
{Name: ":path", Value: "/upgrade"},
{Name: ":protocol", Value: "websocket"},
} {
if err := enc.WriteField(hf); err != nil {
return fmt.Errorf("encoding request headers: %w", err)
}
}
if err := fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1,
BlockFragment: hb.Bytes(),
EndHeaders: true,
EndStream: false,
}); err != nil {
return fmt.Errorf("writing extended connect headers: %w", err)
}
return nil
}
func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) {
var block bytes.Buffer
for {
frame, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading response headers: %w", err)
}
if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID {
return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode)
}
h, ok := frame.(*http2.HeadersFrame)
if !ok || h.StreamID != streamID {
continue
}
if _, err := block.Write(h.HeaderBlockFragment()); err != nil {
return "", fmt.Errorf("buffering response header fragment: %w", err)
}
for !h.HeadersEnded() {
next, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading continuation frame: %w", err)
}
c, ok := next.(*http2.ContinuationFrame)
if !ok || c.StreamID != streamID {
continue
}
if _, err := block.Write(c.HeaderBlockFragment()); err != nil {
return "", fmt.Errorf("buffering continuation fragment: %w", err)
}
if c.HeadersEnded() {
break
}
}
break
}
var status string
dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) {
if f.Name == ":status" {
status = f.Value
}
})
if _, err := dec.Write(block.Bytes()); err != nil {
return "", fmt.Errorf("decoding response header block: %w", err)
}
if status == "" {
return "", fmt.Errorf("missing :status in response headers")
}
return status, nil
}
func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) {
buf := make([]byte, 0, n)
for len(buf) < n {
frame, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading stream data: %w", err)
}
d, ok := frame.(*http2.DataFrame)
if !ok || d.StreamID != streamID {
continue
}
buf = append(buf, d.Data()...)
}
return string(buf[:n]), nil
}
type websocketUpgradeEchoBackend struct {
addr string
ln net.Listener
server *http.Server
}
func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend {
t.Helper()
backend := &websocketUpgradeEchoBackend{}
backend.server = &http.Server{
Handler: http.HandlerFunc(backend.serveHTTP),
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listening for websocket backend: %v", err)
}
backend.ln = ln
backend.addr = ln.Addr().String()
go func() {
_ = backend.server.Serve(ln)
}()
return backend
}
func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
conn, rw, err := hijacker.Hijack()
if err != nil {
return
}
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
_ = rw.Flush()
go func() {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}()
}
func (b *websocketUpgradeEchoBackend) Close() {
_ = b.server.Close()
_ = b.ln.Close()
}
@@ -0,0 +1,130 @@
package integration
import (
"bufio"
"fmt"
"io"
"net"
"net/textproto"
"strings"
"testing"
"time"
"github.com/caddyserver/caddy/v2/caddytest"
)
func TestReverseProxyUpgradeWithEncode(t *testing.T) {
tester := caddytest.NewTester(t)
backend := newUpgradeEchoBackend(t)
defer backend.Close()
tester.InitServer(fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
}
localhost:9080 {
route {
encode gzip
reverse_proxy %s
}
}
`, backend.addr), "caddyfile")
client := newUpgradedStreamClientWithHeaders(t, map[string]string{
"Accept-Encoding": "gzip",
})
defer client.Close()
if err := client.echo("encode-upgrade\n"); err != nil {
t.Fatalf("upgraded stream echo through encode failed: %v", err)
}
}
func TestReverseProxyUpgradeWithInterceptHandleResponse(t *testing.T) {
tester := caddytest.NewTester(t)
backend := newUpgradeEchoBackend(t)
defer backend.Close()
tester.InitServer(fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
}
localhost:9080 {
route {
intercept {
@upgrade status 101
handle_response @upgrade {
respond "should-not-run"
}
}
reverse_proxy %s
}
}
`, backend.addr), "caddyfile")
client := newUpgradedStreamClientWithHeaders(t, nil)
defer client.Close()
if err := client.echo("intercept-upgrade\n"); err != nil {
t.Fatalf("upgraded stream echo through intercept failed: %v", err)
}
}
func newUpgradedStreamClientWithHeaders(t *testing.T, extraHeaders map[string]string) *upgradedStreamClient {
t.Helper()
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
if err != nil {
t.Fatalf("dialing caddy: %v", err)
}
requestLines := []string{
"GET /upgrade HTTP/1.1",
"Host: localhost:9080",
"Connection: Upgrade",
"Upgrade: stress-stream",
}
for k, v := range extraHeaders {
requestLines = append(requestLines, k+": "+v)
}
requestLines = append(requestLines, "", "")
if _, err := io.WriteString(conn, strings.Join(requestLines, "\r\n")); err != nil {
_ = conn.Close()
t.Fatalf("writing upgrade request: %v", err)
}
reader := bufio.NewReader(conn)
tproto := textproto.NewReader(reader)
statusLine, err := tproto.ReadLine()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade status line: %v", err)
}
if !strings.Contains(statusLine, "101") {
_ = conn.Close()
t.Fatalf("unexpected upgrade status: %s", statusLine)
}
headers, err := tproto.ReadMIMEHeader()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade headers: %v", err)
}
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
_ = conn.Close()
t.Fatalf("unexpected upgrade response headers: %v", headers)
}
return &upgradedStreamClient{conn: conn, reader: reader}
}
@@ -0,0 +1,504 @@
package integration
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"os"
"runtime"
"runtime/debug"
"runtime/pprof"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2/caddytest"
)
const (
defaultStressStreamCount = 1
defaultStressReloadCount = 1
defaultStressCloseDelay = 500 * time.Millisecond
)
func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) {
tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{
LoadRequestTimeout: 30 * time.Second,
TestRequestTimeout: 30 * time.Second,
})
backend := newUpgradeEchoBackend(t)
defer backend.Close()
// Three scenarios, each sequential so they don't share Caddy state:
//
// legacy no delay, close on reload immediately (old default)
// close_delay stream_close_delay, the old "keep-alive workaround"
// detached stream_detached, the new explicit detached flag
//
// Reloads are spread across time and interleaved with echo-checks so
// stream health is exercised at each reload boundary, not only at the end.
legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0)
closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t))
detached := runReloadStress(t, tester, backend.addr, "detached", true, 0)
if legacy.aliveAfterReloads != 0 {
t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads)
}
if closeDelay.aliveBeforeDelayExpiry == 0 {
t.Fatalf("close_delay mode: all streams closed before delay expired (expected them alive)")
}
if closeDelay.aliveAfterReloads != 0 {
t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads)
}
if detached.aliveAfterReloads != detached.streamCount {
t.Fatalf("detached mode kept %d/%d upgraded streams alive after reloads", detached.aliveAfterReloads, detached.streamCount)
}
t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
formatBytes(legacy.beforeReload.HeapInuse),
formatBytes(legacy.midReload.HeapInuse),
formatBytes(legacy.afterReload.HeapInuse),
formatBytesDiff(legacy.beforeReload.HeapInuse, legacy.afterReload.HeapInuse),
legacy.beforeReload.HeapObjects, legacy.afterReload.HeapObjects,
legacy.beforeReload.handlerFrames, legacy.afterReload.handlerFrames,
)
t.Logf("close_delay heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
formatBytes(closeDelay.beforeReload.HeapInuse),
formatBytes(closeDelay.midReload.HeapInuse),
formatBytes(closeDelay.afterReload.HeapInuse),
formatBytesDiff(closeDelay.beforeReload.HeapInuse, closeDelay.afterReload.HeapInuse),
closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects,
closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames,
)
t.Logf("detached heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
formatBytes(detached.beforeReload.HeapInuse),
formatBytes(detached.midReload.HeapInuse),
formatBytes(detached.afterReload.HeapInuse),
formatBytesDiff(detached.beforeReload.HeapInuse, detached.afterReload.HeapInuse),
detached.beforeReload.HeapObjects, detached.afterReload.HeapObjects,
detached.beforeReload.handlerFrames, detached.afterReload.handlerFrames,
)
}
type stressRunResult struct {
streamCount int
aliveAfterReloads int
aliveBeforeDelayExpiry int // only meaningful for close_delay mode
beforeReload heapSnapshot
midReload heapSnapshot // after all reloads, before delay expiry clean-up
afterReload heapSnapshot // after all streams have been fully cleaned up
}
type heapSnapshot struct {
HeapInuse uint64
HeapObjects uint64
handlerFrames int
profileBytes int
}
// runReloadStress opens streamCount upgraded streams, then performs reloadCount
// config reloads spread over time. An echo check is performed every 6 reloads so
// stream health is exercised at each reload boundary rather than only at the end.
// closeDelay mirrors the stream_close_delay config option; pass 0 to disable.
func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, detach bool, closeDelay time.Duration) stressRunResult {
t.Helper()
const echoEvery = 6 // perform an echo check every N reloads
streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount)
reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount)
tester.InitServer(reloadStressConfig(backendAddr, detach, closeDelay, 0), "caddyfile")
clients := make([]*upgradedStreamClient, 0, streamCount)
for i := 0; i < streamCount; i++ {
client := newUpgradedStreamClient(t)
clients = append(clients, client)
if err := client.echo(fmt.Sprintf("%s-warmup-%02d\n", mode, i)); err != nil {
closeClients(clients)
t.Fatalf("warmup echo failed in %s mode: %v", mode, err)
}
}
defer closeClients(clients)
before := captureHeapSnapshot(t)
// Reloads are spread across time; between batches of echoEvery reloads we
// pause briefly and measure stream health so the snapshot reflects real-world
// reload cadence rather than a tight loop.
for i := 1; i <= reloadCount; i++ {
loadCaddyfileConfig(t, reloadStressConfig(backendAddr, detach, closeDelay, i))
// Small pause after each reload to let connection teardown propagate.
time.Sleep(50 * time.Millisecond)
if i%echoEvery == 0 {
alive := countAliveStreams(clients)
t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i)
// In detached mode, every stream must survive every reload (upstream unchanged).
if detach {
for j, client := range clients {
if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil {
t.Fatalf("detached mode stream %d died at reload %d: %v", j, i, err)
}
}
}
}
}
// mid snapshot: after all reloads but before any close_delay timer has fired
// (the delay is long enough to still be running at this point).
mid := captureHeapSnapshot(t)
// For legacy mode: the reloads close streams immediately; wait for that to complete.
// For close_delay mode: streams are still alive here; wait for the delay to fire.
// For detached mode: streams survive indefinitely; no wait needed.
var aliveBeforeDelayExpiry int
aliveAfterReloads := countAliveStreams(clients)
switch {
case detach:
// nothing to wait for
case closeDelay > 0:
// streams should still be alive at this point (delay hasn't expired)
aliveBeforeDelayExpiry = aliveAfterReloads
t.Logf("%s mode: %d/%d streams alive before close_delay expires; waiting %v for cleanup",
mode, aliveBeforeDelayExpiry, streamCount, closeDelay)
time.Sleep(closeDelay + 200*time.Millisecond)
aliveAfterReloads = countAliveStreams(clients)
default:
deadline := time.Now().Add(2 * time.Second)
for aliveAfterReloads > 0 && time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
aliveAfterReloads = countAliveStreams(clients)
}
}
after := captureHeapSnapshot(t)
t.Logf("%s mode heap profile size: before=%dB mid=%dB after=%dB objects(before=%d mid=%d after=%d)",
mode,
before.profileBytes, mid.profileBytes, after.profileBytes,
before.HeapObjects, mid.HeapObjects, after.HeapObjects,
)
return stressRunResult{
streamCount: streamCount,
aliveAfterReloads: aliveAfterReloads,
aliveBeforeDelayExpiry: aliveBeforeDelayExpiry,
beforeReload: before,
midReload: mid,
afterReload: after,
}
}
func envIntOrDefault(t *testing.T, key string, def int) int {
t.Helper()
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return def
}
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
t.Fatalf("invalid %s=%q: must be a positive integer", key, raw)
}
return v
}
func stressCloseDelay(t *testing.T) time.Duration {
t.Helper()
const key = "CADDY_STRESS_CLOSE_DELAY"
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return defaultStressCloseDelay
}
v, err := time.ParseDuration(raw)
if err != nil || v <= 0 {
t.Fatalf("invalid %s=%q: must be a positive duration", key, raw)
}
return v
}
func loadCaddyfileConfig(t *testing.T, rawConfig string) {
t.Helper()
client := &http.Client{Timeout: 30 * time.Second}
req, err := http.NewRequest(http.MethodPost, "http://localhost:2999/load", strings.NewReader(rawConfig))
if err != nil {
t.Fatalf("creating load request: %v", err)
}
req.Header.Set("Content-Type", "text/caddyfile")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("loading config: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("reading load response: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("loading config failed: status=%d body=%s", resp.StatusCode, body)
}
}
func reloadStressConfig(backendAddr string, detach bool, closeDelay time.Duration, revision int) string {
var directives string
if detach {
directives += "\n\t\tstream_detached"
}
if closeDelay > 0 {
directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay)
}
return fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
}
localhost:9080 {
reverse_proxy %s {
header_up X-Reload-Revision %d%s
}
}
`, backendAddr, revision, directives)
}
func captureHeapSnapshot(t *testing.T) heapSnapshot {
t.Helper()
runtime.GC()
debug.FreeOSMemory()
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
var buf bytes.Buffer
if err := pprof.Lookup("heap").WriteTo(&buf, 1); err != nil {
t.Fatalf("capturing heap profile: %v", err)
}
profile := buf.String()
return heapSnapshot{
HeapInuse: mem.HeapInuse,
HeapObjects: mem.HeapObjects,
handlerFrames: strings.Count(profile, "modules/caddyhttp/reverseproxy.(*Handler)"),
profileBytes: buf.Len(),
}
}
func countAliveStreams(clients []*upgradedStreamClient) int {
alive := 0
for index, client := range clients {
if err := client.echo(fmt.Sprintf("alive-check-%02d\n", index)); err == nil {
alive++
}
}
return alive
}
func closeClients(clients []*upgradedStreamClient) {
for _, client := range clients {
if client != nil {
_ = client.Close()
}
}
}
func formatBytes(value uint64) string {
const unit = 1024
if value < unit {
return fmt.Sprintf("%d B", value)
}
div, exp := uint64(unit), 0
for n := value / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(value)/float64(div), "KMGTPE"[exp])
}
func formatBytesDiff(before, after uint64) string {
if after >= before {
return "+" + formatBytes(after-before)
}
return "-" + formatBytes(before-after)
}
type upgradedStreamClient struct {
conn net.Conn
reader *bufio.Reader
mu sync.Mutex
}
func newUpgradedStreamClient(t *testing.T) *upgradedStreamClient {
t.Helper()
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
if err != nil {
t.Fatalf("dialing caddy: %v", err)
}
request := strings.Join([]string{
"GET /upgrade HTTP/1.1",
"Host: localhost:9080",
"Connection: Upgrade",
"Upgrade: stress-stream",
"",
"",
}, "\r\n")
if _, err := io.WriteString(conn, request); err != nil {
_ = conn.Close()
t.Fatalf("writing upgrade request: %v", err)
}
reader := bufio.NewReader(conn)
tproto := textproto.NewReader(reader)
statusLine, err := tproto.ReadLine()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade status line: %v", err)
}
if !strings.Contains(statusLine, "101") {
_ = conn.Close()
t.Fatalf("unexpected upgrade status: %s", statusLine)
}
headers, err := tproto.ReadMIMEHeader()
if err != nil {
_ = conn.Close()
t.Fatalf("reading upgrade headers: %v", err)
}
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
_ = conn.Close()
t.Fatalf("unexpected upgrade response headers: %v", headers)
}
return &upgradedStreamClient{conn: conn, reader: reader}
}
func (c *upgradedStreamClient) echo(payload string) error {
c.mu.Lock()
defer c.mu.Unlock()
deadline := time.Now().Add(1 * time.Second)
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return err
}
if _, err := io.WriteString(c.conn, payload); err != nil {
return err
}
if err := c.conn.SetReadDeadline(deadline); err != nil {
return err
}
buf := make([]byte, len(payload))
if _, err := io.ReadFull(c.reader, buf); err != nil {
return err
}
if string(buf) != payload {
return fmt.Errorf("unexpected echoed payload: got %q want %q", string(buf), payload)
}
return nil
}
func (c *upgradedStreamClient) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.conn.Close()
}
type upgradeEchoBackend struct {
addr string
ln net.Listener
mu sync.Mutex
conns map[net.Conn]struct{}
server *http.Server
}
func newUpgradeEchoBackend(t *testing.T) *upgradeEchoBackend {
t.Helper()
backend := &upgradeEchoBackend{conns: make(map[net.Conn]struct{})}
backend.server = &http.Server{
Handler: http.HandlerFunc(backend.serveHTTP),
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listening for backend: %v", err)
}
backend.ln = ln
backend.addr = ln.Addr().String()
go func() {
_ = backend.server.Serve(ln)
}()
return backend
}
func (b *upgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "stress-stream") {
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
conn, rw, err := hijacker.Hijack()
if err != nil {
return
}
b.trackConn(conn)
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: stress-stream\r\n\r\n")
_ = rw.Flush()
go func() {
defer b.untrackConn(conn)
defer conn.Close()
_, _ = io.Copy(conn, conn)
}()
}
func (b *upgradeEchoBackend) trackConn(conn net.Conn) {
b.mu.Lock()
b.conns[conn] = struct{}{}
b.mu.Unlock()
}
func (b *upgradeEchoBackend) untrackConn(conn net.Conn) {
b.mu.Lock()
delete(b.conns, conn)
b.mu.Unlock()
}
func (b *upgradeEchoBackend) Close() {
_ = b.server.Close()
_ = b.ln.Close()
b.mu.Lock()
defer b.mu.Unlock()
for conn := range b.conns {
_ = conn.Close()
}
clear(b.conns)
}
+1 -1
View File
@@ -20,7 +20,7 @@ require (
github.com/klauspost/cpuid/v2 v2.3.0
github.com/mholt/acmez/v3 v3.1.6
github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.59.0
github.com/quic-go/quic-go v0.59.1
github.com/smallstep/certificates v0.30.2
github.com/smallstep/nosql v0.8.0
github.com/smallstep/truststore v0.13.0
+2 -2
View File
@@ -280,8 +280,8 @@ github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEy
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/quic-go/quic-go v0.59.1 h1:0Gmua0HW1Tv7ANR7hUYwRyD0MG5OJfgvYSZasGZzBic=
github.com/quic-go/quic-go v0.59.1/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
+2 -32
View File
@@ -20,7 +20,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"maps"
"net"
"net/http"
"strconv"
@@ -241,12 +240,7 @@ func (app *App) Provision(ctx caddy.Context) error {
// if no protocols configured explicitly, enable all except h2c
if len(srv.Protocols) == 0 {
srv.Protocols = []string{"h1", "h2", "h3"}
}
srvProtocolsUnique := map[string]struct{}{}
for _, srvProtocol := range srv.Protocols {
srvProtocolsUnique[srvProtocol] = struct{}{}
srv.Protocols = srv.protocolsWithDefaults()
}
if srv.ListenProtocols != nil {
@@ -257,31 +251,7 @@ func (app *App) Provision(ctx caddy.Context) error {
for i, lnProtocols := range srv.ListenProtocols {
if lnProtocols != nil {
// populate empty listen protocols with server protocols
lnProtocolsDefault := false
var lnProtocolsInclude []string
srvProtocolsInclude := maps.Clone(srvProtocolsUnique)
// keep existing listener protocols unless they are empty
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" {
lnProtocolsDefault = true
} else {
lnProtocolsInclude = append(lnProtocolsInclude, lnProtocol)
delete(srvProtocolsInclude, lnProtocol)
}
}
// append server protocols to listener protocols if any listener protocols were empty
if lnProtocolsDefault {
for _, srvProtocol := range srv.Protocols {
if _, ok := srvProtocolsInclude[srvProtocol]; ok {
lnProtocolsInclude = append(lnProtocolsInclude, srvProtocol)
}
}
}
srv.ListenProtocols[i] = lnProtocolsInclude
srv.ListenProtocols[i] = srv.listenerProtocolsWithDefaults(lnProtocols)
}
}
}
+15 -1
View File
@@ -173,7 +173,7 @@ func (app *App) automaticHTTPSPhase1(ctx caddy.Context, repl *caddy.Replacer) er
for d := range serverDomainSet {
echDomains = append(echDomains, d)
}
app.tlsApp.RegisterServerNames(echDomains)
app.tlsApp.RegisterServerNames(echDomains, httpsRRALPNs(srv))
// nothing more to do here if there are no domains that qualify for
// automatic HTTPS and there are no explicit TLS connection policies:
@@ -574,6 +574,20 @@ func (app *App) makeRedirRoute(redirToPort uint, matcherSet MatcherSet) Route {
}
}
func httpsRRALPNs(srv *Server) []string {
alpn := make(map[string]struct{}, 3)
if srv.protocol("h3") {
alpn["h3"] = struct{}{}
}
if srv.protocol("h2") {
alpn["h2"] = struct{}{}
}
if srv.protocol("h1") {
alpn["http/1.1"] = struct{}{}
}
return caddytls.OrderedHTTPSRRALPN(alpn)
}
// createAutomationPolicies ensures that automated certificates for this
// app are managed properly. This adds up to two automation policies:
// one for the public names, and one for the internal names. If a catch-all
+33 -30
View File
@@ -1,44 +1,47 @@
package caddyhttp
import (
"reflect"
"testing"
"github.com/caddyserver/caddy/v2"
)
func TestRecordAutoHTTPSRedirectAddressPrefersHTTPSPort(t *testing.T) {
app := &App{HTTPSPort: 443}
redirDomains := make(map[string][]caddy.NetworkAddress)
func TestHTTPSRRALPNsDefaultProtocols(t *testing.T) {
srv := &Server{}
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", StartPort: 2345, EndPort: 2345})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", StartPort: 443, EndPort: 443})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", StartPort: 8443, EndPort: 8443})
got := httpsRRALPNs(srv)
want := []string{"h3", "h2", "http/1.1"}
got := redirDomains["example.com"]
if len(got) != 1 {
t.Fatalf("expected 1 redirect address, got %d: %#v", len(got), got)
}
if got[0].StartPort != 443 {
t.Fatalf("expected redirect to prefer HTTPS port 443, got %#v", got[0])
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected ALPN values: got %v want %v", got, want)
}
}
func TestRecordAutoHTTPSRedirectAddressKeepsAllBindAddressesOnWinningPort(t *testing.T) {
app := &App{HTTPSPort: 443}
redirDomains := make(map[string][]caddy.NetworkAddress)
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", Host: "10.0.0.189", StartPort: 8443, EndPort: 8443})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", Host: "10.0.0.189", StartPort: 443, EndPort: 443})
app.recordAutoHTTPSRedirectAddress(redirDomains, "example.com", caddy.NetworkAddress{Network: "tcp", Host: "2603:c024:8002:9500:9eb:e5d3:3975:d056", StartPort: 443, EndPort: 443})
got := redirDomains["example.com"]
if len(got) != 2 {
t.Fatalf("expected 2 redirect addresses for both bind addresses on the winning port, got %d: %#v", len(got), got)
func TestHTTPSRRALPNsListenProtocolOverrides(t *testing.T) {
srv := &Server{
Protocols: []string{"h1", "h2"},
ListenProtocols: [][]string{
{"h1"},
nil,
{},
{"h3", ""},
},
}
if got[0].StartPort != 443 || got[1].StartPort != 443 {
t.Fatalf("expected both redirect addresses to stay on HTTPS port 443, got %#v", got)
}
if got[0].Host != "10.0.0.189" || got[1].Host != "2603:c024:8002:9500:9eb:e5d3:3975:d056" {
t.Fatalf("expected both bind addresses to be preserved, got %#v", got)
got := httpsRRALPNs(srv)
want := []string{"h3", "h2", "http/1.1"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected ALPN values: got %v want %v", got, want)
}
}
func TestHTTPSRRALPNsIgnoresH2COnly(t *testing.T) {
srv := &Server{
Protocols: []string{"h2c"},
}
got := httpsRRALPNs(srv)
if len(got) != 0 {
t.Fatalf("unexpected ALPN values: got %v want none", got)
}
}
@@ -28,6 +28,7 @@ import (
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/internal/filesystems"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
)
type testCase struct {
@@ -188,6 +189,105 @@ func fileMatcherTest(t *testing.T, i int, tc testCase) {
}
}
func TestTryFilesRewriteEscapesMatchedPath(t *testing.T) {
root := t.TempDir()
tests := []struct {
name string
requestTarget string
filename string
extraFiles []string
wantPath string
wantRequestURI string
skipWindows bool
}{
{
name: "question mark in path",
requestTarget: "/%3F.html",
filename: "?.html",
wantPath: "/?.html",
wantRequestURI: "/%3F.html",
skipWindows: true,
},
{
name: "percent in path",
requestTarget: "/%25.html",
filename: "%.html",
wantPath: "/%.html",
wantRequestURI: "/%25.html",
},
{
name: "encoded question mark remains percent-encoded",
requestTarget: "/%253F.html",
filename: "%3F.html",
wantPath: "/%3F.html",
wantRequestURI: "/%253F.html",
},
{
name: "question mark in nested path",
requestTarget: "/nested/%3F.html",
filename: filepath.Join("nested", "?.html"),
wantPath: "/nested/?.html",
wantRequestURI: "/nested/%3F.html",
skipWindows: true,
},
{
name: "encoded slash in filename does not conflict with nesting",
requestTarget: "/nested%252Ffile.html",
filename: "nested%2Ffile.html",
extraFiles: []string{filepath.Join("nested", "file.html")},
wantPath: "/nested%2Ffile.html",
wantRequestURI: "/nested%252Ffile.html",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.skipWindows && runtime.GOOS == "windows" {
t.Skip("Windows file names cannot contain question marks")
}
for _, name := range append([]string{tc.filename}, tc.extraFiles...) {
filename := filepath.Join(root, name)
if err := os.MkdirAll(filepath.Dir(filename), 0o700); err != nil {
t.Fatalf("creating test file parent directory: %v", err)
}
if err := os.WriteFile(filename, []byte(name), 0o600); err != nil {
t.Fatalf("writing test file: %v", err)
}
}
m := &MatchFile{
fsmap: &filesystems.FileSystemMap{},
Root: root,
TryFiles: []string{"{http.request.uri.path}"},
}
req := httptest.NewRequest(http.MethodGet, "http://example.com"+tc.requestTarget, nil)
repl := caddyhttp.NewTestReplacer(req)
matched, err := m.MatchWithError(req)
if err != nil {
t.Fatalf("matching file: %v", err)
}
if !matched {
t.Fatalf("expected request %s to match %s", tc.requestTarget, tc.filename)
}
rewrite.Rewrite{URI: "{http.matchers.file.relative}"}.Rewrite(req, repl)
if req.URL.Path != tc.wantPath {
t.Errorf("rewritten path = %q, want %q", req.URL.Path, tc.wantPath)
}
if req.RequestURI != tc.wantRequestURI {
t.Errorf("rewritten request URI = %q, want %q", req.RequestURI, tc.wantRequestURI)
}
if req.URL.RawQuery != "" {
t.Errorf("rewritten raw query = %q, want empty", req.URL.RawQuery)
}
})
}
}
func TestPHPFileMatcher(t *testing.T) {
for i, tc := range []struct {
path string
+58 -4
View File
@@ -21,6 +21,8 @@ import (
"io"
"net"
"net/http"
"github.com/caddyserver/caddy/v2"
)
// ResponseWriterWrapper wraps an underlying ResponseWriter and
@@ -70,6 +72,8 @@ type responseRecorder struct {
size int
wroteHeader bool
stream bool
hijacked bool
detached bool
readSize *int
}
@@ -144,7 +148,8 @@ func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer
// WriteHeader writes the headers with statusCode to the wrapped
// ResponseWriter unless the response is to be buffered instead.
// 1xx responses are never buffered.
// 1xx responses are never buffered, except 101 which is treated
// as a final upgrade response.
func (rr *responseRecorder) WriteHeader(statusCode int) {
if rr.wroteHeader {
return
@@ -161,12 +166,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header())
}
// 1xx responses aren't final; just informational
if statusCode < 100 || statusCode > 199 {
// 1xx responses except 101 aren't final; just informational
if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols {
rr.wroteHeader = true
}
// if informational or not buffered, immediately write header
// if 1xx or not buffered, immediately write header
if rr.stream || (100 <= statusCode && statusCode <= 199) {
rr.ResponseWriterWrapper.WriteHeader(statusCode)
}
@@ -222,7 +227,18 @@ func (rr *responseRecorder) Buffered() bool {
return !rr.stream
}
func (rr *responseRecorder) DetachAfterHijack(detached bool) bool {
if rr.hijacked {
return false
}
rr.detached = detached
return true
}
func (rr *responseRecorder) WriteResponse() error {
if rr.hijacked {
return nil
}
if rr.statusCode == 0 {
// could happen if no handlers actually wrote anything,
// and this prevents a panic; status must be > 0
@@ -253,11 +269,25 @@ func (rr *responseRecorder) setReadSize(size *int) {
}
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !rr.wroteHeader {
// hijacking without writing status code first works as long as
// subsequent writes follows http1.1 wire format, but it will
// show up with a status code of 0 in the access log and bytes
// written will include response headers. Response headers won't
// be present in the log if not set on the response writer.
caddy.Log().Warn("hijacking without writing status code first")
}
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
if err != nil {
return nil, nil, err
}
rr.hijacked = true
rr.stream = true
rr.wroteHeader = true
if rr.detached {
return conn, brw, nil
}
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)
@@ -311,6 +341,29 @@ func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
// DetachResponseWriterAfterHijack detaches w or one of its wrapped
// response writers when it's hijacked. Returns true if not already
// hijacked. When detached, bytes read or written stats will not be
// recorded for the hijacked connection, and it's safe to use the
// connection after http middleware returns.
func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool {
for w != nil {
if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok {
return detacher.DetachAfterHijack(detached)
}
unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter })
if !ok {
return false
}
next := unwrapper.Unwrap()
if next == w {
return false
}
w = next
}
return false
}
// ResponseRecorder is a http.ResponseWriter that records
// responses instead of writing them to the client. See
// docs for NewResponseRecorder for proper usage.
@@ -319,6 +372,7 @@ type ResponseRecorder interface {
Status() int
Buffer() *bytes.Buffer
Buffered() bool
DetachAfterHijack(bool) bool
Size() int
WriteResponse() error
}
+93
View File
@@ -1,11 +1,14 @@
package caddyhttp
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
)
type responseWriterSpy interface {
@@ -44,6 +47,50 @@ func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }
type hijackRespWriter struct {
baseRespWriter
header http.Header
status int
conn net.Conn
}
func newHijackRespWriter() *hijackRespWriter {
return &hijackRespWriter{
header: make(http.Header),
conn: stubConn{},
}
}
func (hrw *hijackRespWriter) Header() http.Header {
return hrw.header
}
func (hrw *hijackRespWriter) WriteHeader(statusCode int) {
hrw.status = statusCode
}
func (hrw *hijackRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
br := bufio.NewReader(hrw.conn)
bw := bufio.NewWriter(hrw.conn)
return hrw.conn, bufio.NewReadWriter(br, bw), nil
}
type stubConn struct{}
func (stubConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (stubConn) Write(p []byte) (int, error) { return len(p), nil }
func (stubConn) Close() error { return nil }
func (stubConn) LocalAddr() net.Addr { return stubAddr("local") }
func (stubConn) RemoteAddr() net.Addr { return stubAddr("remote") }
func (stubConn) SetDeadline(time.Time) error { return nil }
func (stubConn) SetReadDeadline(time.Time) error { return nil }
func (stubConn) SetWriteDeadline(time.Time) error { return nil }
type stubAddr string
func (a stubAddr) Network() string { return "tcp" }
func (a stubAddr) String() string { return string(a) }
func TestResponseWriterWrapperReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
@@ -169,3 +216,49 @@ func TestResponseRecorderReadFrom(t *testing.T) {
})
}
}
func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) {
w := newHijackRespWriter()
var buf bytes.Buffer
rr := NewResponseRecorder(w, &buf, func(status int, header http.Header) bool {
return true
})
rr.WriteHeader(http.StatusSwitchingProtocols)
if rr.Status() != http.StatusSwitchingProtocols {
t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols)
}
if w.status != http.StatusSwitchingProtocols {
t.Fatalf("underlying status = %d, want %d", w.status, http.StatusSwitchingProtocols)
}
hj, ok := rr.(http.Hijacker)
if !ok {
t.Fatal("response recorder does not implement http.Hijacker")
}
conn, _, err := hj.Hijack()
if err != nil {
t.Fatalf("Hijack() error = %v", err)
}
defer conn.Close()
if rr.Buffered() {
t.Fatal("hijacked response should not remain buffered")
}
if rr.DetachAfterHijack(true) {
t.Fatal("response recorder should report hijacked state by returning false")
}
if DetachResponseWriterAfterHijack(rr, true) {
t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack")
}
if err := rr.WriteResponse(); err != nil {
t.Fatalf("WriteResponse() after hijack returned error: %v", err)
}
if rr.Size() != 0 {
t.Fatalf("size = %d, want 0 after hijack handshake", rr.Size())
}
if got := w.Written(); got != "" {
t.Fatalf("unexpected buffered body write after hijack: %q", got)
}
}
@@ -99,6 +99,12 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// stream_buffer_size <size>
// stream_timeout <duration>
// stream_close_delay <duration>
// stream_detached
// stream_logs {
// level <debug|info|warn|error>
// logger_name <name|access>
// skip_handshake
// }
// verbose_logs
//
// # request manipulation
@@ -703,6 +709,49 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
h.StreamCloseDelay = caddy.Duration(dur)
}
case "stream_detached":
if d.NextArg() {
return d.ArgErr()
}
h.StreamDetached = true
case "stream_logs":
if d.NextArg() {
return d.ArgErr()
}
if h.StreamLogs == nil {
h.StreamLogs = new(StreamLogs)
}
nesting := d.Nesting()
for d.NextBlock(nesting) {
switch d.Val() {
case "level":
if !d.NextArg() {
return d.ArgErr()
}
h.StreamLogs.Level = d.Val()
if d.NextArg() {
return d.ArgErr()
}
case "logger_name":
if !d.NextArg() {
return d.ArgErr()
}
h.StreamLogs.LoggerName = d.Val()
if d.NextArg() {
return d.ArgErr()
}
case "skip_handshake":
if d.NextArg() {
return d.ArgErr()
}
h.StreamLogs.SkipHandshake = true
default:
return d.Errf("unrecognized stream_logs option: %s", d.Val())
}
}
case "trusted_proxies":
for d.NextArg() {
if d.Val() == "private_ranges" {
@@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
hrc.isFinalized = true
// write the response
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger)
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr)
}
// CopyResponseHeadersHandler is a special HTTP handler which may
@@ -0,0 +1,146 @@
package reverseproxy
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"go.uber.org/zap"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
type extendedConnectCapture struct {
method string
headers http.Header
body []byte
extendedBodyPresent bool
extendedConnectBody []byte
}
type extendedConnectCaptureTransport struct {
mu sync.Mutex
capture extendedConnectCapture
}
func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
c := extendedConnectCapture{
method: req.Method,
headers: req.Header.Clone(),
body: body,
}
if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
c.extendedBodyPresent = true
c.extendedConnectBody, err = io.ReadAll(rc)
if err != nil {
return nil, err
}
_ = rc.Close()
}
tr.mu.Lock()
tr.capture = c
tr.mu.Unlock()
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: req,
}, nil
}
func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture {
tr.mu.Lock()
defer tr.mu.Unlock()
return tr.capture
}
func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) {
tests := []struct {
name string
protoMajor int
proto string
headers map[string]string
}{
{
name: "h2 extended connect",
protoMajor: 2,
proto: "HTTP/2.0",
headers: map[string]string{
":protocol": "websocket",
},
},
{
name: "h3 extended connect",
protoMajor: 3,
proto: "websocket",
headers: map[string]string{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
const payload = "extended-connect-body"
transport := new(extendedConnectCaptureTransport)
h := &Handler{
logger: zap.NewNop(),
Transport: transport,
Upstreams: UpstreamPool{
&Upstream{Host: new(Host), Dial: "127.0.0.1:8443"},
},
LoadBalancing: &LoadBalancing{
SelectionPolicy: &RoundRobinSelection{},
},
}
req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload))
req.ProtoMajor = tc.protoMajor
req.Proto = tc.proto
for key, value := range tc.headers {
req.Header.Set(key, value)
}
req = prepareTestRequest(req)
rr := httptest.NewRecorder()
err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
if err != nil {
t.Fatalf("ServeHTTP() error = %v", err)
}
captured := transport.Snapshot()
if captured.method != http.MethodGet {
t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet)
}
if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") {
t.Fatalf("Upgrade header = %q, want websocket", got)
}
if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") {
t.Fatalf("Connection header = %q, want Upgrade", got)
}
if got := captured.headers.Get(":protocol"); got != "" {
t.Fatalf(":protocol header should be removed, got %q", got)
}
if len(captured.body) != 0 {
t.Fatalf("upstream request body length = %d, want 0", len(captured.body))
}
if !captured.extendedBodyPresent {
t.Fatal("extended_connect_websocket_body variable missing from request context")
}
if string(captured.extendedConnectBody) != payload {
t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload)
}
})
}
}
+79
View File
@@ -16,6 +16,10 @@ import (
var reverseProxyMetrics = struct {
once sync.Once
upstreamsHealthy *prometheus.GaugeVec
streamsActive *prometheus.GaugeVec
streamsTotal *prometheus.CounterVec
streamDuration *prometheus.HistogramVec
streamBytes *prometheus.CounterVec
logger *zap.Logger
}{}
@@ -23,6 +27,8 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
const ns, sub = "caddy", "reverse_proxy"
upstreamsLabels := []string{"upstream"}
streamResultLabels := []string{"upstream", "result"}
streamBytesLabels := []string{"upstream", "direction"}
reverseProxyMetrics.once.Do(func() {
reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: ns,
@@ -30,6 +36,31 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
Name: "upstreams_healthy",
Help: "Health status of reverse proxy upstreams.",
}, upstreamsLabels)
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: ns,
Subsystem: sub,
Name: "streams_active",
Help: "Number of currently active upgraded reverse proxy streams.",
}, upstreamsLabels)
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: ns,
Subsystem: sub,
Name: "streams_total",
Help: "Total number of upgraded reverse proxy streams by close result.",
}, streamResultLabels)
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: ns,
Subsystem: sub,
Name: "stream_duration_seconds",
Help: "Duration of upgraded reverse proxy streams by close result.",
Buckets: prometheus.DefBuckets,
}, streamResultLabels)
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: ns,
Subsystem: sub,
Name: "stream_bytes_total",
Help: "Total bytes proxied across upgraded reverse proxy streams.",
}, streamBytesLabels)
})
// duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because
@@ -42,10 +73,58 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamsActive,
NewCollector: reverseProxyMetrics.streamsActive,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamsTotal,
NewCollector: reverseProxyMetrics.streamsTotal,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamDuration,
NewCollector: reverseProxyMetrics.streamDuration,
}) {
panic(err)
}
if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil &&
!errors.Is(err, prometheus.AlreadyRegisteredError{
ExistingCollector: reverseProxyMetrics.streamBytes,
NewCollector: reverseProxyMetrics.streamBytes,
}) {
panic(err)
}
reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics")
}
func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) {
labels := prometheus.Labels{"upstream": upstream}
reverseProxyMetrics.streamsActive.With(labels).Inc()
var once sync.Once
return func(result string, duration time.Duration, toBackend, fromBackend int64) {
once.Do(func() {
reverseProxyMetrics.streamsActive.With(labels).Dec()
reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc()
reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds())
if toBackend > 0 {
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend))
}
if fromBackend > 0 {
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend))
}
})
}
}
type metricsUpstreamsHealthyUpdater struct {
handler *Handler
}
@@ -0,0 +1,67 @@
package reverseproxy
import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) {
const upstream = "127.0.0.1:7443"
// Use fresh metric vectors for deterministic assertions in this unit test.
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
finish := trackActiveStream(upstream)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 {
t.Fatalf("active streams = %v, want 1", got)
}
finish("closed", 150*time.Millisecond, 1234, 4321)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 {
t.Fatalf("active streams = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 {
t.Fatalf("streams_total closed = %v, want 1", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 {
t.Fatalf("bytes to_upstream = %v, want 1234", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 {
t.Fatalf("bytes from_upstream = %v, want 4321", got)
}
// A second finish call should be ignored by the once guard.
finish("error", 1*time.Second, 111, 222)
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 {
t.Fatalf("streams_total error = %v, want 0", got)
}
}
func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) {
const upstream = "127.0.0.1:9000"
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0)
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 {
t.Fatalf("bytes to_upstream = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 {
t.Fatalf("bytes from_upstream = %v, want 0", got)
}
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 {
t.Fatalf("streams_total timeout = %v, want 1", got)
}
}
@@ -730,3 +730,58 @@ func TestRetryMatchAllowsExpressionMixedWithOtherMatchers(t *testing.T) {
})
}
}
// TestSubrouteErrorFallbackWithBody is similar to TestDialErrorBodyRetry but
// mimics Subroute's Error handler rather than testing retries specifically
func TestSubrouteErrorFallbackWithBody(t *testing.T) {
// Good upstream: echoes the request body with 200 OK.
goodServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read body: "+err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, err = w.Write(body)
if err != nil {
t.Errorf("error writing in good server: %v", err)
}
}))
t.Cleanup(goodServer.Close)
// Handler which will dial error
badProxy := minimalHandler(0, &Upstream{Host: new(Host), Dial: deadUpstreamAddr(t)})
bodyReader := newCloseOnCloseReader("hello world")
req := httptest.NewRequest("POST", "http://localhost/", bodyReader)
// httptest.NewRequest wraps the reader in NopCloser; replace
// it with our close-aware reader so Close() is propagated.
req.Body = bodyReader
req = prepareTestRequest(req)
rec := httptest.NewRecorder()
err := badProxy.ServeHTTP(rec, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
if err == nil {
t.Fatalf("Expected error from badProxy.ServeHTTP")
}
// Simulate the Subroute's Error handler by calling another handler with the
// same request and recorder
goodProxy := minimalHandler(0, &Upstream{Host: new(Host), Dial: goodServer.Listener.Addr().String()})
err = goodProxy.ServeHTTP(rec, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
if err != nil {
t.Fatalf("Expected no error from goodProxy.ServeHTTP, got: %v", err)
}
if rec.Code != http.StatusOK {
t.Errorf("status: got %d, want %d", rec.Code, http.StatusOK)
}
expectedBody := "hello world"
if rec.Body.String() != expectedBody {
t.Errorf("body: got %q, want %q", rec.Body.String(), expectedBody)
}
}
+159 -29
View File
@@ -186,6 +186,22 @@ type Handler struct {
// by the previous config closing. Default: no delay.
StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"`
// If true, upgraded connections such as WebSockets are detached from
// the handler and retained across config reloads when their upstream
// still exists in the new config. Connections using upstreams that are
// removed are closed during cleanup. By default this is false, preserving
// legacy behavior where upgraded connections are closed on reload
// (optionally delayed by stream_close_delay).
// Only http1.1 websocket connections are affected, websockets for h2/h3
// are not affected. If true, bytes transferred for http1.1 in the access
// logs will be zero but those stats can be found in the stream logs for
// http1/2/3 regardless if this is enabled.
StreamDetached bool `json:"stream_detached,omitempty"`
// Controls logging behavior for upgraded stream lifecycle events.
// If omitted, defaults are used (level=DEBUG, logger_name="http.handlers.reverse_proxy.stream").
StreamLogs *StreamLogs `json:"stream_logs,omitempty"`
// If configured, rewrites the copy of the upstream request.
// Allows changing the request method and URI (path and query).
// Since the rewrite is applied to the copy, it does not persist
@@ -240,14 +256,16 @@ type Handler struct {
// Holds the handle_response Caddyfile tokens while adapting
handleResponseSegments []*caddyfile.Dispenser
// Stores upgraded requests (hijacked connections) for proper cleanup
connections map[io.ReadWriteCloser]openConnection
connectionsCloseTimer *time.Timer
connectionsMu *sync.Mutex
// Tracks hijacked/upgraded connections (WebSocket etc.) so they can be
// closed when their upstream is removed from the config.
tunnelTracker *tunnelTracker
ctx caddy.Context
logger *zap.Logger
events *caddyevents.App
streamLogLevel zapcore.Level
streamLogLoggerName string
}
// CaddyModule returns the Caddy module information.
@@ -267,8 +285,25 @@ func (h *Handler) Provision(ctx caddy.Context) error {
h.events = eventAppIface.(*caddyevents.App)
h.ctx = ctx
h.logger = ctx.Logger()
h.connections = make(map[io.ReadWriteCloser]openConnection)
h.connectionsMu = new(sync.Mutex)
h.tunnelTracker = newTunnelTracker(h.logger, time.Duration(h.StreamCloseDelay))
h.streamLogLevel = defaultStreamLogLevel
h.streamLogLoggerName = defaultStreamLoggerName
if h.StreamLogs != nil {
if h.StreamLogs.Level != "" {
lvl, err := zapcore.ParseLevel(strings.ToLower(strings.TrimSpace(h.StreamLogs.Level)))
if err != nil {
return fmt.Errorf("invalid stream_logs.level %q: %w", h.StreamLogs.Level, err)
}
h.streamLogLevel = lvl
}
if name := strings.TrimSpace(h.StreamLogs.LoggerName); name != "" {
h.streamLogLoggerName = name
}
}
if h.StreamDetached {
registerDetachedTunnelTrackers(h.tunnelTracker)
}
// warn about unsafe buffering config
if h.RequestBuffers == -1 || h.ResponseBuffers == -1 {
@@ -437,15 +472,85 @@ func (h *Handler) Provision(ctx caddy.Context) error {
return nil
}
// Cleanup cleans up the resources made by h.
func (h *Handler) Cleanup() error {
err := h.cleanupConnections()
func (h Handler) streamLogsSkipHandshake() bool {
return h.StreamLogs != nil && h.StreamLogs.SkipHandshake
}
// remove hosts from our config from the pool
for _, upstream := range h.Upstreams {
_, _ = hosts.Delete(upstream.String())
func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger {
name := strings.TrimSpace(h.streamLogLoggerName)
if name == "" {
name = defaultStreamLoggerName
}
if name == streamLoggerNameUseAccess {
logger := caddy.Log().Named(defaultAccessLoggerBase)
names := caddyhttp.GetVar(req.Context(), caddyhttp.AccessLoggerNameVarKey)
namesSlice, ok := names.([]any)
if !ok {
return logger
}
for _, v := range namesSlice {
name, ok := v.(string)
if !ok {
continue
}
if name == "" {
return logger
}
return logger.Named(name)
}
return logger
}
return caddy.Log().Named(name)
}
var (
detachedTunnelTrackers = make(map[*tunnelTracker]struct{})
detachedTunnelTrackersMu sync.Mutex
)
func registerDetachedTunnelTrackers(ts *tunnelTracker) {
detachedTunnelTrackersMu.Lock()
defer detachedTunnelTrackersMu.Unlock()
detachedTunnelTrackers[ts] = struct{}{}
}
func notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream string, self *tunnelTracker) error {
detachedTunnelTrackersMu.Lock()
defer detachedTunnelTrackersMu.Unlock()
var err error
for tunnel := range detachedTunnelTrackers {
if closeErr := tunnel.closeConnectionsForUpstream(upstream); closeErr != nil && tunnel == self && err == nil {
err = closeErr
}
}
return err
}
func unregisterDetachedTunnelTrackers(ts *tunnelTracker) {
detachedTunnelTrackersMu.Lock()
defer detachedTunnelTrackersMu.Unlock()
delete(detachedTunnelTrackers, ts)
}
// Cleanup cleans up the resources made by h.
func (h *Handler) Cleanup() error {
// even if StreamDetached is true, extended connect websockets may still be running
err := h.tunnelTracker.cleanupAttachedConnections()
for _, upstream := range h.Upstreams {
// hosts.Delete returns deleted=true when the ref count reaches zero,
// meaning no other active config references this upstream. In that
// case close any tunnels proxying to it; otherwise let them survive
// to their natural end since the upstream is still in use.
deleted, _ := hosts.Delete(upstream.String())
if deleted {
if closeErr := notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream.String(), h.tunnelTracker); closeErr != nil && err == nil {
err = closeErr
}
}
}
return err
}
@@ -488,20 +593,19 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
reqHost := clonedReq.Host
reqHeader := clonedReq.Header
// When retries are configured and there is a body, wrap it in
// io.NopCloser to prevent Go's transport from closing it on dial
// errors. cloneRequest does a shallow copy, so clonedReq.Body and
// If the request contained a body, wrap it in io.NopCloser
// to prevent Go's transport from closing it on dial errors.
// cloneRequest does a shallow copy, so clonedReq.Body and
// r.Body share the same io.ReadCloser — a dial-failure Close()
// would kill the original body for all subsequent retry attempts.
// The real body is closed by the HTTP server when the handler
// returns.
// would kill the original body for all subsequent retry
// attempts or subsequent handlers. The real body is closed by
// the HTTP server when the handler returns.
//
// If the body was already fully buffered (via request_buffers),
// we also extract the buffer so the retry loop can replay it
// from the beginning on each attempt. (see #6259, #7546)
// from the beginning on each attempt. (see #6259, #7546, #7713)
var bufferedReqBody *bytes.Buffer
if clonedReq.Body != nil && h.LoadBalancing != nil &&
(h.LoadBalancing.Retries > 0 || h.LoadBalancing.TryDuration > 0) {
if clonedReq.Body != nil {
if reqBodyBuf, ok := clonedReq.Body.(bodyReadCloser); ok && reqBodyBuf.body == nil && reqBodyBuf.buf != nil {
bufferedReqBody = reqBodyBuf.buf
reqBodyBuf.buf = nil
@@ -1138,10 +1242,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
// we use the original request here, so that any routes from 'next'
// see the original request rather than the proxy cloned request.
hrc := &handleResponseContext{
handler: h,
response: res,
start: start,
logger: logger,
handler: h,
response: res,
start: start,
logger: logger,
upstreamAddr: di.Upstream.String(),
}
ctx := origReq.Context()
ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc)
@@ -1171,7 +1276,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
}
// copy the response body and headers back to the upstream client
return h.finalizeResponse(rw, req, res, repl, start, logger)
return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String())
}
// finalizeResponse prepares and copies the response.
@@ -1182,12 +1287,11 @@ func (h *Handler) finalizeResponse(
repl *caddy.Replacer,
start time.Time,
logger *zap.Logger,
upstreamAddr string,
) error {
// deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
var wg sync.WaitGroup
h.handleUpgradeResponse(logger, &wg, rw, req, res)
wg.Wait()
h.handleUpgradeResponse(logger, rw, req, res, upstreamAddr)
return nil
}
@@ -1794,6 +1898,22 @@ func (brc bodyReadCloser) Close() error {
return nil
}
// StreamLogs controls logging for upgraded stream lifecycle events.
type StreamLogs struct {
// The minimum level at which stream lifecycle events are logged.
// Supported values are debug, info, warn, and error. Default: debug.
Level string `json:"level,omitempty"`
// Logger name for stream lifecycle logs. Default: "http.handlers.reverse_proxy.stream".
// Special value "access" uses the access logger namespace and, if set,
// respects the first value in access_logger_names/log_name for the request.
LoggerName string `json:"logger_name,omitempty"`
// If true, suppresses the access log entry normally emitted when an
// upgraded stream handshake completes and the request unwinds.
SkipHandshake bool `json:"skip_handshake,omitempty"`
}
// bufPool is used for buffering requests and responses.
var bufPool = sync.Pool{
New: func() any {
@@ -1826,6 +1946,9 @@ type handleResponseContext struct {
// i.e. copied and closed, to make sure that it doesn't
// happen twice.
isFinalized bool
// upstreamAddr is the selected upstream address for this request.
upstreamAddr string
}
// proxyHandleResponseContextCtxKey is the context key for the active proxy handler
@@ -1836,6 +1959,13 @@ const proxyHandleResponseContextCtxKey caddy.CtxKey = "reverse_proxy_handle_resp
// errNoUpstream occurs when there are no upstream available.
var errNoUpstream = fmt.Errorf("no upstreams available")
const (
defaultStreamLogLevel = zapcore.DebugLevel
defaultStreamLoggerName = "http.handlers.reverse_proxy.stream"
streamLoggerNameUseAccess = "access"
defaultAccessLoggerBase = "http.log.access"
)
// Interface guards
var (
_ caddy.Provisioner = (*Handler)(nil)
+272 -113
View File
@@ -26,6 +26,7 @@ import (
"io"
weakrand "math/rand/v2"
"mime"
"net"
"net/http"
"sync"
"time"
@@ -35,15 +36,16 @@ import (
"go.uber.org/zap/zapcore"
"golang.org/x/net/http/httpguts"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
type h2ReadWriteCloser struct {
type extendedConnectReadWriteCloser struct {
io.ReadCloser
http.ResponseWriter
}
func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) {
n, err = rwc.ResponseWriter.Write(p)
if err != nil {
return 0, err
@@ -57,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
return n, nil
}
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
@@ -90,13 +92,37 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
copyHeader(rw.Header(), res.Header)
normalizeWebsocketHeaders(rw.Header())
// Capture all h fields needed by the tunnel now, so that the Handler (h)
// is not referenced after this function returns (for HTTP/1.1 hijacked
// connections the tunnel runs in a detached goroutine).
tunnel := h.tunnelTracker
bufferSize := h.StreamBufferSize
streamTimeout := time.Duration(h.StreamTimeout)
if h.StreamDetached {
// the return value should be true as it's not hijacked yet,
// but some middleware may wrap response writers incorrectly
if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) {
if c := logger.Check(zap.DebugLevel, "detaching connection failed"); c != nil {
c.Write(zap.String("tip", "check if your response writers have an Unwrap method or if already hijacked"))
}
}
}
var (
conn io.ReadWriteCloser
brw *bufio.ReadWriter
conn io.ReadWriteCloser
brw *bufio.ReadWriter
detached = h.StreamDetached
)
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can be removed for those backends
// websocket over http2 or http3 if extended connect is enabled,
// assuming backend doesn't support this, the request will be
// modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can
// be removed for those backends
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
// websocket over extended connect can't be detached. rw and req.Body
// are only valid while the handler goroutine is running
detached = false
req.Body = body
rw.Header().Del("Upgrade")
rw.Header().Del("Connection")
@@ -104,18 +130,18 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
rw.WriteHeader(http.StatusOK)
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
c.Write(zap.Int("http_version", 2))
c.Write(zap.Int("http_version", req.ProtoMajor))
}
//nolint:bodyclose
flushErr := http.NewResponseController(rw).Flush()
if flushErr != nil {
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil {
c.Write(zap.Error(flushErr))
}
return
}
conn = h2ReadWriteCloser{req.Body, rw}
conn = extendedConnectReadWriteCloser{req.Body, rw}
// bufio is not needed, use minimal buffer
brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
} else {
@@ -143,27 +169,6 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
}
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
backConnCloseCh := make(chan struct{})
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
start := time.Now()
defer func() {
conn.Close()
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
c.Write(zap.Duration("duration", time.Since(start)))
}
}()
if err := brw.Flush(); err != nil {
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
c.Write(zap.Error(err))
@@ -184,13 +189,12 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
}
// Ensure the hijacked client connection, and the new connection established
// with the backend, are both closed in the event of a server shutdown. This
// is done by registering them. We also try to gracefully close connections
// we recognize as websockets.
// We need to make sure the client connection messages (i.e. to upstream)
// are masked, so we need to know whether the connection is considered the
// server or the client side of the proxy.
// Register both connections with the tunnel tracker. We also try to
// gracefully close connections we recognize as websockets. We need to make
// sure the client connection messages (i.e. to upstream) are masked, so we
// need to know whether the connection is considered the server or the
// client side of the proxy. Note that gracefulClose must not capture h,
// since the tunnel may outlive the handler instance.
gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error {
if isWebsocket(req) {
return func() error {
@@ -199,43 +203,147 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
return nil
}
deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false))
deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true))
defer deleteFrontConn()
defer deleteBackConn()
deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), detached, upstreamAddr)
deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), detached, upstreamAddr)
if h.streamLogsSkipHandshake() {
caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true)
}
repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
repl.Set("http.reverse_proxy.upgraded", true)
streamUUID, _ := repl.GetString("http.request.uuid")
streamFields := makeStreamLogFields(streamUUID)
streamLogger := h.streamLoggerForRequest(req)
streamLevel := h.streamLogLevel
finishMetrics := trackActiveStream(upstreamAddr)
start := time.Now()
if !detached {
handleUpgradeTunnel(
streamLogger,
streamLevel,
conn,
backConn,
deleteFrontConn,
deleteBackConn,
bufferSize,
streamTimeout,
start,
finishMetrics,
streamFields,
)
} else {
// start a new goroutine
go handleUpgradeTunnel(
streamLogger,
streamLevel,
conn,
backConn,
deleteFrontConn,
deleteBackConn,
bufferSize,
streamTimeout,
start,
finishMetrics,
streamFields,
)
}
}
// handleUpgradeTunnel returns when transfer is done.
func handleUpgradeTunnel(
streamLogger *zap.Logger,
streamLevel zapcore.Level,
conn io.ReadWriteCloser,
backConn io.ReadWriteCloser,
deleteFrontConn func(),
deleteBackConn func(),
bufferSize int,
streamTimeout time.Duration,
start time.Time,
finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64),
streamFields []zap.Field,
) {
defer deleteBackConn()
defer deleteFrontConn()
var (
wg sync.WaitGroup
toBackend int64
fromBackend int64
result string
)
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to
// send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
spc := switchProtocolCopier{
user: conn,
backend: backConn,
wg: wg,
bufferSize: h.StreamBufferSize,
wg: &wg,
bufferSize: bufferSize,
sent: &toBackend,
received: &fromBackend,
}
wg.Add(2)
// setup the timeout if requested
var timeoutc <-chan time.Time
if h.StreamTimeout > 0 {
timer := time.NewTimer(time.Duration(h.StreamTimeout))
if streamTimeout > 0 {
timer := time.NewTimer(streamTimeout)
defer timer.Stop()
timeoutc = timer.C
}
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
wg.Add(2)
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
select {
case err := <-errc:
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
result = classifyStreamResult(err)
if c := streamLogger.Check(streamLevel, "streaming error"); c != nil {
c.Write(zap.Error(err))
}
case time := <-timeoutc:
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", time))
case t := <-timeoutc:
result = "timeout"
if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", t))
}
}
// Close both ends to unblock the still-running copy goroutine,
// then wait for it so byte counts are final before metrics/logging.
conn.Close()
backConn.Close()
wg.Wait()
finishMetrics(result, time.Since(start), toBackend, fromBackend)
if c := streamLogger.Check(streamLevel, "connection closed"); c != nil {
fields := append([]zap.Field{}, streamFields...)
fields = append(fields,
zap.Duration("duration", time.Since(start)),
zap.Int64("bytes_to_backend", toBackend),
zap.Int64("bytes_from_backend", fromBackend),
)
c.Write(fields...)
}
}
func classifyStreamResult(err error) string {
if err == nil ||
errors.Is(err, io.EOF) ||
errors.Is(err, net.ErrClosed) ||
errors.Is(err, context.Canceled) {
return "closed"
}
return "error"
}
func makeStreamLogFields(streamUUID string) []zap.Field {
fields := make([]zap.Field, 0, 1)
if streamUUID != "" {
fields = append(fields, zap.String("uuid", streamUUID))
}
return fields
}
// flushInterval returns the p.FlushInterval value, conditionally
@@ -375,75 +483,101 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
}
}
// registerConnection holds onto conn so it can be closed in the event
// of a server shutdown. This is useful because hijacked connections or
// connections dialed to backends don't close when server is shut down.
// The caller should call the returned delete() function when the
// connection is done to remove it from memory.
func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) {
h.connectionsMu.Lock()
h.connections[conn] = openConnection{conn, gracefulClose}
h.connectionsMu.Unlock()
return func() {
h.connectionsMu.Lock()
delete(h.connections, conn)
// if there is no connection left before the connections close timer fires
if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
// we release the timer that holds the reference to Handler
if (*h.connectionsCloseTimer).Stop() {
h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
}
h.connectionsCloseTimer = nil
}
h.connectionsMu.Unlock()
// openConnection maps an open connection to an optional function for graceful
// close and records which upstream address the connection is proxying to.
// Also tracks whether the connection is detached, which means it should only be
// closed when the upstream is removed from the config, not on every reload.
type openConnection struct {
conn io.ReadWriteCloser
gracefulClose func() error
detached bool
upstream string
}
// tunnelTracker tracks hijacked/upgraded connections for selective cleanup.
// This exists to detach the lifecycle of streaming connections from the proxy
// Handler and config, since we typically want them to survive past config reloads.
// It also allows for selective connection cleanup based on their attachment status.
type tunnelTracker struct {
connections map[io.ReadWriteCloser]openConnection
closeTimer *time.Timer
closeDelay time.Duration
stopped bool
mu sync.Mutex
logger *zap.Logger
}
func newTunnelTracker(logger *zap.Logger, closeDelay time.Duration) *tunnelTracker {
return &tunnelTracker{
connections: make(map[io.ReadWriteCloser]openConnection),
closeDelay: closeDelay,
logger: logger,
}
}
// closeConnections immediately closes all hijacked connections (both to client and backend).
func (h *Handler) closeConnections() error {
var err error
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
// registerConnection stores conn in the tracking map. The caller must invoke
// the returned del func when the connection is done.
func (ts *tunnelTracker) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) {
ts.mu.Lock()
ts.connections[conn] = openConnection{conn, gracefulClose, detached, upstream}
ts.mu.Unlock()
return func() {
ts.mu.Lock()
delete(ts.connections, conn)
if len(ts.connections) == 0 && ts.stopped {
unregisterDetachedTunnelTrackers(ts)
if ts.closeTimer != nil {
if ts.closeTimer.Stop() {
ts.logger.Debug("stopped streaming connections close timer - all connections are already closed")
}
ts.closeTimer = nil
}
}
ts.mu.Unlock()
}
}
for _, oc := range h.connections {
// closeAttachedConnections closes all tracked attached connections.
func (ts *tunnelTracker) closeAttachedConnections() error {
var err error
ts.mu.Lock()
defer ts.mu.Unlock()
ts.stopped = true
for _, oc := range ts.connections {
// detached connections are only closed when the upstream is gone from the config
if oc.detached {
continue
}
if oc.gracefulClose != nil {
// this is potentially blocking while we have the lock on the connections
// map, but that should be OK since the server has in theory shut down
// and we are no longer using the connections map
gracefulErr := oc.gracefulClose()
if gracefulErr != nil && err == nil {
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
err = gracefulErr
}
}
closeErr := oc.conn.Close()
if closeErr != nil && err == nil {
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
err = closeErr
}
}
return err
}
// cleanupConnections closes hijacked connections.
// Depending on the value of StreamCloseDelay it does that either immediately
// or sets up a timer that will do that later.
func (h *Handler) cleanupConnections() error {
if h.StreamCloseDelay == 0 {
return h.closeConnections()
// cleanupAttachedConnections closes upgraded attached connections.
// Depending on closeDelay it does that either immediately or after a timer.
func (ts *tunnelTracker) cleanupAttachedConnections() error {
if ts.closeDelay == 0 {
return ts.closeAttachedConnections()
}
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
// the handler is shut down, no new connection can appear,
// so we can skip setting up the timer when there are no connections
if len(h.connections) > 0 {
delay := time.Duration(h.StreamCloseDelay)
h.connectionsCloseTimer = time.AfterFunc(delay, func() {
if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
ts.mu.Lock()
defer ts.mu.Unlock()
if len(ts.connections) > 0 {
delay := ts.closeDelay
ts.closeTimer = time.AfterFunc(delay, func() {
if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
c.Write(zap.Duration("delay", delay))
}
err := h.closeConnections()
err := ts.closeAttachedConnections()
if err != nil {
if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil {
if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil {
c.Write(
zap.Error(err),
zap.Duration("delay", delay),
@@ -567,11 +701,29 @@ func isWebsocket(r *http.Request) bool {
httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket")
}
// openConnection maps an open connection to
// an optional function for graceful close.
type openConnection struct {
conn io.ReadWriteCloser
gracefulClose func() error
// closeConnectionsForUpstream closes all tracked connections that were
// established to the given upstream address.
func (ts *tunnelTracker) closeConnectionsForUpstream(addr string) error {
var err error
ts.mu.Lock()
defer ts.mu.Unlock()
if !ts.stopped {
return nil
}
for _, oc := range ts.connections {
if oc.upstream != addr {
continue
}
if oc.gracefulClose != nil {
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
err = gracefulErr
}
}
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
err = closeErr
}
}
return err
}
type maxLatencyWriter struct {
@@ -642,16 +794,23 @@ type switchProtocolCopier struct {
user, backend io.ReadWriteCloser
wg *sync.WaitGroup
bufferSize int
// sent and received accumulate byte counts for each direction.
// They are written before wg.Done() and read after wg.Wait(), so no
// additional synchronization is needed beyond the WaitGroup barrier.
sent *int64 // bytes copied to backend; must be non-nil
received *int64 // bytes copied from backend; must be non-nil
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
_, err := io.CopyBuffer(c.user, c.backend, c.buffer())
n, err := io.CopyBuffer(c.user, c.backend, c.buffer())
*c.received = n
errc <- err
c.wg.Done()
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
_, err := io.CopyBuffer(c.backend, c.user, c.buffer())
n, err := io.CopyBuffer(c.backend, c.user, c.buffer())
*c.sent = n
errc <- err
c.wg.Done()
}
@@ -7,8 +7,10 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
)
func TestHandlerCopyResponse(t *testing.T) {
@@ -41,12 +43,15 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) {
var wg sync.WaitGroup
var errc = make(chan error, 1)
var dst bytes.Buffer
var sent, received int64
copier := switchProtocolCopier{
user: nopReadWriteCloser{Reader: strings.NewReader("hello")},
backend: nopReadWriteCloser{Writer: &dst},
wg: &wg,
bufferSize: 7,
sent: &sent,
received: &received,
}
buf := copier.buffer()
@@ -80,3 +85,146 @@ type nopReadWriteCloser struct {
}
func (nopReadWriteCloser) Close() error { return nil }
type trackingReadWriteCloser struct {
closed chan struct{}
one sync.Once
}
func newTrackingReadWriteCloser() *trackingReadWriteCloser {
return &trackingReadWriteCloser{closed: make(chan struct{})}
}
func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF }
func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil }
func (c *trackingReadWriteCloser) Close() error {
c.one.Do(func() {
close(c.closed)
})
return nil
}
func (c *trackingReadWriteCloser) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) {
ts := newTunnelTracker(caddy.Log(), 0)
connA := newTrackingReadWriteCloser()
connB := newTrackingReadWriteCloser()
ts.registerConnection(connA, nil, false, "a")
ts.registerConnection(connB, nil, false, "b")
h := &Handler{
tunnelTracker: ts,
StreamDetached: false,
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if !connA.isClosed() || !connB.isClosed() {
t.Fatalf("legacy cleanup should close all upgraded connections")
}
}
func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) {
ts := newTunnelTracker(caddy.Log(), 40*time.Millisecond)
conn := newTrackingReadWriteCloser()
ts.registerConnection(conn, nil, false, "a")
h := &Handler{
tunnelTracker: ts,
StreamDetached: false,
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if conn.isClosed() {
t.Fatal("connection should not close immediately when stream_close_delay is set")
}
select {
case <-conn.closed:
case <-time.After(500 * time.Millisecond):
t.Fatal("connection did not close after stream_close_delay elapsed")
}
}
func TestHandlerCleanupDetachedModeClosesOnlyRemovedUpstreams(t *testing.T) {
const upstreamA = "upstream-a"
const upstreamB = "upstream-b"
// Simulate old+new configs both referencing upstreamA (refcount 2),
// while upstreamB is only referenced by the old config (refcount 1).
hosts.LoadOrStore(upstreamA, struct{}{})
hosts.LoadOrStore(upstreamA, struct{}{})
hosts.LoadOrStore(upstreamB, struct{}{})
t.Cleanup(func() {
_, _ = hosts.Delete(upstreamA)
_, _ = hosts.Delete(upstreamA)
_, _ = hosts.Delete(upstreamB)
})
ts := newTunnelTracker(caddy.Log(), 0)
registerDetachedTunnelTrackers(ts)
connA := newTrackingReadWriteCloser()
connB := newTrackingReadWriteCloser()
ts.registerConnection(connA, nil, true, upstreamA)
ts.registerConnection(connB, nil, true, upstreamB)
h := &Handler{
tunnelTracker: ts,
StreamDetached: true,
Upstreams: UpstreamPool{
&Upstream{Dial: upstreamA},
&Upstream{Dial: upstreamB},
},
}
if err := h.Cleanup(); err != nil {
t.Fatalf("cleanup failed: %v", err)
}
if connA.isClosed() {
t.Fatal("connection for detached upstream should remain open")
}
if !connB.isClosed() {
t.Fatal("connection for removed upstream should be closed")
}
}
func TestHandlerUnmarshalCaddyfileStreamLogsBlock(t *testing.T) {
d := caddyfile.NewTestDispenser(`
reverse_proxy localhost:9000 {
stream_logs {
level info
logger_name access
skip_handshake
}
}
`)
var h Handler
if err := h.UnmarshalCaddyfile(d); err != nil {
t.Fatalf("UnmarshalCaddyfile() error = %v", err)
}
if h.StreamLogs == nil {
t.Fatal("expected stream_logs to be configured")
}
if h.StreamLogs.Level != "info" {
t.Fatalf("expected stream_logs.level=info, got %q", h.StreamLogs.Level)
}
if h.StreamLogs.LoggerName != "access" {
t.Fatalf("expected stream_logs.logger_name=access, got %q", h.StreamLogs.LoggerName)
}
if !h.StreamLogs.SkipHandshake {
t.Fatal("expected stream_logs.skip_handshake=true")
}
}
+3 -28
View File
@@ -34,9 +34,7 @@ func init() {
// parseCaddyfileRewrite sets up a basic rewrite handler from Caddyfile tokens. Syntax:
//
// rewrite [<matcher>] <to> {
// force_modify_query
// }
// rewrite [<matcher>] <to>
//
// Only URI components which are given in <to> will be set in the resulting URI.
// See the docs for the rewrite handler for more information.
@@ -52,30 +50,12 @@ func parseCaddyfileRewrite(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
return nil, h.Errf("too many arguments; should only be a matcher and a URI")
}
parseBlock := func(rewr *Rewrite) error {
for nesting := h.Nesting(); h.NextBlock(nesting); {
switch h.Val() {
case "force_modify_query":
rewr.ForceModifyQuery = true
default:
return h.Errf("unknown subdirective: %s", h.Val())
}
}
return nil
}
// with only one arg, assume it's a rewrite URI with no matcher token
if argsCount == 1 {
if !h.NextArg() {
return nil, h.ArgErr()
}
rewr := Rewrite{URI: h.Val()}
err := parseBlock(&rewr)
if err != nil {
return nil, err
}
return h.NewRoute(nil, rewr), nil
return h.NewRoute(nil, Rewrite{URI: h.Val()}), nil
}
// parse the matcher token into a matcher set
@@ -86,12 +66,7 @@ func parseCaddyfileRewrite(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
h.Next() // consume directive name again, matcher parsing does a reset
h.Next() // advance to the rewrite URI
rewr := Rewrite{URI: h.Val()}
err = parseBlock(&rewr)
if err != nil {
return nil, err
}
return h.NewRoute(userMatcherSet, rewr), nil
return h.NewRoute(userMatcherSet, Rewrite{URI: h.Val()}), nil
}
// parseCaddyfileMethod sets up a basic method rewrite handler from Caddyfile tokens. Syntax:
+28 -24
View File
@@ -92,17 +92,6 @@ type Rewrite struct {
// Mutates the query string of the URI.
Query *queryOps `json:"query,omitempty"`
// If true, the rewrite will be forced to also apply to the
// query part of the URL. This is only needed if the configured
// URI does not include a '?' character which is normally used
// to determine whether the query should be modified. In other
// words, this allows rewriting both the path and query when
// using a placeholder as the replacement value, whereas otherwise
// only the path would be rewritten because the placeholder itself
// does not contain a '?' character. Only use this if the placeholder
// is trusted to not be vulnerable to query injections.
ForceModifyQuery bool `json:"force_modify_query,omitempty"`
logger *zap.Logger
}
@@ -222,12 +211,7 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
var newPath, newQuery, newFrag string
if path != "" {
// replace the `path` placeholder to escaped path
pathPlaceholder := "{http.request.uri.path}"
if strings.Contains(path, pathPlaceholder) {
path = strings.ReplaceAll(path, pathPlaceholder, r.URL.EscapedPath())
}
path = escapePathPlaceholders(path, r, repl)
newPath = repl.ReplaceAll(path, "")
}
@@ -237,15 +221,10 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
// recompute; new path contains a query string
var injectedQuery string
newPath, injectedQuery = before, after
// don't overwrite explicitly-configured query string,
// unless configured explicitly to do so
if query == "" || rewr.ForceModifyQuery {
// don't overwrite explicitly-configured query string
if query == "" {
query = injectedQuery
}
if rewr.ForceModifyQuery {
qsStart = 0
}
}
if query != "" {
@@ -316,6 +295,31 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
return r.Method != oldMethod || r.RequestURI != oldURI
}
func escapePathPlaceholders(path string, r *http.Request, repl *caddy.Replacer) string {
// Replace path-valued placeholders in escaped form before the URI is parsed,
// otherwise literal '?' and '%' bytes from the path can be interpreted as URI
// delimiters or percent-escape sequences during the rewrite.
pathPlaceholder := "{http.request.uri.path}"
if strings.Contains(path, pathPlaceholder) {
path = strings.ReplaceAll(path, pathPlaceholder, r.URL.EscapedPath())
}
fileMatchRelativePlaceholder := "{http.matchers.file.relative}"
if strings.Contains(path, fileMatchRelativePlaceholder) {
if val, ok := repl.Get("http.matchers.file.relative"); ok {
if relativePath, ok := val.(string); ok {
path = strings.ReplaceAll(path, fileMatchRelativePlaceholder, escapePathPreservingSlashes(relativePath))
}
}
}
return path
}
func escapePathPreservingSlashes(path string) string {
return strings.ReplaceAll(url.PathEscape(path), "%2F", "/")
}
// buildQueryString takes an input query string and
// performs replacements on each component, returning
// the resulting query string. This function appends
-20
View File
@@ -225,23 +225,6 @@ func TestRewrite(t *testing.T) {
input: newRequest(t, "GET", "/foo#fragFirst?c=d"),
expect: newRequest(t, "GET", "/bar#fragFirst?c=d"),
},
{
rule: Rewrite{URI: "{test.path_and_query}"},
input: newRequest(t, "GET", "/"),
expect: newRequest(t, "GET", "/foo"),
},
{
// TODO: This might be an incorrect result, since it also replaces
// the path with empty string when that might not be the intent.
rule: Rewrite{URI: "{test.query}", ForceModifyQuery: true},
input: newRequest(t, "GET", "/foo"),
expect: newRequest(t, "GET", "?bar=1"),
},
{
rule: Rewrite{URI: "{test.path_and_query}", ForceModifyQuery: true},
input: newRequest(t, "GET", "/"),
expect: newRequest(t, "GET", "/foo?bar=1"),
},
{
rule: Rewrite{URI: "/api/admin/panel"},
input: newRequest(t, "GET", "/api/admin%2Fpanel"),
@@ -381,9 +364,6 @@ func TestRewrite(t *testing.T) {
repl.Set("http.request.uri", tc.input.RequestURI)
repl.Set("http.request.uri.path", tc.input.URL.Path)
repl.Set("http.request.uri.query", tc.input.URL.RawQuery)
repl.Set("test.path", "/foo")
repl.Set("test.query", "?bar=1")
repl.Set("test.path_and_query", "/foo?bar=1")
// we can't directly call Provision() without a valid caddy.Context
// (TODO: fix that) so here we ad-hoc compile the regex
+47 -9
View File
@@ -300,6 +300,8 @@ type Server struct {
onStopFuncs []func(context.Context) error // TODO: Experimental (Nov. 2023)
}
var defaultProtocols = []string{"h1", "h2", "h3"}
var (
ServerHeader = "Caddy"
serverHeader = []string{ServerHeader}
@@ -899,22 +901,58 @@ func (s *Server) logRequest(
// protocol returns true if the protocol proto is configured/enabled.
func (s *Server) protocol(proto string) bool {
if s.ListenProtocols == nil {
if slices.Contains(s.Protocols, proto) {
return slices.Contains(s.protocolsWithDefaults(), proto)
}
for _, lnProtocols := range s.ListenProtocols {
if slices.Contains(s.listenerProtocolsWithDefaults(lnProtocols), proto) {
return true
}
} else {
for _, lnProtocols := range s.ListenProtocols {
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" && slices.Contains(s.Protocols, proto) || lnProtocol == proto {
return true
}
}
}
}
return false
}
func (s *Server) protocolsWithDefaults() []string {
if len(s.Protocols) == 0 {
return defaultProtocols
}
return s.Protocols
}
func (s *Server) listenerProtocolsWithDefaults(lnProtocols []string) []string {
serverProtocols := s.protocolsWithDefaults()
if len(lnProtocols) == 0 {
return serverProtocols
}
lnProtocolsDefault := false
lnProtocolsInclude := make([]string, 0, len(lnProtocols)+len(serverProtocols))
srvProtocolsInclude := make(map[string]struct{}, len(serverProtocols))
for _, srvProtocol := range serverProtocols {
srvProtocolsInclude[srvProtocol] = struct{}{}
}
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" {
lnProtocolsDefault = true
continue
}
lnProtocolsInclude = append(lnProtocolsInclude, lnProtocol)
delete(srvProtocolsInclude, lnProtocol)
}
if lnProtocolsDefault {
for _, srvProtocol := range serverProtocols {
if _, ok := srvProtocolsInclude[srvProtocol]; ok {
lnProtocolsInclude = append(lnProtocolsInclude, srvProtocol)
}
}
}
return lnProtocolsInclude
}
// Listeners returns the server's listeners. These are active listeners,
// so calling Accept() or Close() on them will probably break things.
// They are made available here for read-only purposes (e.g. Addr())
+30 -2
View File
@@ -36,13 +36,22 @@ func init() {
// Templates is a middleware which executes response bodies as Go templates.
// The syntax is documented in the Go standard library's
// [text/template package](https://golang.org/pkg/text/template/).
// Note that ANY response body that matches and qualifies may be evaluated,
// even if it comes from a proxied backend.
//
// ⚠️ Template functions/actions are still experimental, so they are subject to change.
// ⚠️ Template functions/actions can access the environment, files on disk,
// and make HTTP requests. This is extremely useful, but you need to make
// sure templates are only evaluated on content that you trust, control, or
// at least sanitize properly.
//
// Custom template functions can be registered by creating a plugin module under the `http.handlers.templates.functions.*` namespace that implements the `CustomFunctions` interface.
// ⚠️ Templates are still experimental, so they are subject to change.
//
// [All Sprig functions](https://masterminds.github.io/sprig/) are supported.
//
// Custom template functions can be registered by creating a plugin module
// under the `http.handlers.templates.functions.*` namespace that implements
// the `CustomFunctions` interface.
//
// In addition to the standard functions and the Sprig library, Caddy adds
// extra functions and data that are available to a template:
//
@@ -162,6 +171,25 @@ func init() {
// {{listFiles "/mydir"}}
// ```
//
// ##### `fileExists`
//
// Returns true if the given file name, relative to the template context's file root,
// can be opened successfully.
//
// ```
// {{fileExists "path/to/file.html"}}
// ```
//
// ##### `fileStat`
//
// Returns [FileInfo](https://pkg.go.dev/io/fs#FileInfo) using [Stat](https://pkg.go.dev/io/fs#Stat)
// on the given file name, relative to the template context's file root.
//
// ```
// {{$css := fileStat "css/style.css" -}}
// <link rel="stylesheet" href="/css/style.css?v={{ $css.ModTime.Unix }}">
// ```
//
// ##### `markdown`
//
// Renders the given Markdown text as HTML and returns it. This uses the
+2 -2
View File
@@ -153,9 +153,9 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy.Context) *tls.Config {
// in its config (remember, TLS connection policies are used by *other* apps to
// run TLS servers) -- we skip names with placeholders
if tlsApp.EncryptedClientHello.Publication == nil {
var echNames []string
repl := caddy.NewReplacer()
for _, p := range cp {
var echNames []string
for _, m := range p.matchers {
if sni, ok := m.(MatchServerName); ok {
for _, name := range sni {
@@ -164,8 +164,8 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy.Context) *tls.Config {
}
}
}
tlsApp.RegisterServerNames(echNames, p.ALPN)
}
tlsApp.RegisterServerNames(echNames)
}
tlsCfg.GetEncryptedClientHelloKeys = func(chi *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
+26 -7
View File
@@ -440,6 +440,10 @@ func (t *TLS) publishECHConfigs(logger *zap.Logger) error {
zap.Strings("domains", dnsNamesToPublish),
zap.Uint8s("config_ids", configIDs))
if dnsPublisher, ok := publisher.(*ECHDNSPublisher); ok {
dnsPublisher.alpnByDomain = t.alpnValuesForServerNames(dnsNamesToPublish)
}
// publish this ECH config list with this publisher
pubTime := time.Now()
err := publisher.PublishECHConfigList(t.ctx, dnsNamesToPublish, echCfgListBin)
@@ -776,7 +780,8 @@ type ECHDNSPublisher struct {
ProviderRaw json.RawMessage `json:"provider,omitempty" caddy:"namespace=dns.providers inline_key=name"`
provider ECHDNSProvider
logger *zap.Logger
alpnByDomain map[string][]string
logger *zap.Logger
}
// CaddyModule returns the Caddy module information.
@@ -872,12 +877,7 @@ nextName:
continue
}
params := httpsRec.Params
if params == nil {
params = make(libdns.SvcParams)
}
// overwrite only the "ech" SvcParamKey
params["ech"] = []string{base64.StdEncoding.EncodeToString(configListBin)}
params = dnsPub.publishedSvcParams(domain, params, configListBin)
// publish record
_, err = dnsPub.provider.SetRecords(ctx, zone, []libdns.Record{
@@ -903,6 +903,25 @@ nextName:
return nil
}
func (dnsPub *ECHDNSPublisher) publishedSvcParams(domain string, existing libdns.SvcParams, configListBin []byte) libdns.SvcParams {
params := make(libdns.SvcParams, len(existing)+2)
for key, values := range existing {
params[key] = append([]string(nil), values...)
}
params["ech"] = []string{base64.StdEncoding.EncodeToString(configListBin)}
if len(dnsPub.alpnByDomain) == 0 {
return params
}
if alpn := dnsPub.alpnByDomain[strings.ToLower(domain)]; len(alpn) > 0 {
params["alpn"] = append([]string(nil), alpn...)
}
return params
}
// echConfig represents an ECHConfig from the specification,
// [draft-ietf-tls-esni-22](https://www.ietf.org/archive/id/draft-ietf-tls-esni-22.html).
type echConfig struct {
+65
View File
@@ -0,0 +1,65 @@
package caddytls
import (
"encoding/base64"
"reflect"
"sync"
"testing"
"github.com/libdns/libdns"
)
func TestRegisterServerNamesWithALPN(t *testing.T) {
tlsApp := &TLS{
serverNames: make(map[string]serverNameRegistration),
serverNamesMu: new(sync.Mutex),
}
tlsApp.RegisterServerNames([]string{
"Example.com:443",
"example.com",
"127.0.0.1:443",
}, []string{"h2", "http/1.1"})
tlsApp.RegisterServerNames([]string{"EXAMPLE.COM"}, []string{"h3"})
got := tlsApp.alpnValuesForServerNames([]string{"example.com:443", "127.0.0.1:443"})
want := map[string][]string{
"example.com": {"h3", "h2", "http/1.1"},
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected ALPN values: got %#v want %#v", got, want)
}
}
func TestECHDNSPublisherPublishedSvcParams(t *testing.T) {
dnsPub := &ECHDNSPublisher{
alpnByDomain: map[string][]string{
"example.com": {"h3", "h2", "http/1.1"},
},
}
existing := libdns.SvcParams{
"alpn": {"h2"},
"ipv4hint": {"203.0.113.10"},
}
got := dnsPub.publishedSvcParams("Example.com", existing, []byte{0x01, 0x02, 0x03})
if !reflect.DeepEqual(existing["alpn"], []string{"h2"}) {
t.Fatalf("existing params mutated: got %v", existing["alpn"])
}
if !reflect.DeepEqual(got["alpn"], []string{"h3", "h2", "http/1.1"}) {
t.Fatalf("unexpected ALPN params: got %v", got["alpn"])
}
if !reflect.DeepEqual(got["ipv4hint"], []string{"203.0.113.10"}) {
t.Fatalf("unexpected preserved params: got %v", got["ipv4hint"])
}
wantECH := base64.StdEncoding.EncodeToString([]byte{0x01, 0x02, 0x03})
if !reflect.DeepEqual(got["ech"], []string{wantECH}) {
t.Fatalf("unexpected ECH params: got %v want %v", got["ech"], wantECH)
}
}
+104 -16
View File
@@ -23,6 +23,7 @@ import (
"net"
"net/http"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
@@ -140,7 +141,7 @@ type TLS struct {
logger *zap.Logger
events *caddyevents.App
serverNames map[string]struct{}
serverNames map[string]serverNameRegistration
serverNamesMu *sync.Mutex
// set of subjects with managed certificates,
@@ -168,7 +169,7 @@ func (t *TLS) Provision(ctx caddy.Context) error {
t.logger = ctx.Logger()
repl := caddy.NewReplacer()
t.managing, t.loaded = make(map[string]string), make(map[string]string)
t.serverNames = make(map[string]struct{})
t.serverNames = make(map[string]serverNameRegistration)
t.serverNamesMu = new(sync.Mutex)
// set up default DNS module, if any, and make sure it implements all the
@@ -613,8 +614,8 @@ func (t *TLS) Manage(subjects map[string]struct{}) error {
// managingWildcardFor returns true if the app is managing a certificate that covers that
// subject name (including consideration of wildcards), either from its internal list of
// names that it IS managing certs for, or from the otherSubjsToManage which includes names
// that WILL be managed.
// names that it IS managing certs for, from the otherSubjsToManage which includes names
// that WILL be managed, or from names configured in the 'automate' loader.
func (t *TLS) managingWildcardFor(subj string, otherSubjsToManage map[string]struct{}) bool {
// TODO: we could also consider manually-loaded certs using t.HasCertificateForSubject(),
// but that does not account for how manually-loaded certs may be restricted as to which
@@ -629,7 +630,9 @@ func (t *TLS) managingWildcardFor(subj string, otherSubjsToManage map[string]str
return managing
}
// replace labels of the domain with wildcards until we get a match
// replace labels of the domain with wildcards until we get a match from names
// already being managed, those about to be managed in this batch, or those
// configured for automation
labels := strings.Split(subj, ".")
for i := range labels {
if labels[i] == "*" {
@@ -643,32 +646,117 @@ func (t *TLS) managingWildcardFor(subj string, otherSubjsToManage map[string]str
if _, ok := otherSubjsToManage[candidate]; ok {
return true
}
if _, ok := t.automateNames[candidate]; ok {
return true
}
}
return false
}
// RegisterServerNames registers the provided DNS names with the TLS app.
// This is currently used to auto-publish Encrypted ClientHello (ECH)
// configurations, if enabled. Use of this function by apps using the TLS
// app removes the need for the user to redundantly specify domain names
// in their configuration. This function separates hostname and port
// (keeping only the hotsname) and filters IP addresses, which can't be
// used with ECH.
// RegisterServerNames registers the provided DNS names with the TLS app and
// associates them with the given HTTPS RR ALPN values, if any. This is
// currently used to auto-publish Encrypted ClientHello (ECH) configurations,
// if enabled. Use of this function by apps using the TLS app removes the need
// for the user to redundantly specify domain names in their configuration.
// This function separates hostname and port, keeping only the hostname, and
// filters IP addresses which can't be used with ECH.
//
// EXPERIMENTAL: This function and its semantics/behavior are subject to change.
func (t *TLS) RegisterServerNames(dnsNames []string) {
func (t *TLS) RegisterServerNames(dnsNames, alpnValues []string) {
t.serverNamesMu.Lock()
defer t.serverNamesMu.Unlock()
for _, name := range dnsNames {
host, _, err := net.SplitHostPort(name)
if err != nil {
host = name
}
if strings.TrimSpace(host) != "" && !certmagic.SubjectIsIP(host) {
t.serverNames[strings.ToLower(host)] = struct{}{}
host = strings.ToLower(strings.TrimSpace(host))
if host == "" || certmagic.SubjectIsIP(host) {
continue
}
registration := t.serverNames[host]
if len(alpnValues) == 0 {
t.serverNames[host] = registration
continue
}
if registration.alpnValues == nil {
registration.alpnValues = make(map[string]struct{}, len(alpnValues))
}
for _, alpn := range alpnValues {
if alpn == "" {
continue
}
registration.alpnValues[alpn] = struct{}{}
}
t.serverNames[host] = registration
}
}
func (t *TLS) alpnValuesForServerNames(dnsNames []string) map[string][]string {
t.serverNamesMu.Lock()
defer t.serverNamesMu.Unlock()
result := make(map[string][]string, len(dnsNames))
for _, name := range dnsNames {
host, _, err := net.SplitHostPort(name)
if err != nil {
host = name
}
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
continue
}
registration, ok := t.serverNames[host]
if !ok || len(registration.alpnValues) == 0 {
continue
}
result[host] = OrderedHTTPSRRALPN(registration.alpnValues)
}
return result
}
// OrderedHTTPSRRALPN returns the HTTPS RR ALPN values in preferred order.
func OrderedHTTPSRRALPN(alpnSet map[string]struct{}) []string {
if len(alpnSet) == 0 {
return nil
}
knownOrder := append([]string{"h3"}, defaultALPN...)
ordered := make([]string, 0, len(alpnSet))
seen := make(map[string]struct{}, len(alpnSet))
for _, alpn := range knownOrder {
if _, ok := alpnSet[alpn]; ok {
ordered = append(ordered, alpn)
seen[alpn] = struct{}{}
}
}
t.serverNamesMu.Unlock()
if len(ordered) == len(alpnSet) {
return ordered
}
var remaining []string
for alpn := range alpnSet {
if _, ok := seen[alpn]; ok {
continue
}
remaining = append(remaining, alpn)
}
slices.Sort(remaining)
return append(ordered, remaining...)
}
type serverNameRegistration struct {
alpnValues map[string]struct{}
}
// HandleHTTPChallenge ensures that the ACME HTTP challenge or ZeroSSL HTTP
+96
View File
@@ -0,0 +1,96 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"encoding/json"
"testing"
"github.com/caddyserver/caddy/v2"
)
func TestAvoidDuplicateAutomation(t *testing.T) {
tests := []struct {
name string
automateNames []string
expectedToManage bool
}{
{
name: "do not manage if wildcard is automated",
automateNames: []string{"*.example.com"},
expectedToManage: false,
},
{
name: "manage if no automation configured",
automateNames: []string{},
expectedToManage: true,
},
{
name: "manage if explicitly requested even when wildcard automated",
automateNames: []string{"*.example.com", "sub.example.com"},
expectedToManage: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
automateJSON, err := json.Marshal(tc.automateNames)
if err != nil {
t.Fatal(err)
}
tlsApp := &TLS{
Automation: &AutomationConfig{
Policies: []*AutomationPolicy{
{
IssuersRaw: []json.RawMessage{
[]byte(`{"module": "internal"}`),
},
},
},
},
CertificatesRaw: map[string]json.RawMessage{
"automate": automateJSON,
},
}
var cfg caddy.Config
ctx, err := caddy.ProvisionContext(&cfg)
if err != nil {
t.Fatal(err)
}
if err := tlsApp.Provision(ctx); err != nil {
t.Fatal(err)
}
// simulate a case wherein the HTTP app starts first and
// tells the TLS app about the following auto-HTTPS domains
httpDomains := map[string]struct{}{"sub.example.com": {}}
if err := tlsApp.Manage(httpDomains); err != nil {
t.Fatal(err)
}
_, actuallyManaged := tlsApp.managing["sub.example.com"]
if actuallyManaged != tc.expectedToManage {
t.Errorf(
"expected sub.example.com individually managed: %v, got: %v",
tc.expectedToManage,
actuallyManaged,
)
}
})
}
}