diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c5052723..89135cbc2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -132,6 +132,8 @@ jobs: - name: Run tests # id: step_test # continue-on-error: true + env: + GODEBUG: http2xconnect=1 run: | # (go test -v -coverprofile=cover-profile.out -race ./... 2>&1) > test-results/test-result.out go test -v -coverprofile="cover-profile.out" -short -race ./... @@ -191,7 +193,7 @@ jobs: retries=3 exit_code=0 while ((retries > 0)); do - CGO_ENABLED=0 go test -p 1 -v ./... + GODEBUG=http2xconnect=1 CGO_ENABLED=0 go test -p 1 -v ./... exit_code=$? if ((exit_code == 0)); then break diff --git a/caddytest/integration/reverseproxy_extended_connect_test.go b/caddytest/integration/reverseproxy_extended_connect_test.go new file mode 100644 index 000000000..8822988be --- /dev/null +++ b/caddytest/integration/reverseproxy_extended_connect_test.go @@ -0,0 +1,328 @@ +package integration + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support") + +func TestReverseProxyExtendedConnectOverH2(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newWebsocketUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust + servers :9443 { + protocols h2 + } +} + +https://localhost:9443 { + reverse_proxy %s +} +`, backend.addr), "caddyfile") + + const payload = "extended-connect-echo\n" + if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil { + if errors.Is(err, errExtendedConnectUnsupportedByPeer) { + t.Skipf("skipping extended CONNECT integration test: %v", err) + } + t.Fatalf("extended connect h2 echo failed: %v", err) + } +} + +func assertExtendedConnectH2Echo(addr, payload string) error { + conn, err := tlsDialH2(addr) + if err != nil { + return fmt.Errorf("dialing h2 tls: %w", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + return fmt.Errorf("setting deadline: %w", err) + } + + fr := http2.NewFramer(conn, conn) + + if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil { + return fmt.Errorf("writing client preface: %w", err) + } + if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil { + return fmt.Errorf("writing client settings: %w", err) + } + + supported, err := waitForServerSettings(fr) + if err != nil { + return err + } + if !supported { + return errExtendedConnectUnsupportedByPeer + } + if err := waitForSettingsAck(fr); err != nil { + return err + } + + if err := writeExtendedConnectHeaders(fr, addr); err != nil { + return err + } + + status, err := readResponseStatus(fr, 1) + if err != nil { + return err + } + if status != "200" { + return fmt.Errorf("unexpected extended connect status: got=%s want=200", status) + } + + if err := fr.WriteData(1, false, []byte(payload)); err != nil { + return fmt.Errorf("writing stream data: %w", err) + } + + echo, err := readStreamData(fr, 1, len(payload)) + if err != nil { + return err + } + if echo != payload { + return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload) + } + + _ = fr.WriteRSTStream(1, http2.ErrCodeNo) + return nil +} + +func tlsDialH2(addr string) (net.Conn, error) { + var lastErr error + for i := 0; i < 30; i++ { + dialer := &net.Dialer{Timeout: 2 * time.Second} + conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + }) + if err == nil { + return conn, nil + } + lastErr = err + time.Sleep(100 * time.Millisecond) + } + return nil, lastErr +} + +func waitForServerSettings(fr *http2.Framer) (bool, error) { + for { + frame, err := fr.ReadFrame() + if err != nil { + return false, fmt.Errorf("reading frame before connect: %w", err) + } + settings, ok := frame.(*http2.SettingsFrame) + if !ok { + continue + } + if settings.IsAck() { + continue + } + + supported := false + if err := settings.ForeachSetting(func(s http2.Setting) error { + if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 { + supported = true + } + return nil + }); err != nil { + return false, fmt.Errorf("reading server settings: %w", err) + } + + if err := fr.WriteSettingsAck(); err != nil { + return false, fmt.Errorf("writing settings ack: %w", err) + } + return supported, nil + } +} + +func waitForSettingsAck(fr *http2.Framer) error { + for { + frame, err := fr.ReadFrame() + if err != nil { + return fmt.Errorf("reading settings ack: %w", err) + } + settings, ok := frame.(*http2.SettingsFrame) + if ok && settings.IsAck() { + return nil + } + } +} + +func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error { + var hb bytes.Buffer + enc := hpack.NewEncoder(&hb) + for _, hf := range []hpack.HeaderField{ + {Name: ":method", Value: "CONNECT"}, + {Name: ":scheme", Value: "https"}, + {Name: ":authority", Value: addr}, + {Name: ":path", Value: "/upgrade"}, + {Name: ":protocol", Value: "websocket"}, + } { + if err := enc.WriteField(hf); err != nil { + return fmt.Errorf("encoding request headers: %w", err) + } + } + + if err := fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: 1, + BlockFragment: hb.Bytes(), + EndHeaders: true, + EndStream: false, + }); err != nil { + return fmt.Errorf("writing extended connect headers: %w", err) + } + return nil +} + +func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) { + var block bytes.Buffer + + for { + frame, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading response headers: %w", err) + } + if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID { + return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode) + } + + h, ok := frame.(*http2.HeadersFrame) + if !ok || h.StreamID != streamID { + continue + } + + if _, err := block.Write(h.HeaderBlockFragment()); err != nil { + return "", fmt.Errorf("buffering response header fragment: %w", err) + } + for !h.HeadersEnded() { + next, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading continuation frame: %w", err) + } + c, ok := next.(*http2.ContinuationFrame) + if !ok || c.StreamID != streamID { + continue + } + if _, err := block.Write(c.HeaderBlockFragment()); err != nil { + return "", fmt.Errorf("buffering continuation fragment: %w", err) + } + if c.HeadersEnded() { + break + } + } + break + } + + var status string + dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) { + if f.Name == ":status" { + status = f.Value + } + }) + if _, err := dec.Write(block.Bytes()); err != nil { + return "", fmt.Errorf("decoding response header block: %w", err) + } + if status == "" { + return "", fmt.Errorf("missing :status in response headers") + } + return status, nil +} + +func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) { + buf := make([]byte, 0, n) + for len(buf) < n { + frame, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading stream data: %w", err) + } + d, ok := frame.(*http2.DataFrame) + if !ok || d.StreamID != streamID { + continue + } + buf = append(buf, d.Data()...) + } + return string(buf[:n]), nil +} + +type websocketUpgradeEchoBackend struct { + addr string + ln net.Listener + server *http.Server +} + +func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend { + t.Helper() + + backend := &websocketUpgradeEchoBackend{} + backend.server = &http.Server{ + Handler: http.HandlerFunc(backend.serveHTTP), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for websocket backend: %v", err) + } + backend.ln = ln + backend.addr = ln.Addr().String() + + go func() { + _ = backend.server.Serve(ln) + }() + + return backend +} + +func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + http.Error(w, "upgrade required", http.StatusUpgradeRequired) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + conn, rw, err := hijacker.Hijack() + if err != nil { + return + } + + _, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + _ = rw.Flush() + + go func() { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }() +} + +func (b *websocketUpgradeEchoBackend) Close() { + _ = b.server.Close() + _ = b.ln.Close() +} diff --git a/caddytest/integration/stream_reload_stress_test.go b/caddytest/integration/stream_reload_stress_test.go index 45473e221..ff140f9a4 100644 --- a/caddytest/integration/stream_reload_stress_test.go +++ b/caddytest/integration/stream_reload_stress_test.go @@ -21,9 +21,11 @@ import ( "github.com/caddyserver/caddy/v2/caddytest" ) -// stressCloseDelay is the stream_close_delay used for the close_delay scenario. -// Long enough to outlast all test reloads; short enough to keep total test time reasonable. -const stressCloseDelay = 3 * time.Second +const ( + defaultStressStreamCount = 1 + defaultStressReloadCount = 1 + defaultStressCloseDelay = 500 * time.Millisecond +) func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{ @@ -43,7 +45,7 @@ func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { // Reloads are spread across time and interleaved with echo-checks so // stream health is exercised at each reload boundary, not only at the end. legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0) - closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay) + closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t)) retain := runReloadStress(t, tester, backend.addr, "retain", true, 0) if legacy.aliveAfterReloads != 0 { @@ -110,8 +112,8 @@ func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode s const echoEvery = 6 // perform an echo check every N reloads - streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", 12) - reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", 24) + streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount) + reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount) tester.InitServer(reloadStressConfig(backendAddr, retain, closeDelay, 0), "caddyfile") @@ -209,6 +211,21 @@ func envIntOrDefault(t *testing.T, key string, def int) int { return v } +func stressCloseDelay(t *testing.T) time.Duration { + t.Helper() + + const key = "CADDY_STRESS_CLOSE_DELAY" + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return defaultStressCloseDelay + } + v, err := time.ParseDuration(raw) + if err != nil || v <= 0 { + t.Fatalf("invalid %s=%q: must be a positive duration", key, raw) + } + return v +} + func loadCaddyfileConfig(t *testing.T, rawConfig string) { t.Helper() diff --git a/modules/caddyhttp/encode/encode.go b/modules/caddyhttp/encode/encode.go index ecf85495a..0474768f0 100644 --- a/modules/caddyhttp/encode/encode.go +++ b/modules/caddyhttp/encode/encode.go @@ -405,11 +405,6 @@ func (rw *responseWriter) ReadFrom(r io.Reader) (int64, error) { // Close writes any remaining buffered response and // deallocates any active resources. func (rw *responseWriter) Close() error { - if caddyhttp.ResponseWriterHijacked(rw.ResponseWriter) { - rw.releaseEncoder() - return nil - } - // didn't write, probably head request if !rw.wroteHeader { cl, err := strconv.Atoi(rw.Header().Get("Content-Length")) diff --git a/modules/caddyhttp/reverseproxy/extended_connect_test.go b/modules/caddyhttp/reverseproxy/extended_connect_test.go new file mode 100644 index 000000000..5cb27d807 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/extended_connect_test.go @@ -0,0 +1,146 @@ +package reverseproxy + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "go.uber.org/zap" + + "github.com/caddyserver/caddy/v2/modules/caddyhttp" +) + +type extendedConnectCapture struct { + method string + headers http.Header + body []byte + extendedBodyPresent bool + extendedConnectBody []byte +} + +type extendedConnectCaptureTransport struct { + mu sync.Mutex + capture extendedConnectCapture +} + +func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + c := extendedConnectCapture{ + method: req.Method, + headers: req.Header.Clone(), + body: body, + } + if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { + c.extendedBodyPresent = true + c.extendedConnectBody, err = io.ReadAll(rc) + if err != nil { + return nil, err + } + _ = rc.Close() + } + + tr.mu.Lock() + tr.capture = c + tr.mu.Unlock() + + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("ok")), + Request: req, + }, nil +} + +func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture { + tr.mu.Lock() + defer tr.mu.Unlock() + return tr.capture +} + +func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) { + tests := []struct { + name string + protoMajor int + proto string + headers map[string]string + }{ + { + name: "h2 extended connect", + protoMajor: 2, + proto: "HTTP/2.0", + headers: map[string]string{ + ":protocol": "websocket", + }, + }, + { + name: "h3 extended connect", + protoMajor: 3, + proto: "websocket", + headers: map[string]string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + const payload = "extended-connect-body" + + transport := new(extendedConnectCaptureTransport) + h := &Handler{ + logger: zap.NewNop(), + Transport: transport, + Upstreams: UpstreamPool{ + &Upstream{Host: new(Host), Dial: "127.0.0.1:8443"}, + }, + LoadBalancing: &LoadBalancing{ + SelectionPolicy: &RoundRobinSelection{}, + }, + } + + req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload)) + req.ProtoMajor = tc.protoMajor + req.Proto = tc.proto + for key, value := range tc.headers { + req.Header.Set(key, value) + } + req = prepareTestRequest(req) + + rr := httptest.NewRecorder() + err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + return nil + })) + if err != nil { + t.Fatalf("ServeHTTP() error = %v", err) + } + + captured := transport.Snapshot() + if captured.method != http.MethodGet { + t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet) + } + if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") { + t.Fatalf("Upgrade header = %q, want websocket", got) + } + if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") { + t.Fatalf("Connection header = %q, want Upgrade", got) + } + if got := captured.headers.Get(":protocol"); got != "" { + t.Fatalf(":protocol header should be removed, got %q", got) + } + if len(captured.body) != 0 { + t.Fatalf("upstream request body length = %d, want 0", len(captured.body)) + } + if !captured.extendedBodyPresent { + t.Fatal("extended_connect_websocket_body variable missing from request context") + } + if string(captured.extendedConnectBody) != payload { + t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload) + } + }) + } +} diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 9aece01e8..407c4acf9 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -100,14 +100,14 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, streamTimeout := time.Duration(h.StreamTimeout) var ( - conn io.ReadWriteCloser - brw *bufio.ReadWriter - isH2 bool + conn io.ReadWriteCloser + brw *bufio.ReadWriter + isExtendedConnect bool ) // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade // TODO: once we can reliably detect backend support this, it can be removed for those backends if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { - isH2 = true + isExtendedConnect = true req.Body = body rw.Header().Del("Upgrade") rw.Header().Del("Connection") @@ -115,13 +115,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw.WriteHeader(http.StatusOK) if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil { - c.Write(zap.Int("http_version", 2)) + c.Write(zap.Int("http_version", req.ProtoMajor)) } //nolint:bodyclose flushErr := http.NewResponseController(rw).Flush() if flushErr != nil { - if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil { + if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil { c.Write(zap.Error(flushErr)) } return @@ -154,25 +154,6 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // For H2 extended connect: close backConn when the request context is - // cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections - // we skip this because req.Context() may be cancelled when ServeHTTP - // returns early, which would prematurely close the backend connection. - if isH2 { - // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 - backConnCloseCh := make(chan struct{}) - go func() { - // Ensure that the cancellation of a request closes the backend. - // See issue https://golang.org/issue/35559. - select { - case <-req.Context().Done(): - case <-backConnCloseCh: - } - backConn.Close() - }() - defer close(backConnCloseCh) - } - if err := brw.Flush(); err != nil { if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil { c.Write(zap.Error(err)) @@ -221,8 +202,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, start := time.Now() - if isH2 { - h.handleH2UpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) + if isExtendedConnect { + h.handleExtendedConnectUpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) } else { h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) // Return immediately without touching wg. finalizeResponse's @@ -230,7 +211,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } -func (h *Handler) handleH2UpgradeTunnel( +func (h *Handler) handleExtendedConnectUpgradeTunnel( streamLogger *zap.Logger, streamLevel zapcore.Level, wg *sync.WaitGroup, @@ -244,7 +225,7 @@ func (h *Handler) handleH2UpgradeTunnel( finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64), streamFields []zap.Field, ) { - // H2 extended connect: ServeHTTP must block because rw and req.Body are + // Extended CONNECT: ServeHTTP must block because rw and req.Body are // only valid while the handler goroutine is running. Defers clean up // when the select below fires and this function returns. defer deleteBackConn()