mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	Merge pull request #1232 from mholt/fix-1229
proxy: record request Body for retry (fixes #1229)
This commit is contained in:
		
						commit
						12fd349916
					
				
							
								
								
									
										40
									
								
								caddyhttp/proxy/body.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								caddyhttp/proxy/body.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,40 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type bufferedBody struct {
 | 
			
		||||
	*bytes.Reader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*bufferedBody) Close() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// rewind allows bufferedBody to be read again.
 | 
			
		||||
func (b *bufferedBody) rewind() error {
 | 
			
		||||
	if b == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	_, err := b.Seek(0, io.SeekStart)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newBufferedBody returns *bufferedBody to use in place of src. Closes src
 | 
			
		||||
// and returns Read error on src. All content from src is buffered.
 | 
			
		||||
func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) {
 | 
			
		||||
	if src == nil {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	b, err := ioutil.ReadAll(src)
 | 
			
		||||
	src.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &bufferedBody{
 | 
			
		||||
		Reader: bytes.NewReader(b),
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										69
									
								
								caddyhttp/proxy/body_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								caddyhttp/proxy/body_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestBodyRetry(t *testing.T) {
 | 
			
		||||
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		io.Copy(w, r.Body)
 | 
			
		||||
		r.Body.Close()
 | 
			
		||||
	}))
 | 
			
		||||
	defer ts.Close()
 | 
			
		||||
 | 
			
		||||
	testcase := "test content"
 | 
			
		||||
	req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, err := newBufferedBody(req.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if body != nil {
 | 
			
		||||
		req.Body = body
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// simulate fail request
 | 
			
		||||
	host := req.URL.Host
 | 
			
		||||
	req.URL.Host = "example.com"
 | 
			
		||||
	body.rewind()
 | 
			
		||||
	_, _ = http.DefaultTransport.RoundTrip(req)
 | 
			
		||||
 | 
			
		||||
	// retry request
 | 
			
		||||
	req.URL.Host = host
 | 
			
		||||
	body.rewind()
 | 
			
		||||
	resp, err := http.DefaultTransport.RoundTrip(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	result, err := ioutil.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	resp.Body.Close()
 | 
			
		||||
	if string(result) != testcase {
 | 
			
		||||
		t.Fatalf("result = %s, want %s", result, testcase)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// try one more time for body reuse
 | 
			
		||||
	body.rewind()
 | 
			
		||||
	resp, err = http.DefaultTransport.RoundTrip(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	result, err = ioutil.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	resp.Body.Close()
 | 
			
		||||
	if string(result) != testcase {
 | 
			
		||||
		t.Fatalf("result = %s, want %s", result, testcase)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
			
		||||
	// outreq is the request that makes a roundtrip to the backend
 | 
			
		||||
	outreq := createUpstreamRequest(r)
 | 
			
		||||
 | 
			
		||||
	// record and replace outreq body
 | 
			
		||||
	body, err := newBufferedBody(outreq.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return http.StatusBadRequest, errors.New("failed to read downstream request body")
 | 
			
		||||
	}
 | 
			
		||||
	if body != nil {
 | 
			
		||||
		outreq.Body = body
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// The keepRetrying function will return true if we should
 | 
			
		||||
	// loop and try to select another host, or false if we
 | 
			
		||||
	// should break and stop retrying.
 | 
			
		||||
@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
			
		||||
			downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// rewind request body to its beginning
 | 
			
		||||
		if err := body.rewind(); err != nil {
 | 
			
		||||
			return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// tell the proxy to serve the request
 | 
			
		||||
		atomic.AddInt64(&host.Conns, 1)
 | 
			
		||||
		backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,7 @@ import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/mholt/caddy/caddyfile"
 | 
			
		||||
	"github.com/mholt/caddy/caddyhttp/httpserver"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/websocket"
 | 
			
		||||
@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestReverseProxyRetry(t *testing.T) {
 | 
			
		||||
	log.SetOutput(ioutil.Discard)
 | 
			
		||||
	defer log.SetOutput(os.Stderr)
 | 
			
		||||
 | 
			
		||||
	// set up proxy
 | 
			
		||||
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		io.Copy(w, r.Body)
 | 
			
		||||
		r.Body.Close()
 | 
			
		||||
	}))
 | 
			
		||||
	defer backend.Close()
 | 
			
		||||
 | 
			
		||||
	su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
 | 
			
		||||
	proxy / localhost:65535 localhost:65534 `+backend.URL+` {
 | 
			
		||||
		policy round_robin
 | 
			
		||||
		fail_timeout 5s
 | 
			
		||||
		max_fails 1
 | 
			
		||||
		try_duration 5s
 | 
			
		||||
		try_interval 250ms
 | 
			
		||||
	}
 | 
			
		||||
	`)))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p := &Proxy{
 | 
			
		||||
		Next:      httpserver.EmptyNext, // prevents panic in some cases when test fails
 | 
			
		||||
		Upstreams: su,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// middle is required to simulate closable downstream request body
 | 
			
		||||
	middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		_, err = p.ServeHTTP(w, r)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
	}))
 | 
			
		||||
	defer middle.Close()
 | 
			
		||||
 | 
			
		||||
	testcase := "test content"
 | 
			
		||||
	r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	resp, err := http.DefaultTransport.RoundTrip(r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	b, err := ioutil.ReadAll(resp.Body)
 | 
			
		||||
	resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if string(b) != testcase {
 | 
			
		||||
		t.Fatalf("string(b) = %s, want %s", string(b), testcase)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
 | 
			
		||||
	uri, _ := url.Parse(name)
 | 
			
		||||
	u := &fakeUpstream{
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user