mirror of
https://github.com/caddyserver/caddy.git
synced 2026-04-24 01:49:32 -04:00
record bytes read and written for response writers unless detached
This commit is contained in:
parent
db86fdaba2
commit
4628aea894
@ -21,6 +21,8 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
|
||||
// ResponseWriterWrapper wraps an underlying ResponseWriter and
|
||||
@ -71,6 +73,7 @@ type responseRecorder struct {
|
||||
wroteHeader bool
|
||||
stream bool
|
||||
hijacked bool
|
||||
detached bool
|
||||
|
||||
readSize *int
|
||||
}
|
||||
@ -155,12 +158,6 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
|
||||
// save statusCode always, in case HTTP middleware upgrades websocket
|
||||
// connections by manually setting headers and writing status 101
|
||||
rr.statusCode = statusCode
|
||||
if statusCode == http.StatusSwitchingProtocols {
|
||||
rr.stream = true
|
||||
rr.wroteHeader = true
|
||||
rr.ResponseWriterWrapper.WriteHeader(statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
// decide whether we should buffer the response
|
||||
if rr.shouldBuffer == nil {
|
||||
@ -169,12 +166,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
|
||||
rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header())
|
||||
}
|
||||
|
||||
// 1xx responses aren't final; just informational
|
||||
if statusCode < 100 || statusCode > 199 {
|
||||
// 1xx responses except 101 aren't final; just informational
|
||||
if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols {
|
||||
rr.wroteHeader = true
|
||||
}
|
||||
|
||||
// if informational or not buffered, immediately write header
|
||||
// if 1xx or not buffered, immediately write header
|
||||
if rr.stream || (100 <= statusCode && statusCode <= 199) {
|
||||
rr.ResponseWriterWrapper.WriteHeader(statusCode)
|
||||
}
|
||||
@ -230,8 +227,12 @@ func (rr *responseRecorder) Buffered() bool {
|
||||
return !rr.stream
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) Hijacked() bool {
|
||||
return rr.hijacked
|
||||
func (rr *responseRecorder) DetachAfterHijack(detached bool) bool {
|
||||
if rr.hijacked {
|
||||
return false
|
||||
}
|
||||
rr.detached = detached
|
||||
return true
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) WriteResponse() error {
|
||||
@ -268,6 +269,12 @@ func (rr *responseRecorder) setReadSize(size *int) {
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if !rr.wroteHeader {
|
||||
// hijacking without writing status code first works as long as subsequent writes follows http1.1
|
||||
// wire format, but it will show up with a status code of 0 in the access log and bytes written
|
||||
// will include response headers.
|
||||
caddy.Log().Debug("hijacking without writing status code first")
|
||||
}
|
||||
//nolint:bodyclose
|
||||
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
|
||||
if err != nil {
|
||||
@ -276,13 +283,16 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
rr.hijacked = true
|
||||
rr.stream = true
|
||||
rr.wroteHeader = true
|
||||
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not.
|
||||
// Return the raw hijacked connection so upgraded stream traffic does not keep
|
||||
// traversing the response recorder hot path.
|
||||
if rr.detached {
|
||||
return conn, brw, nil
|
||||
}
|
||||
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
|
||||
conn = &hijackedConn{conn, rr}
|
||||
brw.Writer.Reset(conn)
|
||||
|
||||
buffered := brw.Reader.Buffered()
|
||||
if buffered != 0 {
|
||||
conn.(*hijackedConn).updateReadSize(buffered)
|
||||
data, _ := brw.Peek(buffered)
|
||||
brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn))
|
||||
// peek to make buffered data appear, as Reset will make it 0
|
||||
@ -293,12 +303,49 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return conn, brw, nil
|
||||
}
|
||||
|
||||
// ResponseWriterHijacked reports whether w or one of its wrapped response
|
||||
// writers has been hijacked.
|
||||
func ResponseWriterHijacked(w http.ResponseWriter) bool {
|
||||
// used to track the size of hijacked response writers
|
||||
type hijackedConn struct {
|
||||
net.Conn
|
||||
rr *responseRecorder
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) updateReadSize(n int) {
|
||||
if hc.rr.readSize != nil {
|
||||
*hc.rr.readSize += n
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) Read(p []byte) (int, error) {
|
||||
n, err := hc.Conn.Read(p)
|
||||
hc.updateReadSize(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := io.Copy(w, hc.Conn)
|
||||
hc.updateReadSize(int(n))
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) Write(p []byte) (int, error) {
|
||||
n, err := hc.Conn.Write(p)
|
||||
hc.rr.size += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
|
||||
n, err := io.Copy(hc.Conn, r)
|
||||
hc.rr.size += int(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// DetachResponseWriterAfterHijack detaches w or one of its wrapped response
|
||||
// writers when it's hijacked. Returns true if not already hijacked.
|
||||
// When detached, bytes read or written stats will not be recorded for the hijacked connection, and it's safe to use the connection after http middleware returns.
|
||||
func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool {
|
||||
for w != nil {
|
||||
if hijacked, ok := w.(interface{ Hijacked() bool }); ok && hijacked.Hijacked() {
|
||||
return true
|
||||
if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok {
|
||||
return detacher.DetachAfterHijack(detached)
|
||||
}
|
||||
unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter })
|
||||
if !ok {
|
||||
@ -321,7 +368,7 @@ type ResponseRecorder interface {
|
||||
Status() int
|
||||
Buffer() *bytes.Buffer
|
||||
Buffered() bool
|
||||
Hijacked() bool
|
||||
DetachAfterHijack(bool) bool
|
||||
Size() int
|
||||
WriteResponse() error
|
||||
}
|
||||
@ -341,4 +388,7 @@ var (
|
||||
// see PR #5022 (25%-50% speedup)
|
||||
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
|
||||
_ io.ReaderFrom = (*responseRecorder)(nil)
|
||||
_ io.ReaderFrom = (*hijackedConn)(nil)
|
||||
|
||||
_ io.WriterTo = (*hijackedConn)(nil)
|
||||
)
|
||||
|
||||
@ -246,11 +246,11 @@ func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if !rr.Hijacked() {
|
||||
t.Fatal("response recorder should report hijacked state")
|
||||
if rr.DetachAfterHijack(true) {
|
||||
t.Fatal("response recorder should report hijacked state by returning false")
|
||||
}
|
||||
if !ResponseWriterHijacked(rr) {
|
||||
t.Fatal("ResponseWriterHijacked() should report true after hijack")
|
||||
if DetachResponseWriterAfterHijack(rr, true) {
|
||||
t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack")
|
||||
}
|
||||
if err := rr.WriteResponse(); err != nil {
|
||||
t.Fatalf("WriteResponse() after hijack returned error: %v", err)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user