mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-10-31 10:37:24 -04:00 
			
		
		
		
	proxy: record request Body for retry (fixes #1229)
This commit is contained in:
		
							parent
							
								
									0cdaaba4b8
								
							
						
					
					
						commit
						dd4c4d7eb6
					
				
							
								
								
									
										40
									
								
								caddyhttp/proxy/body.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								caddyhttp/proxy/body.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,40 @@ | ||||
| package proxy | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| ) | ||||
| 
 | ||||
| type bufferedBody struct { | ||||
| 	*bytes.Reader | ||||
| } | ||||
| 
 | ||||
| func (*bufferedBody) Close() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // rewind allows bufferedBody to be read again. | ||||
| func (b *bufferedBody) rewind() error { | ||||
| 	if b == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	_, err := b.Seek(0, io.SeekStart) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // newBufferedBody returns *bufferedBody to use in place of src. Closes src | ||||
| // and returns Read error on src. All content from src is buffered. | ||||
| func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) { | ||||
| 	if src == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	b, err := ioutil.ReadAll(src) | ||||
| 	src.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &bufferedBody{ | ||||
| 		Reader: bytes.NewReader(b), | ||||
| 	}, nil | ||||
| } | ||||
							
								
								
									
										69
									
								
								caddyhttp/proxy/body_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								caddyhttp/proxy/body_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,69 @@ | ||||
| package proxy | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestBodyRetry(t *testing.T) { | ||||
| 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		io.Copy(w, r.Body) | ||||
| 		r.Body.Close() | ||||
| 	})) | ||||
| 	defer ts.Close() | ||||
| 
 | ||||
| 	testcase := "test content" | ||||
| 	req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	body, err := newBufferedBody(req.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if body != nil { | ||||
| 		req.Body = body | ||||
| 	} | ||||
| 
 | ||||
| 	// simulate fail request | ||||
| 	host := req.URL.Host | ||||
| 	req.URL.Host = "example.com" | ||||
| 	body.rewind() | ||||
| 	_, _ = http.DefaultTransport.RoundTrip(req) | ||||
| 
 | ||||
| 	// retry request | ||||
| 	req.URL.Host = host | ||||
| 	body.rewind() | ||||
| 	resp, err := http.DefaultTransport.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	result, err := ioutil.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	resp.Body.Close() | ||||
| 	if string(result) != testcase { | ||||
| 		t.Fatalf("result = %s, want %s", result, testcase) | ||||
| 	} | ||||
| 
 | ||||
| 	// try one more time for body reuse | ||||
| 	body.rewind() | ||||
| 	resp, err = http.DefaultTransport.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	result, err = ioutil.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	resp.Body.Close() | ||||
| 	if string(result) != testcase { | ||||
| 		t.Fatalf("result = %s, want %s", result, testcase) | ||||
| 	} | ||||
| } | ||||
| @ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { | ||||
| 	// outreq is the request that makes a roundtrip to the backend | ||||
| 	outreq := createUpstreamRequest(r) | ||||
| 
 | ||||
| 	// record and replace outreq body | ||||
| 	body, err := newBufferedBody(outreq.Body) | ||||
| 	if err != nil { | ||||
| 		return http.StatusBadRequest, errors.New("failed to read downstream request body") | ||||
| 	} | ||||
| 	if body != nil { | ||||
| 		outreq.Body = body | ||||
| 	} | ||||
| 
 | ||||
| 	// The keepRetrying function will return true if we should | ||||
| 	// loop and try to select another host, or false if we | ||||
| 	// should break and stop retrying. | ||||
| @ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { | ||||
| 			downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) | ||||
| 		} | ||||
| 
 | ||||
| 		// rewind request body to its beginning | ||||
| 		if err := body.rewind(); err != nil { | ||||
| 			return http.StatusInternalServerError, errors.New("unable to rewind downstream request body") | ||||
| 		} | ||||
| 
 | ||||
| 		// tell the proxy to serve the request | ||||
| 		atomic.AddInt64(&host.Conns, 1) | ||||
| 		backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) | ||||
|  | ||||
| @ -20,6 +20,7 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/mholt/caddy/caddyfile" | ||||
| 	"github.com/mholt/caddy/caddyhttp/httpserver" | ||||
| 
 | ||||
| 	"golang.org/x/net/websocket" | ||||
| @ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestReverseProxyRetry(t *testing.T) { | ||||
| 	log.SetOutput(ioutil.Discard) | ||||
| 	defer log.SetOutput(os.Stderr) | ||||
| 
 | ||||
| 	// set up proxy | ||||
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		io.Copy(w, r.Body) | ||||
| 		r.Body.Close() | ||||
| 	})) | ||||
| 	defer backend.Close() | ||||
| 
 | ||||
| 	su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(` | ||||
| 	proxy / localhost:65535 localhost:65534 `+backend.URL+` { | ||||
| 		policy round_robin | ||||
| 		fail_timeout 5s | ||||
| 		max_fails 1 | ||||
| 		try_duration 5s | ||||
| 		try_interval 250ms | ||||
| 	} | ||||
| 	`))) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	p := &Proxy{ | ||||
| 		Next:      httpserver.EmptyNext, // prevents panic in some cases when test fails | ||||
| 		Upstreams: su, | ||||
| 	} | ||||
| 
 | ||||
| 	// middle is required to simulate closable downstream request body | ||||
| 	middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		_, err = p.ServeHTTP(w, r) | ||||
| 		if err != nil { | ||||
| 			t.Error(err) | ||||
| 		} | ||||
| 	})) | ||||
| 	defer middle.Close() | ||||
| 
 | ||||
| 	testcase := "test content" | ||||
| 	r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	resp, err := http.DefaultTransport.RoundTrip(r) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	b, err := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if string(b) != testcase { | ||||
| 		t.Fatalf("string(b) = %s, want %s", string(b), testcase) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func newFakeUpstream(name string, insecure bool) *fakeUpstream { | ||||
| 	uri, _ := url.Parse(name) | ||||
| 	u := &fakeUpstream{ | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user