diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go index ccd7e6af8..33a152692 100644 --- a/middleware/errors/errors.go +++ b/middleware/errors/errors.go @@ -43,9 +43,7 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er } if status >= 400 { - if w.Header().Get("Content-Length") == "" { - h.errorPage(w, r, status) - } + h.errorPage(w, r, status) return 0, err } diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go index c0cf63259..49af3e4f4 100644 --- a/middleware/errors/errors_test.go +++ b/middleware/errors/errors_test.go @@ -79,13 +79,6 @@ func TestErrors(t *testing.T) { expectedLog: "", expectedErr: nil, }, - { - next: genErrorHandler(http.StatusNotFound, nil, "normal"), - expectedCode: 0, - expectedBody: "normal", - expectedLog: "", - expectedErr: nil, - }, { next: genErrorHandler(http.StatusForbidden, nil, ""), expectedCode: 0, @@ -168,8 +161,8 @@ func genErrorHandler(status int, err error, body string) middleware.Handler { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { if len(body) > 0 { w.Header().Set("Content-Length", strconv.Itoa(len(body))) + fmt.Fprint(w, body) } - fmt.Fprint(w, body) return status, err }) } diff --git a/middleware/fastcgi/fastcgi.go b/middleware/fastcgi/fastcgi.go old mode 100644 new mode 100755 index fa9a6c469..33b21d435 --- a/middleware/fastcgi/fastcgi.go +++ b/middleware/fastcgi/fastcgi.go @@ -4,7 +4,6 @@ package fastcgi import ( - "bytes" "errors" "io" "net/http" @@ -106,43 +105,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) return http.StatusBadGateway, err } - var responseBody io.Reader = resp.Body - if resp.Header.Get("Content-Length") == "" { - // If the upstream app didn't set a Content-Length (shame on them), - // we need to do it to prevent error messages being appended to - // an already-written response, and other problematic behavior. - // So we copy it to a buffer and read its size before flushing - // the response out to the client. See issues #567 and #614. - buf := new(bytes.Buffer) - _, err := io.Copy(buf, resp.Body) - if err != nil { - return http.StatusBadGateway, err - } - w.Header().Set("Content-Length", strconv.Itoa(buf.Len())) - responseBody = buf - } - - // Write the status code and header fields + // Write response header writeHeader(w, resp) // Write the response body - _, err = io.Copy(w, responseBody) + _, err = io.Copy(w, resp.Body) if err != nil { return http.StatusBadGateway, err } - // FastCGI stderr outputs + // Log any stderr output from upstream if fcgiBackend.stderr.Len() != 0 { // Remove trailing newline, error logger already does this. err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) } - // Normally we should only return a status >= 400 if no response - // body is written yet, however, upstream apps don't know about - // this contract and we still want the correct code logged, so error - // handling code in our stack needs to check Content-Length before - // writing an error message... oh well. - return resp.StatusCode, err + // Normally we would return the status code if it is an error status (>= 400), + // however, upstream FastCGI apps don't know about our contract and have + // probably already written an error page. So we just return 0, indicating + // that the response body is already written. However, we do return any + // error value so it can be logged. + // Note that the proxy middleware works the same way, returning status=0. + return 0, err } } diff --git a/middleware/fastcgi/fastcgi_test.go b/middleware/fastcgi/fastcgi_test.go index 5fbba23f1..001f38721 100644 --- a/middleware/fastcgi/fastcgi_test.go +++ b/middleware/fastcgi/fastcgi_test.go @@ -10,49 +10,44 @@ import ( "testing" ) -func TestServeHTTPContentLength(t *testing.T) { - testWithBackend := func(body string, setContentLength bool) { - bodyLenStr := strconv.Itoa(len(body)) - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("BackendSetsContentLength=%v: Unable to create listener for test: %v", setContentLength, err) - } - defer listener.Close() - go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if setContentLength { - w.Header().Set("Content-Length", bodyLenStr) - } - w.Write([]byte(body)) - })) +func TestServeHTTP(t *testing.T) { + body := "This is some test body content" - handler := Handler{ - Next: nil, - Rules: []Rule{{Path: "/", Address: listener.Addr().String()}}, - } - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatalf("BackendSetsContentLength=%v: Unable to create request: %v", setContentLength, err) - } - w := httptest.NewRecorder() - - status, err := handler.ServeHTTP(w, r) - - if got, want := status, http.StatusOK; got != want { - t.Errorf("BackendSetsContentLength=%v: Expected returned status code to be %d, got %d", setContentLength, want, got) - } - if err != nil { - t.Errorf("BackendSetsContentLength=%v: Expected nil error, got: %v", setContentLength, err) - } - if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want { - t.Errorf("BackendSetsContentLength=%v: Expected Content-Length to be '%s', got: '%s'", setContentLength, want, got) - } - if got, want := w.Body.String(), body; got != want { - t.Errorf("BackendSetsContentLength=%v: Expected response body to be '%s', got: '%s'", setContentLength, want, got) - } + bodyLenStr := strconv.Itoa(len(body)) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to create listener for test: %v", err) } + defer listener.Close() + go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", bodyLenStr) + w.Write([]byte(body)) + })) - testWithBackend("Backend does NOT set Content-Length", false) - testWithBackend("Backend sets Content-Length", true) + handler := Handler{ + Next: nil, + Rules: []Rule{{Path: "/", Address: listener.Addr().String()}}, + } + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Unable to create request: %v", err) + } + w := httptest.NewRecorder() + + status, err := handler.ServeHTTP(w, r) + + if got, want := status, 0; got != want { + t.Errorf("Expected returned status code to be %d, got %d", want, got) + } + if err != nil { + t.Errorf("Expected nil error, got: %v", err) + } + if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want { + t.Errorf("Expected Content-Length to be '%s', got: '%s'", want, got) + } + if got, want := w.Body.String(), body; got != want { + t.Errorf("Expected response body to be '%s', got: '%s'", want, got) + } } func TestRuleParseAddress(t *testing.T) { diff --git a/middleware/log/log.go b/middleware/log/log.go index acb695c5e..feb6182ad 100644 --- a/middleware/log/log.go +++ b/middleware/log/log.go @@ -26,7 +26,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // The error must be handled here so the log entry will record the response size. if l.ErrorFunc != nil { l.ErrorFunc(responseRecorder, r, status) - } else if responseRecorder.Header().Get("Content-Length") == "" { // ensure no body written since proxy backends may write an error page + } else { // Default failover error handler responseRecorder.WriteHeader(status) fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status)) diff --git a/middleware/middleware.go b/middleware/middleware.go index c7036f3c9..d91044ebe 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -13,30 +13,24 @@ type ( // passed the next Handler in the chain. Middleware func(Handler) Handler - // Handler is like http.Handler except ServeHTTP returns a status code - // and an error. The status code is for the client's benefit; the error - // value is for the server's benefit. The status code will be sent to - // the client while the error value will be logged privately. Sometimes, - // an error status code (4xx or 5xx) may be returned with a nil error - // when there is no reason to log the error on the server. + // Handler is like http.Handler except ServeHTTP may return a status + // code and/or error. // - // If a HandlerFunc returns an error (status >= 400), it should NOT - // write to the response. This philosophy makes middleware.Handler - // different from http.Handler: error handling should happen at the - // application layer or in dedicated error-handling middleware only - // rather than with an "every middleware for itself" paradigm. + // If ServeHTTP writes to the response body, it should return a status + // code of 0. This signals to other handlers above it that the response + // body is already written, and that they should not write to it also. // - // The application or error-handling middleware should incorporate logic - // to ensure that the client always gets a proper response according to - // the status code. For security reasons, it should probably not reveal - // the actual error message. (Instead it should be logged, for example.) + // If ServeHTTP encounters an error, it should return the error value + // so it can be logged by designated error-handling middleware. // - // Handlers which do write to the response should return a status value - // < 400 as a signal that a response has been written. In other words, - // only error-handling middleware or the application will write to the - // response for a status code >= 400. When ANY handler writes to the - // response, it should return a status code < 400 to signal others to - // NOT write to the response again, which would be erroneous. + // If writing a response after calling another ServeHTTP method, the + // returned status code SHOULD be used when writing the response. + // + // If handling errors after calling another ServeHTTP method, the + // returned error value SHOULD be logged or handled accordingly. + // + // Otherwise, return values should be propagated down the middleware + // chain by returning them unchanged. Handler interface { ServeHTTP(http.ResponseWriter, *http.Request) (int, error) } diff --git a/server/server.go b/server/server.go index 2df0deac3..d687b8378 100644 --- a/server/server.go +++ b/server/server.go @@ -319,7 +319,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { status, _ := vh.stack.ServeHTTP(w, r) // Fallback error response in case error handling wasn't chained in - if status >= 400 && w.Header().Get("Content-Length") == "" { + if status >= 400 { DefaultErrorFunc(w, r, status) } } else {