mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-04 03:27:23 -05:00 
			
		
		
		
	Merge pull request #1232 from mholt/fix-1229
proxy: record request Body for retry (fixes #1229)
This commit is contained in:
		
						commit
						12fd349916
					
				
							
								
								
									
										40
									
								
								caddyhttp/proxy/body.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								caddyhttp/proxy/body.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,40 @@
 | 
				
			|||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type bufferedBody struct {
 | 
				
			||||||
 | 
						*bytes.Reader
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (*bufferedBody) Close() error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// rewind allows bufferedBody to be read again.
 | 
				
			||||||
 | 
					func (b *bufferedBody) rewind() error {
 | 
				
			||||||
 | 
						if b == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, err := b.Seek(0, io.SeekStart)
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// newBufferedBody returns *bufferedBody to use in place of src. Closes src
 | 
				
			||||||
 | 
					// and returns Read error on src. All content from src is buffered.
 | 
				
			||||||
 | 
					func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) {
 | 
				
			||||||
 | 
						if src == nil {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						b, err := ioutil.ReadAll(src)
 | 
				
			||||||
 | 
						src.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &bufferedBody{
 | 
				
			||||||
 | 
							Reader: bytes.NewReader(b),
 | 
				
			||||||
 | 
						}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										69
									
								
								caddyhttp/proxy/body_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								caddyhttp/proxy/body_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
				
			|||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBodyRetry(t *testing.T) {
 | 
				
			||||||
 | 
						ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							io.Copy(w, r.Body)
 | 
				
			||||||
 | 
							r.Body.Close()
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer ts.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						testcase := "test content"
 | 
				
			||||||
 | 
						req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := newBufferedBody(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if body != nil {
 | 
				
			||||||
 | 
							req.Body = body
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// simulate fail request
 | 
				
			||||||
 | 
						host := req.URL.Host
 | 
				
			||||||
 | 
						req.URL.Host = "example.com"
 | 
				
			||||||
 | 
						body.rewind()
 | 
				
			||||||
 | 
						_, _ = http.DefaultTransport.RoundTrip(req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// retry request
 | 
				
			||||||
 | 
						req.URL.Host = host
 | 
				
			||||||
 | 
						body.rewind()
 | 
				
			||||||
 | 
						resp, err := http.DefaultTransport.RoundTrip(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						result, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						if string(result) != testcase {
 | 
				
			||||||
 | 
							t.Fatalf("result = %s, want %s", result, testcase)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// try one more time for body reuse
 | 
				
			||||||
 | 
						body.rewind()
 | 
				
			||||||
 | 
						resp, err = http.DefaultTransport.RoundTrip(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						result, err = ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						if string(result) != testcase {
 | 
				
			||||||
 | 
							t.Fatalf("result = %s, want %s", result, testcase)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			|||||||
	// outreq is the request that makes a roundtrip to the backend
 | 
						// outreq is the request that makes a roundtrip to the backend
 | 
				
			||||||
	outreq := createUpstreamRequest(r)
 | 
						outreq := createUpstreamRequest(r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// record and replace outreq body
 | 
				
			||||||
 | 
						body, err := newBufferedBody(outreq.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return http.StatusBadRequest, errors.New("failed to read downstream request body")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if body != nil {
 | 
				
			||||||
 | 
							outreq.Body = body
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// The keepRetrying function will return true if we should
 | 
						// The keepRetrying function will return true if we should
 | 
				
			||||||
	// loop and try to select another host, or false if we
 | 
						// loop and try to select another host, or false if we
 | 
				
			||||||
	// should break and stop retrying.
 | 
						// should break and stop retrying.
 | 
				
			||||||
@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			|||||||
			downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
 | 
								downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// rewind request body to its beginning
 | 
				
			||||||
 | 
							if err := body.rewind(); err != nil {
 | 
				
			||||||
 | 
								return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// tell the proxy to serve the request
 | 
							// tell the proxy to serve the request
 | 
				
			||||||
		atomic.AddInt64(&host.Conns, 1)
 | 
							atomic.AddInt64(&host.Conns, 1)
 | 
				
			||||||
		backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
 | 
							backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,7 @@ import (
 | 
				
			|||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/mholt/caddy/caddyfile"
 | 
				
			||||||
	"github.com/mholt/caddy/caddyhttp/httpserver"
 | 
						"github.com/mholt/caddy/caddyhttp/httpserver"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"golang.org/x/net/websocket"
 | 
						"golang.org/x/net/websocket"
 | 
				
			||||||
@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestReverseProxyRetry(t *testing.T) {
 | 
				
			||||||
 | 
						log.SetOutput(ioutil.Discard)
 | 
				
			||||||
 | 
						defer log.SetOutput(os.Stderr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// set up proxy
 | 
				
			||||||
 | 
						backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							io.Copy(w, r.Body)
 | 
				
			||||||
 | 
							r.Body.Close()
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer backend.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
 | 
				
			||||||
 | 
						proxy / localhost:65535 localhost:65534 `+backend.URL+` {
 | 
				
			||||||
 | 
							policy round_robin
 | 
				
			||||||
 | 
							fail_timeout 5s
 | 
				
			||||||
 | 
							max_fails 1
 | 
				
			||||||
 | 
							try_duration 5s
 | 
				
			||||||
 | 
							try_interval 250ms
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						`)))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p := &Proxy{
 | 
				
			||||||
 | 
							Next:      httpserver.EmptyNext, // prevents panic in some cases when test fails
 | 
				
			||||||
 | 
							Upstreams: su,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// middle is required to simulate closable downstream request body
 | 
				
			||||||
 | 
						middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							_, err = p.ServeHTTP(w, r)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer middle.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						testcase := "test content"
 | 
				
			||||||
 | 
						r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp, err := http.DefaultTransport.RoundTrip(r)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						b, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if string(b) != testcase {
 | 
				
			||||||
 | 
							t.Fatalf("string(b) = %s, want %s", string(b), testcase)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
 | 
					func newFakeUpstream(name string, insecure bool) *fakeUpstream {
 | 
				
			||||||
	uri, _ := url.Parse(name)
 | 
						uri, _ := url.Parse(name)
 | 
				
			||||||
	u := &fakeUpstream{
 | 
						u := &fakeUpstream{
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user