simplify streaming handling

This commit is contained in:
WeidiDeng 2026-04-21 10:07:13 +08:00
parent 4628aea894
commit 3f257bbccc
No known key found for this signature in database
GPG Key ID: 25F87CE1741EC7CD
3 changed files with 33 additions and 128 deletions

View File

@ -422,20 +422,13 @@ func (rw *responseWriter) Close() error {
var err error var err error
if rw.w != nil { if rw.w != nil {
err = rw.w.Close() err = rw.w.Close()
rw.releaseEncoder() rw.w.Reset(nil)
rw.config.writerPools[rw.encodingName].Put(rw.w)
rw.w = nil
} }
return err return err
} }
func (rw *responseWriter) releaseEncoder() {
if rw.w == nil {
return
}
rw.w.Reset(nil)
rw.config.writerPools[rw.encodingName].Put(rw.w)
rw.w = nil
}
// Unwrap returns the underlying ResponseWriter. // Unwrap returns the underlying ResponseWriter.
func (rw *responseWriter) Unwrap() http.ResponseWriter { func (rw *responseWriter) Unwrap() http.ResponseWriter {
return rw.ResponseWriter return rw.ResponseWriter

View File

@ -191,6 +191,9 @@ type Handler struct {
// Connections using upstreams that are removed are closed during cleanup. // Connections using upstreams that are removed are closed during cleanup.
// By default this is false, preserving legacy behavior where upgraded // By default this is false, preserving legacy behavior where upgraded
// connections are closed on reload (optionally delayed by stream_close_delay). // connections are closed on reload (optionally delayed by stream_close_delay).
// Only http1.1 websocket connections are affected, websockets for h2/h3 are not affected.
// If true, bytes transferred for http1.1 in the access logs will be zero but those stats
// can be found in the stream logs for http1/2/3 regardless if this is enabled.
StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"` StreamRetainOnReload bool `json:"stream_retain_on_reload,omitempty"`
// Controls logging behavior for upgraded stream lifecycle events. // Controls logging behavior for upgraded stream lifecycle events.
@ -1239,9 +1242,7 @@ func (h *Handler) finalizeResponse(
) error { ) error {
// deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) // deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols { if res.StatusCode == http.StatusSwitchingProtocols {
var wg sync.WaitGroup h.handleUpgradeResponse(logger, rw, req, res, upstreamAddr)
h.handleUpgradeResponse(logger, &wg, rw, req, res, upstreamAddr)
wg.Wait()
return nil return nil
} }

View File

@ -40,12 +40,12 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/caddy/v2/modules/caddyhttp"
) )
type h2ReadWriteCloser struct { type extendedConnectReadWriteCloser struct {
io.ReadCloser io.ReadCloser
http.ResponseWriter http.ResponseWriter
} }
func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) {
n, err = rwc.ResponseWriter.Write(p) n, err = rwc.ResponseWriter.Write(p)
if err != nil { if err != nil {
return 0, err return 0, err
@ -59,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
return n, nil return n, nil
} }
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) { func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) {
reqUpType := upgradeType(req.Header) reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header) resUpType := upgradeType(res.Header)
@ -99,15 +99,25 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
bufferSize := h.StreamBufferSize bufferSize := h.StreamBufferSize
streamTimeout := time.Duration(h.StreamTimeout) streamTimeout := time.Duration(h.StreamTimeout)
if h.StreamRetainOnReload {
// the return value should be true as it's not hijacked yet, but some middleware may wrap response writers incorrectly
if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) {
if c := logger.Check(zap.DebugLevel, "detaching connection failed"); c != nil {
c.Write(zap.String("tip", "check if your response writers have an Unwrap method or if already hijacked"))
}
}
}
var ( var (
conn io.ReadWriteCloser conn io.ReadWriteCloser
brw *bufio.ReadWriter brw *bufio.ReadWriter
isExtendedConnect bool detached = h.StreamRetainOnReload
) )
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can be removed for those backends // TODO: once we can reliably detect backend support this, it can be removed for those backends
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
isExtendedConnect = true // websocket over extended connect can't be detached. rw and req.Body are only valid while the handler goroutine is running
detached = false
req.Body = body req.Body = body
rw.Header().Del("Upgrade") rw.Header().Del("Upgrade")
rw.Header().Del("Connection") rw.Header().Del("Connection")
@ -126,7 +136,7 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
} }
return return
} }
conn = h2ReadWriteCloser{req.Body, rw} conn = extendedConnectReadWriteCloser{req.Body, rw}
// bufio is not needed, use minimal buffer // bufio is not needed, use minimal buffer
brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
} else { } else {
@ -202,35 +212,20 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
start := time.Now() start := time.Now()
if isExtendedConnect { if !detached {
h.handleExtendedConnectUpgradeTunnel(streamLogger, streamLevel, wg, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) h.handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
} else { } else {
h.handleDetachedUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields) // start a new goroutine
// Return immediately without touching wg. finalizeResponse's go h.handleUpgradeTunnel(streamLogger, streamLevel, conn, backConn, deleteFrontConn, deleteBackConn, bufferSize, streamTimeout, start, finishMetrics, streamFields)
// wg.Wait() returns at once since wg was never incremented.
} }
} }
func (h *Handler) handleExtendedConnectUpgradeTunnel( // handleUpgradeTunnel returns when transfer is done.
streamLogger *zap.Logger, func (h *Handler) handleUpgradeTunnel(streamLogger *zap.Logger, streamLevel zapcore.Level, conn io.ReadWriteCloser, backConn io.ReadWriteCloser, deleteFrontConn func(), deleteBackConn func(), bufferSize int, streamTimeout time.Duration, start time.Time, finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64), streamFields []zap.Field) {
streamLevel zapcore.Level,
wg *sync.WaitGroup,
conn io.ReadWriteCloser,
backConn io.ReadWriteCloser,
deleteFrontConn func(),
deleteBackConn func(),
bufferSize int,
streamTimeout time.Duration,
start time.Time,
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
streamFields []zap.Field,
) {
// Extended CONNECT: ServeHTTP must block because rw and req.Body are
// only valid while the handler goroutine is running. Defers clean up
// when the select below fires and this function returns.
defer deleteBackConn() defer deleteBackConn()
defer deleteFrontConn() defer deleteFrontConn()
var ( var (
wg sync.WaitGroup
toBackend int64 toBackend int64
fromBackend int64 fromBackend int64
result string result string
@ -243,7 +238,7 @@ func (h *Handler) handleExtendedConnectUpgradeTunnel(
spc := switchProtocolCopier{ spc := switchProtocolCopier{
user: conn, user: conn,
backend: backConn, backend: backConn,
wg: wg, wg: &wg,
bufferSize: bufferSize, bufferSize: bufferSize,
sent: &toBackend, sent: &toBackend,
received: &fromBackend, received: &fromBackend,
@ -290,90 +285,6 @@ func (h *Handler) handleExtendedConnectUpgradeTunnel(
} }
} }
func (h *Handler) handleDetachedUpgradeTunnel(
streamLogger *zap.Logger,
streamLevel zapcore.Level,
conn io.ReadWriteCloser,
backConn io.ReadWriteCloser,
deleteFrontConn func(),
deleteBackConn func(),
bufferSize int,
streamTimeout time.Duration,
start time.Time,
finishMetrics func(result string, duration time.Duration, toBackend, fromBackend int64),
streamFields []zap.Field,
) {
// HTTP/1.1 hijacked connection: launch a detached goroutine so that
// ServeHTTP can return immediately, allowing the Handler to be GC'd
// after a config reload. The goroutine captures only tunnel (a small
// *tunnelState), logger, conn/backConn, and scalar config values.
go func() {
var (
toBackend int64
fromBackend int64
result = "closed"
)
defer deleteBackConn()
defer deleteFrontConn()
defer func() {
finishMetrics(result, time.Since(start), toBackend, fromBackend)
if c := streamLogger.Check(streamLevel, "connection closed"); c != nil {
fields := append([]zap.Field{}, streamFields...)
fields = append(fields,
zap.Duration("duration", time.Since(start)),
zap.Int64("bytes_to_backend", toBackend),
zap.Int64("bytes_from_backend", fromBackend),
)
c.Write(fields...)
}
}()
var innerWg sync.WaitGroup
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
spc := switchProtocolCopier{
user: conn,
backend: backConn,
wg: &innerWg,
bufferSize: bufferSize,
sent: &toBackend,
received: &fromBackend,
}
innerWg.Add(2)
var timeoutc <-chan time.Time
if streamTimeout > 0 {
timer := time.NewTimer(streamTimeout)
defer timer.Stop()
timeoutc = timer.C
}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
select {
case err := <-errc:
result = classifyStreamResult(err)
if c := streamLogger.Check(streamLevel, "streaming error"); c != nil {
c.Write(zap.Error(err))
}
case t := <-timeoutc:
result = "timeout"
if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", t))
}
}
// Close both ends to unblock the still-running copy goroutine,
// then wait for it to finish so byte counts are accurate before
// the deferred log fires.
conn.Close()
backConn.Close()
innerWg.Wait()
}()
}
func classifyStreamResult(err error) string { func classifyStreamResult(err error) string {
if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { if err == nil || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return "closed" return "closed"