record bytes read and written for response writers unless detached

This commit is contained in:
WeidiDeng 2026-04-21 10:06:30 +08:00
parent db86fdaba2
commit 4628aea894
No known key found for this signature in database
GPG Key ID: 25F87CE1741EC7CD
2 changed files with 74 additions and 24 deletions

View File

@ -21,6 +21,8 @@ import (
"io"
"net"
"net/http"
"github.com/caddyserver/caddy/v2"
)
// ResponseWriterWrapper wraps an underlying ResponseWriter and
@ -71,6 +73,7 @@ type responseRecorder struct {
wroteHeader bool
stream bool
hijacked bool
detached bool
readSize *int
}
@ -155,12 +158,6 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
// save statusCode always, in case HTTP middleware upgrades websocket
// connections by manually setting headers and writing status 101
rr.statusCode = statusCode
if statusCode == http.StatusSwitchingProtocols {
rr.stream = true
rr.wroteHeader = true
rr.ResponseWriterWrapper.WriteHeader(statusCode)
return
}
// decide whether we should buffer the response
if rr.shouldBuffer == nil {
@ -169,12 +166,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header())
}
// 1xx responses aren't final; just informational
if statusCode < 100 || statusCode > 199 {
// 1xx responses except 101 aren't final; just informational
if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols {
rr.wroteHeader = true
}
// if informational or not buffered, immediately write header
// if 1xx or not buffered, immediately write header
if rr.stream || (100 <= statusCode && statusCode <= 199) {
rr.ResponseWriterWrapper.WriteHeader(statusCode)
}
@ -230,8 +227,12 @@ func (rr *responseRecorder) Buffered() bool {
return !rr.stream
}
func (rr *responseRecorder) Hijacked() bool {
return rr.hijacked
func (rr *responseRecorder) DetachAfterHijack(detached bool) bool {
if rr.hijacked {
return false
}
rr.detached = detached
return true
}
func (rr *responseRecorder) WriteResponse() error {
@ -268,6 +269,12 @@ func (rr *responseRecorder) setReadSize(size *int) {
}
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !rr.wroteHeader {
// hijacking without writing status code first works as long as subsequent writes follows http1.1
// wire format, but it will show up with a status code of 0 in the access log and bytes written
// will include response headers.
caddy.Log().Debug("hijacking without writing status code first")
}
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
if err != nil {
@ -276,13 +283,16 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rr.hijacked = true
rr.stream = true
rr.wroteHeader = true
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not.
// Return the raw hijacked connection so upgraded stream traffic does not keep
// traversing the response recorder hot path.
if rr.detached {
return conn, brw, nil
}
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)
buffered := brw.Reader.Buffered()
if buffered != 0 {
conn.(*hijackedConn).updateReadSize(buffered)
data, _ := brw.Peek(buffered)
brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn))
// peek to make buffered data appear, as Reset will make it 0
@ -293,12 +303,49 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return conn, brw, nil
}
// ResponseWriterHijacked reports whether w or one of its wrapped response
// writers has been hijacked.
func ResponseWriterHijacked(w http.ResponseWriter) bool {
// used to track the size of hijacked response writers
type hijackedConn struct {
net.Conn
rr *responseRecorder
}
func (hc *hijackedConn) updateReadSize(n int) {
if hc.rr.readSize != nil {
*hc.rr.readSize += n
}
}
func (hc *hijackedConn) Read(p []byte) (int, error) {
n, err := hc.Conn.Read(p)
hc.updateReadSize(n)
return n, err
}
func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) {
n, err := io.Copy(w, hc.Conn)
hc.updateReadSize(int(n))
return n, err
}
func (hc *hijackedConn) Write(p []byte) (int, error) {
n, err := hc.Conn.Write(p)
hc.rr.size += n
return n, err
}
func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
n, err := io.Copy(hc.Conn, r)
hc.rr.size += int(n)
return n, err
}
// DetachResponseWriterAfterHijack detaches w or one of its wrapped response
// writers when it's hijacked. Returns true if not already hijacked.
// When detached, bytes read or written stats will not be recorded for the hijacked connection, and it's safe to use the connection after http middleware returns.
func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool {
for w != nil {
if hijacked, ok := w.(interface{ Hijacked() bool }); ok && hijacked.Hijacked() {
return true
if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok {
return detacher.DetachAfterHijack(detached)
}
unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter })
if !ok {
@ -321,7 +368,7 @@ type ResponseRecorder interface {
Status() int
Buffer() *bytes.Buffer
Buffered() bool
Hijacked() bool
DetachAfterHijack(bool) bool
Size() int
WriteResponse() error
}
@ -341,4 +388,7 @@ var (
// see PR #5022 (25%-50% speedup)
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
_ io.ReaderFrom = (*hijackedConn)(nil)
_ io.WriterTo = (*hijackedConn)(nil)
)

View File

@ -246,11 +246,11 @@ func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) {
}
defer conn.Close()
if !rr.Hijacked() {
t.Fatal("response recorder should report hijacked state")
if rr.DetachAfterHijack(true) {
t.Fatal("response recorder should report hijacked state by returning false")
}
if !ResponseWriterHijacked(rr) {
t.Fatal("ResponseWriterHijacked() should report true after hijack")
if DetachResponseWriterAfterHijack(rr, true) {
t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack")
}
if err := rr.WriteResponse(); err != nil {
t.Fatalf("WriteResponse() after hijack returned error: %v", err)