mirror of
https://github.com/caddyserver/caddy.git
synced 2026-05-30 10:35:18 -04:00
Adjustments from Weidi's review
This commit is contained in:
@@ -0,0 +1,146 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
type extendedConnectCapture struct {
|
||||
method string
|
||||
headers http.Header
|
||||
body []byte
|
||||
extendedBodyPresent bool
|
||||
extendedConnectBody []byte
|
||||
}
|
||||
|
||||
type extendedConnectCaptureTransport struct {
|
||||
mu sync.Mutex
|
||||
capture extendedConnectCapture
|
||||
}
|
||||
|
||||
func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := extendedConnectCapture{
|
||||
method: req.Method,
|
||||
headers: req.Header.Clone(),
|
||||
body: body,
|
||||
}
|
||||
if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
|
||||
c.extendedBodyPresent = true
|
||||
c.extendedConnectBody, err = io.ReadAll(rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = rc.Close()
|
||||
}
|
||||
|
||||
tr.mu.Lock()
|
||||
tr.capture = c
|
||||
tr.mu.Unlock()
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture {
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
return tr.capture
|
||||
}
|
||||
|
||||
func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protoMajor int
|
||||
proto string
|
||||
headers map[string]string
|
||||
}{
|
||||
{
|
||||
name: "h2 extended connect",
|
||||
protoMajor: 2,
|
||||
proto: "HTTP/2.0",
|
||||
headers: map[string]string{
|
||||
":protocol": "websocket",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "h3 extended connect",
|
||||
protoMajor: 3,
|
||||
proto: "websocket",
|
||||
headers: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const payload = "extended-connect-body"
|
||||
|
||||
transport := new(extendedConnectCaptureTransport)
|
||||
h := &Handler{
|
||||
logger: zap.NewNop(),
|
||||
Transport: transport,
|
||||
Upstreams: UpstreamPool{
|
||||
&Upstream{Host: new(Host), Dial: "127.0.0.1:8443"},
|
||||
},
|
||||
LoadBalancing: &LoadBalancing{
|
||||
SelectionPolicy: &RoundRobinSelection{},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload))
|
||||
req.ProtoMajor = tc.protoMajor
|
||||
req.Proto = tc.proto
|
||||
for key, value := range tc.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
req = prepareTestRequest(req)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP() error = %v", err)
|
||||
}
|
||||
|
||||
captured := transport.Snapshot()
|
||||
if captured.method != http.MethodGet {
|
||||
t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet)
|
||||
}
|
||||
if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") {
|
||||
t.Fatalf("Upgrade header = %q, want websocket", got)
|
||||
}
|
||||
if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") {
|
||||
t.Fatalf("Connection header = %q, want Upgrade", got)
|
||||
}
|
||||
if got := captured.headers.Get(":protocol"); got != "" {
|
||||
t.Fatalf(":protocol header should be removed, got %q", got)
|
||||
}
|
||||
if len(captured.body) != 0 {
|
||||
t.Fatalf("upstream request body length = %d, want 0", len(captured.body))
|
||||
}
|
||||
if !captured.extendedBodyPresent {
|
||||
t.Fatal("extended_connect_websocket_body variable missing from request context")
|
||||
}
|
||||
if string(captured.extendedConnectBody) != payload {
|
||||
t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -100,14 +100,14 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
streamTimeout := time.Duration(h.StreamTimeout)
|
||||
|
||||
var (
|
||||
conn io.ReadWriteCloser
|
||||
brw *bufio.ReadWriter
|
||||
isH2 bool
|
||||
conn io.ReadWriteCloser
|
||||
brw *bufio.ReadWriter
|
||||
isExtendedConnect bool
|
||||
)
|
||||
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
|
||||
// TODO: once we can reliably detect backend support this, it can be removed for those backends
|
||||
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
|
||||
isH2 = true
|
||||
isExtendedConnect = true
|
||||
req.Body = body
|
||||
rw.Header().Del("Upgrade")
|
||||
rw.Header().Del("Connection")
|
||||
@@ -115,13 +115,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
|
||||
c.Write(zap.Int("http_version", 2))
|
||||
c.Write(zap.Int("http_version", req.ProtoMajor))
|
||||
}
|
||||
|
||||
//nolint:bodyclose
|
||||
flushErr := http.NewResponseController(rw).Flush()
|
||||
if flushErr != nil {
|
||||
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
|
||||
if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil {
|
||||
c.Write(zap.Error(flushErr))
|
||||
}
|
||||
return
|
||||
@@ -154,25 +154,6 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
}
|
||||
}
|
||||
|
||||
// For H2 extended connect: close backConn when the request context is
|
||||
// cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections
|
||||
// we skip this because req.Context() may be cancelled when ServeHTTP
|
||||
// returns early, which would prematurely close the backend connection.
|
||||
if isH2 {
|
||||
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
|
||||
backConnCloseCh := make(chan struct{})
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
}
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
@@ -221,8 +202,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
|
||||
start := time.Now()
|
||||
|
||||
if isH2 {
|
||||
h.handleH2UpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
|
||||
if isExtendedConnect {
|
||||
h.handleExtendedConnectUpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
|
||||
} else {
|
||||
h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
|
||||
// Return immediately without touching wg. finalizeResponse's
|
||||
@@ -230,7 +211,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleH2UpgradeTunnel(
|
||||
func (h *Handler) handleExtendedConnectUpgradeTunnel(
|
||||
streamLogger *zap.Logger,
|
||||
streamLevel zapcore.Level,
|
||||
wg *sync.WaitGroup,
|
||||
@@ -244,7 +225,7 @@ func (h *Handler) handleH2UpgradeTunnel(
|
||||
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
|
||||
streamFields []zap.Field,
|
||||
) {
|
||||
// H2 extended connect: ServeHTTP must block because rw and req.Body are
|
||||
// Extended CONNECT: ServeHTTP must block because rw and req.Body are
|
||||
// only valid while the handler goroutine is running. Defers clean up
|
||||
// when the select below fires and this function returns.
|
||||
defer deleteBackConn()
|
||||
|
||||
Reference in New Issue
Block a user