caddy/caddytest/integration/reverseproxy_extended_connect_test.go
2026-04-19 18:16:00 -04:00

329 lines
7.7 KiB
Go

package integration
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"github.com/caddyserver/caddy/v2/caddytest"
)
var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support")
func TestReverseProxyExtendedConnectOverH2(t *testing.T) {
tester := caddytest.NewTester(t)
backend := newWebsocketUpgradeEchoBackend(t)
defer backend.Close()
tester.InitServer(fmt.Sprintf(`
{
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
skip_install_trust
servers :9443 {
protocols h2
}
}
https://localhost:9443 {
reverse_proxy %s
}
`, backend.addr), "caddyfile")
const payload = "extended-connect-echo\n"
if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil {
if errors.Is(err, errExtendedConnectUnsupportedByPeer) {
t.Skipf("skipping extended CONNECT integration test: %v", err)
}
t.Fatalf("extended connect h2 echo failed: %v", err)
}
}
func assertExtendedConnectH2Echo(addr, payload string) error {
conn, err := tlsDialH2(addr)
if err != nil {
return fmt.Errorf("dialing h2 tls: %w", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return fmt.Errorf("setting deadline: %w", err)
}
fr := http2.NewFramer(conn, conn)
if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil {
return fmt.Errorf("writing client preface: %w", err)
}
if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil {
return fmt.Errorf("writing client settings: %w", err)
}
supported, err := waitForServerSettings(fr)
if err != nil {
return err
}
if !supported {
return errExtendedConnectUnsupportedByPeer
}
if err := waitForSettingsAck(fr); err != nil {
return err
}
if err := writeExtendedConnectHeaders(fr, addr); err != nil {
return err
}
status, err := readResponseStatus(fr, 1)
if err != nil {
return err
}
if status != "200" {
return fmt.Errorf("unexpected extended connect status: got=%s want=200", status)
}
if err := fr.WriteData(1, false, []byte(payload)); err != nil {
return fmt.Errorf("writing stream data: %w", err)
}
echo, err := readStreamData(fr, 1, len(payload))
if err != nil {
return err
}
if echo != payload {
return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload)
}
_ = fr.WriteRSTStream(1, http2.ErrCodeNo)
return nil
}
func tlsDialH2(addr string) (net.Conn, error) {
var lastErr error
for i := 0; i < 30; i++ {
dialer := &net.Dialer{Timeout: 2 * time.Second}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
ServerName: "localhost",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
})
if err == nil {
return conn, nil
}
lastErr = err
time.Sleep(100 * time.Millisecond)
}
return nil, lastErr
}
func waitForServerSettings(fr *http2.Framer) (bool, error) {
for {
frame, err := fr.ReadFrame()
if err != nil {
return false, fmt.Errorf("reading frame before connect: %w", err)
}
settings, ok := frame.(*http2.SettingsFrame)
if !ok {
continue
}
if settings.IsAck() {
continue
}
supported := false
if err := settings.ForeachSetting(func(s http2.Setting) error {
if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 {
supported = true
}
return nil
}); err != nil {
return false, fmt.Errorf("reading server settings: %w", err)
}
if err := fr.WriteSettingsAck(); err != nil {
return false, fmt.Errorf("writing settings ack: %w", err)
}
return supported, nil
}
}
func waitForSettingsAck(fr *http2.Framer) error {
for {
frame, err := fr.ReadFrame()
if err != nil {
return fmt.Errorf("reading settings ack: %w", err)
}
settings, ok := frame.(*http2.SettingsFrame)
if ok && settings.IsAck() {
return nil
}
}
}
func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error {
var hb bytes.Buffer
enc := hpack.NewEncoder(&hb)
for _, hf := range []hpack.HeaderField{
{Name: ":method", Value: "CONNECT"},
{Name: ":scheme", Value: "https"},
{Name: ":authority", Value: addr},
{Name: ":path", Value: "/upgrade"},
{Name: ":protocol", Value: "websocket"},
} {
if err := enc.WriteField(hf); err != nil {
return fmt.Errorf("encoding request headers: %w", err)
}
}
if err := fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1,
BlockFragment: hb.Bytes(),
EndHeaders: true,
EndStream: false,
}); err != nil {
return fmt.Errorf("writing extended connect headers: %w", err)
}
return nil
}
func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) {
var block bytes.Buffer
for {
frame, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading response headers: %w", err)
}
if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID {
return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode)
}
h, ok := frame.(*http2.HeadersFrame)
if !ok || h.StreamID != streamID {
continue
}
if _, err := block.Write(h.HeaderBlockFragment()); err != nil {
return "", fmt.Errorf("buffering response header fragment: %w", err)
}
for !h.HeadersEnded() {
next, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading continuation frame: %w", err)
}
c, ok := next.(*http2.ContinuationFrame)
if !ok || c.StreamID != streamID {
continue
}
if _, err := block.Write(c.HeaderBlockFragment()); err != nil {
return "", fmt.Errorf("buffering continuation fragment: %w", err)
}
if c.HeadersEnded() {
break
}
}
break
}
var status string
dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) {
if f.Name == ":status" {
status = f.Value
}
})
if _, err := dec.Write(block.Bytes()); err != nil {
return "", fmt.Errorf("decoding response header block: %w", err)
}
if status == "" {
return "", fmt.Errorf("missing :status in response headers")
}
return status, nil
}
func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) {
buf := make([]byte, 0, n)
for len(buf) < n {
frame, err := fr.ReadFrame()
if err != nil {
return "", fmt.Errorf("reading stream data: %w", err)
}
d, ok := frame.(*http2.DataFrame)
if !ok || d.StreamID != streamID {
continue
}
buf = append(buf, d.Data()...)
}
return string(buf[:n]), nil
}
type websocketUpgradeEchoBackend struct {
addr string
ln net.Listener
server *http.Server
}
func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend {
t.Helper()
backend := &websocketUpgradeEchoBackend{}
backend.server = &http.Server{
Handler: http.HandlerFunc(backend.serveHTTP),
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listening for websocket backend: %v", err)
}
backend.ln = ln
backend.addr = ln.Addr().String()
go func() {
_ = backend.server.Serve(ln)
}()
return backend
}
func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
conn, rw, err := hijacker.Hijack()
if err != nil {
return
}
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
_ = rw.Flush()
go func() {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}()
}
func (b *websocketUpgradeEchoBackend) Close() {
_ = b.server.Close()
_ = b.ln.Close()
}