diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index d5b43bf42..a477c7abe 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -21,6 +21,8 @@ import ( "io" "net" "net/http" + + "github.com/caddyserver/caddy/v2" ) // ResponseWriterWrapper wraps an underlying ResponseWriter and @@ -71,6 +73,7 @@ type responseRecorder struct { wroteHeader bool stream bool hijacked bool + detached bool readSize *int } @@ -155,12 +158,6 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { // save statusCode always, in case HTTP middleware upgrades websocket // connections by manually setting headers and writing status 101 rr.statusCode = statusCode - if statusCode == http.StatusSwitchingProtocols { - rr.stream = true - rr.wroteHeader = true - rr.ResponseWriterWrapper.WriteHeader(statusCode) - return - } // decide whether we should buffer the response if rr.shouldBuffer == nil { @@ -169,12 +166,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header()) } - // 1xx responses aren't final; just informational - if statusCode < 100 || statusCode > 199 { + // 1xx responses except 101 aren't final; just informational + if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols { rr.wroteHeader = true } - // if informational or not buffered, immediately write header + // if 1xx or not buffered, immediately write header if rr.stream || (100 <= statusCode && statusCode <= 199) { rr.ResponseWriterWrapper.WriteHeader(statusCode) } @@ -230,8 +227,12 @@ func (rr *responseRecorder) Buffered() bool { return !rr.stream } -func (rr *responseRecorder) Hijacked() bool { - return rr.hijacked +func (rr *responseRecorder) DetachAfterHijack(detached bool) bool { + if rr.hijacked { + return false + } + rr.detached = detached + return true } func (rr *responseRecorder) WriteResponse() error { @@ -268,6 +269,12 @@ func (rr *responseRecorder) setReadSize(size *int) { } func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !rr.wroteHeader { + // hijacking without writing status code first works as long as subsequent writes follows http1.1 + // wire format, but it will show up with a status code of 0 in the access log and bytes written + // will include response headers. + caddy.Log().Debug("hijacking without writing status code first") + } //nolint:bodyclose conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack() if err != nil { @@ -276,13 +283,16 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { rr.hijacked = true rr.stream = true rr.wroteHeader = true - // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not. - // Return the raw hijacked connection so upgraded stream traffic does not keep - // traversing the response recorder hot path. + if rr.detached { + return conn, brw, nil + } + // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not + conn = &hijackedConn{conn, rr} brw.Writer.Reset(conn) buffered := brw.Reader.Buffered() if buffered != 0 { + conn.(*hijackedConn).updateReadSize(buffered) data, _ := brw.Peek(buffered) brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn)) // peek to make buffered data appear, as Reset will make it 0 @@ -293,12 +303,49 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return conn, brw, nil } -// ResponseWriterHijacked reports whether w or one of its wrapped response -// writers has been hijacked. -func ResponseWriterHijacked(w http.ResponseWriter) bool { +// used to track the size of hijacked response writers +type hijackedConn struct { + net.Conn + rr *responseRecorder +} + +func (hc *hijackedConn) updateReadSize(n int) { + if hc.rr.readSize != nil { + *hc.rr.readSize += n + } +} + +func (hc *hijackedConn) Read(p []byte) (int, error) { + n, err := hc.Conn.Read(p) + hc.updateReadSize(n) + return n, err +} + +func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) { + n, err := io.Copy(w, hc.Conn) + hc.updateReadSize(int(n)) + return n, err +} + +func (hc *hijackedConn) Write(p []byte) (int, error) { + n, err := hc.Conn.Write(p) + hc.rr.size += n + return n, err +} + +func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) { + n, err := io.Copy(hc.Conn, r) + hc.rr.size += int(n) + return n, err +} + +// DetachResponseWriterAfterHijack detaches w or one of its wrapped response +// writers when it's hijacked. Returns true if not already hijacked. +// When detached, bytes read or written stats will not be recorded for the hijacked connection, and it's safe to use the connection after http middleware returns. +func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool { for w != nil { - if hijacked, ok := w.(interface{ Hijacked() bool }); ok && hijacked.Hijacked() { - return true + if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok { + return detacher.DetachAfterHijack(detached) } unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter }) if !ok { @@ -321,7 +368,7 @@ type ResponseRecorder interface { Status() int Buffer() *bytes.Buffer Buffered() bool - Hijacked() bool + DetachAfterHijack(bool) bool Size() int WriteResponse() error } @@ -341,4 +388,7 @@ var ( // see PR #5022 (25%-50% speedup) _ io.ReaderFrom = (*ResponseWriterWrapper)(nil) _ io.ReaderFrom = (*responseRecorder)(nil) + _ io.ReaderFrom = (*hijackedConn)(nil) + + _ io.WriterTo = (*hijackedConn)(nil) ) diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index 72e416db1..ec8c3b5ab 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -246,11 +246,11 @@ func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) { } defer conn.Close() - if !rr.Hijacked() { - t.Fatal("response recorder should report hijacked state") + if rr.DetachAfterHijack(true) { + t.Fatal("response recorder should report hijacked state by returning false") } - if !ResponseWriterHijacked(rr) { - t.Fatal("ResponseWriterHijacked() should report true after hijack") + if DetachResponseWriterAfterHijack(rr, true) { + t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack") } if err := rr.WriteResponse(); err != nil { t.Fatalf("WriteResponse() after hijack returned error: %v", err)