mirror of
https://github.com/caddyserver/caddy.git
synced 2026-04-23 17:39:33 -04:00
Adjustments from Weidi's review
This commit is contained in:
parent
5a1ace3e91
commit
db86fdaba2
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@ -132,6 +132,8 @@ jobs:
|
||||
- name: Run tests
|
||||
# id: step_test
|
||||
# continue-on-error: true
|
||||
env:
|
||||
GODEBUG: http2xconnect=1
|
||||
run: |
|
||||
# (go test -v -coverprofile=cover-profile.out -race ./... 2>&1) > test-results/test-result.out
|
||||
go test -v -coverprofile="cover-profile.out" -short -race ./...
|
||||
@ -191,7 +193,7 @@ jobs:
|
||||
retries=3
|
||||
exit_code=0
|
||||
while ((retries > 0)); do
|
||||
CGO_ENABLED=0 go test -p 1 -v ./...
|
||||
GODEBUG=http2xconnect=1 CGO_ENABLED=0 go test -p 1 -v ./...
|
||||
exit_code=$?
|
||||
if ((exit_code == 0)); then
|
||||
break
|
||||
|
||||
328
caddytest/integration/reverseproxy_extended_connect_test.go
Normal file
328
caddytest/integration/reverseproxy_extended_connect_test.go
Normal file
@ -0,0 +1,328 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddytest"
|
||||
)
|
||||
|
||||
var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support")
|
||||
|
||||
func TestReverseProxyExtendedConnectOverH2(t *testing.T) {
|
||||
tester := caddytest.NewTester(t)
|
||||
backend := newWebsocketUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
tester.InitServer(fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
servers :9443 {
|
||||
protocols h2
|
||||
}
|
||||
}
|
||||
|
||||
https://localhost:9443 {
|
||||
reverse_proxy %s
|
||||
}
|
||||
`, backend.addr), "caddyfile")
|
||||
|
||||
const payload = "extended-connect-echo\n"
|
||||
if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil {
|
||||
if errors.Is(err, errExtendedConnectUnsupportedByPeer) {
|
||||
t.Skipf("skipping extended CONNECT integration test: %v", err)
|
||||
}
|
||||
t.Fatalf("extended connect h2 echo failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertExtendedConnectH2Echo(addr, payload string) error {
|
||||
conn, err := tlsDialH2(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing h2 tls: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return fmt.Errorf("setting deadline: %w", err)
|
||||
}
|
||||
|
||||
fr := http2.NewFramer(conn, conn)
|
||||
|
||||
if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||
return fmt.Errorf("writing client preface: %w", err)
|
||||
}
|
||||
if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil {
|
||||
return fmt.Errorf("writing client settings: %w", err)
|
||||
}
|
||||
|
||||
supported, err := waitForServerSettings(fr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !supported {
|
||||
return errExtendedConnectUnsupportedByPeer
|
||||
}
|
||||
if err := waitForSettingsAck(fr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeExtendedConnectHeaders(fr, addr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := readResponseStatus(fr, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if status != "200" {
|
||||
return fmt.Errorf("unexpected extended connect status: got=%s want=200", status)
|
||||
}
|
||||
|
||||
if err := fr.WriteData(1, false, []byte(payload)); err != nil {
|
||||
return fmt.Errorf("writing stream data: %w", err)
|
||||
}
|
||||
|
||||
echo, err := readStreamData(fr, 1, len(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if echo != payload {
|
||||
return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload)
|
||||
}
|
||||
|
||||
_ = fr.WriteRSTStream(1, http2.ErrCodeNo)
|
||||
return nil
|
||||
}
|
||||
|
||||
func tlsDialH2(addr string) (net.Conn, error) {
|
||||
var lastErr error
|
||||
for i := 0; i < 30; i++ {
|
||||
dialer := &net.Dialer{Timeout: 2 * time.Second}
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
|
||||
ServerName: "localhost",
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{"h2"},
|
||||
})
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func waitForServerSettings(fr *http2.Framer) (bool, error) {
|
||||
for {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("reading frame before connect: %w", err)
|
||||
}
|
||||
settings, ok := frame.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if settings.IsAck() {
|
||||
continue
|
||||
}
|
||||
|
||||
supported := false
|
||||
if err := settings.ForeachSetting(func(s http2.Setting) error {
|
||||
if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 {
|
||||
supported = true
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return false, fmt.Errorf("reading server settings: %w", err)
|
||||
}
|
||||
|
||||
if err := fr.WriteSettingsAck(); err != nil {
|
||||
return false, fmt.Errorf("writing settings ack: %w", err)
|
||||
}
|
||||
return supported, nil
|
||||
}
|
||||
}
|
||||
|
||||
func waitForSettingsAck(fr *http2.Framer) error {
|
||||
for {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading settings ack: %w", err)
|
||||
}
|
||||
settings, ok := frame.(*http2.SettingsFrame)
|
||||
if ok && settings.IsAck() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error {
|
||||
var hb bytes.Buffer
|
||||
enc := hpack.NewEncoder(&hb)
|
||||
for _, hf := range []hpack.HeaderField{
|
||||
{Name: ":method", Value: "CONNECT"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":authority", Value: addr},
|
||||
{Name: ":path", Value: "/upgrade"},
|
||||
{Name: ":protocol", Value: "websocket"},
|
||||
} {
|
||||
if err := enc.WriteField(hf); err != nil {
|
||||
return fmt.Errorf("encoding request headers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fr.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 1,
|
||||
BlockFragment: hb.Bytes(),
|
||||
EndHeaders: true,
|
||||
EndStream: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("writing extended connect headers: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) {
|
||||
var block bytes.Buffer
|
||||
|
||||
for {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading response headers: %w", err)
|
||||
}
|
||||
if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID {
|
||||
return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode)
|
||||
}
|
||||
|
||||
h, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok || h.StreamID != streamID {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := block.Write(h.HeaderBlockFragment()); err != nil {
|
||||
return "", fmt.Errorf("buffering response header fragment: %w", err)
|
||||
}
|
||||
for !h.HeadersEnded() {
|
||||
next, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading continuation frame: %w", err)
|
||||
}
|
||||
c, ok := next.(*http2.ContinuationFrame)
|
||||
if !ok || c.StreamID != streamID {
|
||||
continue
|
||||
}
|
||||
if _, err := block.Write(c.HeaderBlockFragment()); err != nil {
|
||||
return "", fmt.Errorf("buffering continuation fragment: %w", err)
|
||||
}
|
||||
if c.HeadersEnded() {
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var status string
|
||||
dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) {
|
||||
if f.Name == ":status" {
|
||||
status = f.Value
|
||||
}
|
||||
})
|
||||
if _, err := dec.Write(block.Bytes()); err != nil {
|
||||
return "", fmt.Errorf("decoding response header block: %w", err)
|
||||
}
|
||||
if status == "" {
|
||||
return "", fmt.Errorf("missing :status in response headers")
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) {
|
||||
buf := make([]byte, 0, n)
|
||||
for len(buf) < n {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading stream data: %w", err)
|
||||
}
|
||||
d, ok := frame.(*http2.DataFrame)
|
||||
if !ok || d.StreamID != streamID {
|
||||
continue
|
||||
}
|
||||
buf = append(buf, d.Data()...)
|
||||
}
|
||||
return string(buf[:n]), nil
|
||||
}
|
||||
|
||||
type websocketUpgradeEchoBackend struct {
|
||||
addr string
|
||||
ln net.Listener
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend {
|
||||
t.Helper()
|
||||
|
||||
backend := &websocketUpgradeEchoBackend{}
|
||||
backend.server = &http.Server{
|
||||
Handler: http.HandlerFunc(backend.serveHTTP),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listening for websocket backend: %v", err)
|
||||
}
|
||||
backend.ln = ln
|
||||
backend.addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
_ = backend.server.Serve(ln)
|
||||
}()
|
||||
|
||||
return backend
|
||||
}
|
||||
|
||||
func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
|
||||
_ = rw.Flush()
|
||||
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *websocketUpgradeEchoBackend) Close() {
|
||||
_ = b.server.Close()
|
||||
_ = b.ln.Close()
|
||||
}
|
||||
@ -21,9 +21,11 @@ import (
|
||||
"github.com/caddyserver/caddy/v2/caddytest"
|
||||
)
|
||||
|
||||
// stressCloseDelay is the stream_close_delay used for the close_delay scenario.
|
||||
// Long enough to outlast all test reloads; short enough to keep total test time reasonable.
|
||||
const stressCloseDelay = 3 * time.Second
|
||||
const (
|
||||
defaultStressStreamCount = 1
|
||||
defaultStressReloadCount = 1
|
||||
defaultStressCloseDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) {
|
||||
tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{
|
||||
@ -43,7 +45,7 @@ func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) {
|
||||
// Reloads are spread across time and interleaved with echo-checks so
|
||||
// stream health is exercised at each reload boundary, not only at the end.
|
||||
legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0)
|
||||
closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay)
|
||||
closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t))
|
||||
retain := runReloadStress(t, tester, backend.addr, "retain", true, 0)
|
||||
|
||||
if legacy.aliveAfterReloads != 0 {
|
||||
@ -110,8 +112,8 @@ func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode s
|
||||
|
||||
const echoEvery = 6 // perform an echo check every N reloads
|
||||
|
||||
streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", 12)
|
||||
reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", 24)
|
||||
streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount)
|
||||
reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount)
|
||||
|
||||
tester.InitServer(reloadStressConfig(backendAddr, retain, closeDelay, 0), "caddyfile")
|
||||
|
||||
@ -209,6 +211,21 @@ func envIntOrDefault(t *testing.T, key string, def int) int {
|
||||
return v
|
||||
}
|
||||
|
||||
func stressCloseDelay(t *testing.T) time.Duration {
|
||||
t.Helper()
|
||||
|
||||
const key = "CADDY_STRESS_CLOSE_DELAY"
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
return defaultStressCloseDelay
|
||||
}
|
||||
v, err := time.ParseDuration(raw)
|
||||
if err != nil || v <= 0 {
|
||||
t.Fatalf("invalid %s=%q: must be a positive duration", key, raw)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func loadCaddyfileConfig(t *testing.T, rawConfig string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@ -405,11 +405,6 @@ func (rw *responseWriter) ReadFrom(r io.Reader) (int64, error) {
|
||||
// Close writes any remaining buffered response and
|
||||
// deallocates any active resources.
|
||||
func (rw *responseWriter) Close() error {
|
||||
if caddyhttp.ResponseWriterHijacked(rw.ResponseWriter) {
|
||||
rw.releaseEncoder()
|
||||
return nil
|
||||
}
|
||||
|
||||
// didn't write, probably head request
|
||||
if !rw.wroteHeader {
|
||||
cl, err := strconv.Atoi(rw.Header().Get("Content-Length"))
|
||||
|
||||
146
modules/caddyhttp/reverseproxy/extended_connect_test.go
Normal file
146
modules/caddyhttp/reverseproxy/extended_connect_test.go
Normal file
@ -0,0 +1,146 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
type extendedConnectCapture struct {
|
||||
method string
|
||||
headers http.Header
|
||||
body []byte
|
||||
extendedBodyPresent bool
|
||||
extendedConnectBody []byte
|
||||
}
|
||||
|
||||
type extendedConnectCaptureTransport struct {
|
||||
mu sync.Mutex
|
||||
capture extendedConnectCapture
|
||||
}
|
||||
|
||||
func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := extendedConnectCapture{
|
||||
method: req.Method,
|
||||
headers: req.Header.Clone(),
|
||||
body: body,
|
||||
}
|
||||
if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
|
||||
c.extendedBodyPresent = true
|
||||
c.extendedConnectBody, err = io.ReadAll(rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = rc.Close()
|
||||
}
|
||||
|
||||
tr.mu.Lock()
|
||||
tr.capture = c
|
||||
tr.mu.Unlock()
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture {
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
return tr.capture
|
||||
}
|
||||
|
||||
func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protoMajor int
|
||||
proto string
|
||||
headers map[string]string
|
||||
}{
|
||||
{
|
||||
name: "h2 extended connect",
|
||||
protoMajor: 2,
|
||||
proto: "HTTP/2.0",
|
||||
headers: map[string]string{
|
||||
":protocol": "websocket",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "h3 extended connect",
|
||||
protoMajor: 3,
|
||||
proto: "websocket",
|
||||
headers: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const payload = "extended-connect-body"
|
||||
|
||||
transport := new(extendedConnectCaptureTransport)
|
||||
h := &Handler{
|
||||
logger: zap.NewNop(),
|
||||
Transport: transport,
|
||||
Upstreams: UpstreamPool{
|
||||
&Upstream{Host: new(Host), Dial: "127.0.0.1:8443"},
|
||||
},
|
||||
LoadBalancing: &LoadBalancing{
|
||||
SelectionPolicy: &RoundRobinSelection{},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload))
|
||||
req.ProtoMajor = tc.protoMajor
|
||||
req.Proto = tc.proto
|
||||
for key, value := range tc.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
req = prepareTestRequest(req)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP() error = %v", err)
|
||||
}
|
||||
|
||||
captured := transport.Snapshot()
|
||||
if captured.method != http.MethodGet {
|
||||
t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet)
|
||||
}
|
||||
if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") {
|
||||
t.Fatalf("Upgrade header = %q, want websocket", got)
|
||||
}
|
||||
if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") {
|
||||
t.Fatalf("Connection header = %q, want Upgrade", got)
|
||||
}
|
||||
if got := captured.headers.Get(":protocol"); got != "" {
|
||||
t.Fatalf(":protocol header should be removed, got %q", got)
|
||||
}
|
||||
if len(captured.body) != 0 {
|
||||
t.Fatalf("upstream request body length = %d, want 0", len(captured.body))
|
||||
}
|
||||
if !captured.extendedBodyPresent {
|
||||
t.Fatal("extended_connect_websocket_body variable missing from request context")
|
||||
}
|
||||
if string(captured.extendedConnectBody) != payload {
|
||||
t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -100,14 +100,14 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
streamTimeout := time.Duration(h.StreamTimeout)
|
||||
|
||||
var (
|
||||
conn io.ReadWriteCloser
|
||||
brw *bufio.ReadWriter
|
||||
isH2 bool
|
||||
conn io.ReadWriteCloser
|
||||
brw *bufio.ReadWriter
|
||||
isExtendedConnect bool
|
||||
)
|
||||
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
|
||||
// TODO: once we can reliably detect backend support this, it can be removed for those backends
|
||||
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
|
||||
isH2 = true
|
||||
isExtendedConnect = true
|
||||
req.Body = body
|
||||
rw.Header().Del("Upgrade")
|
||||
rw.Header().Del("Connection")
|
||||
@ -115,13 +115,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
|
||||
c.Write(zap.Int("http_version", 2))
|
||||
c.Write(zap.Int("http_version", req.ProtoMajor))
|
||||
}
|
||||
|
||||
//nolint:bodyclose
|
||||
flushErr := http.NewResponseController(rw).Flush()
|
||||
if flushErr != nil {
|
||||
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
|
||||
if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil {
|
||||
c.Write(zap.Error(flushErr))
|
||||
}
|
||||
return
|
||||
@ -154,25 +154,6 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
}
|
||||
}
|
||||
|
||||
// For H2 extended connect: close backConn when the request context is
|
||||
// cancelled (e.g. client disconnects). For HTTP/1.1 hijacked connections
|
||||
// we skip this because req.Context() may be cancelled when ServeHTTP
|
||||
// returns early, which would prematurely close the backend connection.
|
||||
if isH2 {
|
||||
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
|
||||
backConnCloseCh := make(chan struct{})
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
}
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
@ -221,8 +202,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
|
||||
start := time.Now()
|
||||
|
||||
if isH2 {
|
||||
h.handleH2UpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
|
||||
if isExtendedConnect {
|
||||
h.handleExtendedConnectUpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
|
||||
} else {
|
||||
h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
|
||||
// Return immediately without touching wg. finalizeResponse's
|
||||
@ -230,7 +211,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleH2UpgradeTunnel(
|
||||
func (h *Handler) handleExtendedConnectUpgradeTunnel(
|
||||
streamLogger *zap.Logger,
|
||||
streamLevel zapcore.Level,
|
||||
wg *sync.WaitGroup,
|
||||
@ -244,7 +225,7 @@ func (h *Handler) handleH2UpgradeTunnel(
|
||||
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
|
||||
streamFields []zap.Field,
|
||||
) {
|
||||
// H2 extended connect: ServeHTTP must block because rw and req.Body are
|
||||
// Extended CONNECT: ServeHTTP must block because rw and req.Body are
|
||||
// only valid while the handler goroutine is running. Defers clean up
|
||||
// when the select below fires and this function returns.
|
||||
defer deleteBackConn()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user