Adjustments from Weidi's review

This commit is contained in:
Francis Lavoie
2026-04-18 14:16:20 -04:00
parent 5a1ace3e91
commit db86fdaba2
6 changed files with 510 additions and 41 deletions
@@ -0,0 +1,146 @@
package reverseproxy
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"go.uber.org/zap"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
type extendedConnectCapture struct {
method string
headers http.Header
body []byte
extendedBodyPresent bool
extendedConnectBody []byte
}
type extendedConnectCaptureTransport struct {
mu sync.Mutex
capture extendedConnectCapture
}
func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
c := extendedConnectCapture{
method: req.Method,
headers: req.Header.Clone(),
body: body,
}
if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
c.extendedBodyPresent = true
c.extendedConnectBody, err = io.ReadAll(rc)
if err != nil {
return nil, err
}
_ = rc.Close()
}
tr.mu.Lock()
tr.capture = c
tr.mu.Unlock()
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: req,
}, nil
}
func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture {
tr.mu.Lock()
defer tr.mu.Unlock()
return tr.capture
}
func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) {
tests := []struct {
name string
protoMajor int
proto string
headers map[string]string
}{
{
name: "h2 extended connect",
protoMajor: 2,
proto: "HTTP/2.0",
headers: map[string]string{
":protocol": "websocket",
},
},
{
name: "h3 extended connect",
protoMajor: 3,
proto: "websocket",
headers: map[string]string{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
const payload = "extended-connect-body"
transport := new(extendedConnectCaptureTransport)
h := &Handler{
logger: zap.NewNop(),
Transport: transport,
Upstreams: UpstreamPool{
&Upstream{Host: new(Host), Dial: "127.0.0.1:8443"},
},
LoadBalancing: &LoadBalancing{
SelectionPolicy: &RoundRobinSelection{},
},
}
req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload))
req.ProtoMajor = tc.protoMajor
req.Proto = tc.proto
for key, value := range tc.headers {
req.Header.Set(key, value)
}
req = prepareTestRequest(req)
rr := httptest.NewRecorder()
err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
if err != nil {
t.Fatalf("ServeHTTP() error = %v", err)
}
captured := transport.Snapshot()
if captured.method != http.MethodGet {
t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet)
}
if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") {
t.Fatalf("Upgrade header = %q, want websocket", got)
}
if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") {
t.Fatalf("Connection header = %q, want Upgrade", got)
}
if got := captured.headers.Get(":protocol"); got != "" {
t.Fatalf(":protocol header should be removed, got %q", got)
}
if len(captured.body) != 0 {
t.Fatalf("upstream request body length = %d, want 0", len(captured.body))
}
if !captured.extendedBodyPresent {
t.Fatal("extended_connect_websocket_body variable missing from request context")
}
if string(captured.extendedConnectBody) != payload {
t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload)
}
})
}
}
+10 -29
View File
@@ -100,14 +100,14 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
streamTimeout := time.Duration(h.StreamTimeout)
var (
conn io.ReadWriteCloser
brw *bufio.ReadWriter
isH2 bool
conn io.ReadWriteCloser
brw *bufio.ReadWriter
isExtendedConnect bool
)
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can be removed for those backends
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
isH2 = true
isExtendedConnect = true
req.Body = body
rw.Header().Del("Upgrade")
rw.Header().Del("Connection")
@@ -115,13 +115,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
rw.WriteHeader(http.StatusOK)
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
c.Write(zap.Int("http_version", 2))
c.Write(zap.Int("http_version", req.ProtoMajor))
}
//nolint:bodyclose
flushErr := http.NewResponseController(rw).Flush()
if flushErr != nil {
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil {
c.Write(zap.Error(flushErr))
}
return
@@ -154,25 +154,6 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
}
// For H2 extended connect: close backConn when the request context is
// cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections
// we skip this because req.Context() may be cancelled when ServeHTTP
// returns early, which would prematurely close the backend connection.
if isH2 {
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
backConnCloseCh := make(chan struct{})
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
}
if err := brw.Flush(); err != nil {
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
c.Write(zap.Error(err))
@@ -221,8 +202,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
start := time.Now()
if isH2 {
h.handleH2UpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
if isExtendedConnect {
h.handleExtendedConnectUpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
} else {
h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
// Return immediately without touching wg. finalizeResponse's
@@ -230,7 +211,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
}
}
func (h *Handler) handleH2UpgradeTunnel(
func (h *Handler) handleExtendedConnectUpgradeTunnel(
streamLogger *zap.Logger,
streamLevel zapcore.Level,
wg *sync.WaitGroup,
@@ -244,7 +225,7 @@ func (h *Handler) handleH2UpgradeTunnel(
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
streamFields []zap.Field,
) {
// H2 extended connect: ServeHTTP must block because rw and req.Body are
// Extended CONNECT: ServeHTTP must block because rw and req.Body are
// only valid while the handler goroutine is running. Defers clean up
// when the select below fires and this function returns.
defer deleteBackConn()