mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-04 03:27:23 -05:00 
			
		
		
		
	Merge pull request #40 from ChannelMeter/proxy-middleware
Proxy Middleware: Add support for multiple backends, load balancing & healthchecks
This commit is contained in:
		
						commit
						5f32f9b1c8
					
				
							
								
								
									
										91
									
								
								middleware/proxy/policy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								middleware/proxy/policy.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
				
			|||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"math/rand"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type HostPool []*UpstreamHost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Policy decides how a host will be selected from a pool.
 | 
				
			||||||
 | 
					type Policy interface {
 | 
				
			||||||
 | 
						Select(pool HostPool) *UpstreamHost
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The random policy randomly selected an up host from the pool.
 | 
				
			||||||
 | 
					type Random struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Random) Select(pool HostPool) *UpstreamHost {
 | 
				
			||||||
 | 
						// instead of just generating a random index
 | 
				
			||||||
 | 
						// this is done to prevent selecting a down host
 | 
				
			||||||
 | 
						var randHost *UpstreamHost
 | 
				
			||||||
 | 
						count := 0
 | 
				
			||||||
 | 
						for _, host := range pool {
 | 
				
			||||||
 | 
							if host.Down() {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							count++
 | 
				
			||||||
 | 
							if count == 1 {
 | 
				
			||||||
 | 
								randHost = host
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								r := rand.Int() % count
 | 
				
			||||||
 | 
								if r == (count - 1) {
 | 
				
			||||||
 | 
									randHost = host
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return randHost
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The least_conn policy selects a host with the least connections.
 | 
				
			||||||
 | 
					// If multiple hosts have the least amount of connections, one is randomly
 | 
				
			||||||
 | 
					// chosen.
 | 
				
			||||||
 | 
					type LeastConn struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
 | 
				
			||||||
 | 
						var bestHost *UpstreamHost
 | 
				
			||||||
 | 
						count := 0
 | 
				
			||||||
 | 
						leastConn := int64(1<<63 - 1)
 | 
				
			||||||
 | 
						for _, host := range pool {
 | 
				
			||||||
 | 
							if host.Down() {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							hostConns := host.Conns
 | 
				
			||||||
 | 
							if hostConns < leastConn {
 | 
				
			||||||
 | 
								bestHost = host
 | 
				
			||||||
 | 
								leastConn = hostConns
 | 
				
			||||||
 | 
								count = 1
 | 
				
			||||||
 | 
							} else if hostConns == leastConn {
 | 
				
			||||||
 | 
								// randomly select host among hosts with least connections
 | 
				
			||||||
 | 
								count++
 | 
				
			||||||
 | 
								if count == 1 {
 | 
				
			||||||
 | 
									bestHost = host
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									r := rand.Int() % count
 | 
				
			||||||
 | 
									if r == (count - 1) {
 | 
				
			||||||
 | 
										bestHost = host
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return bestHost
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The round_robin policy selects a host based on round robin ordering.
 | 
				
			||||||
 | 
					type RoundRobin struct {
 | 
				
			||||||
 | 
						Robin uint32
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
 | 
				
			||||||
 | 
						poolLen := uint32(len(pool))
 | 
				
			||||||
 | 
						selection := atomic.AddUint32(&r.Robin, 1) % poolLen
 | 
				
			||||||
 | 
						host := pool[selection]
 | 
				
			||||||
 | 
						// if the currently selected host is down, just ffwd to up host
 | 
				
			||||||
 | 
						for i := uint32(1); host.Down() && i < poolLen; i++ {
 | 
				
			||||||
 | 
							host = pool[(selection+i)%poolLen]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if host.Down() {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return host
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										57
									
								
								middleware/proxy/policy_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								middleware/proxy/policy_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,57 @@
 | 
				
			|||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func testPool() HostPool {
 | 
				
			||||||
 | 
						pool := []*UpstreamHost{
 | 
				
			||||||
 | 
							&UpstreamHost{
 | 
				
			||||||
 | 
								Name: "http://google.com", // this should resolve (healthcheck test)
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							&UpstreamHost{
 | 
				
			||||||
 | 
								Name: "http://shouldnot.resolve", // this shouldn't
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							&UpstreamHost{
 | 
				
			||||||
 | 
								Name: "http://C",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return HostPool(pool)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRoundRobinPolicy(t *testing.T) {
 | 
				
			||||||
 | 
						pool := testPool()
 | 
				
			||||||
 | 
						rrPolicy := &RoundRobin{}
 | 
				
			||||||
 | 
						h := rrPolicy.Select(pool)
 | 
				
			||||||
 | 
						// First selected host is 1, because counter starts at 0
 | 
				
			||||||
 | 
						// and increments before host is selected
 | 
				
			||||||
 | 
						if h != pool[1] {
 | 
				
			||||||
 | 
							t.Error("Expected first round robin host to be second host in the pool.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h = rrPolicy.Select(pool)
 | 
				
			||||||
 | 
						if h != pool[2] {
 | 
				
			||||||
 | 
							t.Error("Expected second round robin host to be third host in the pool.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// mark host as down
 | 
				
			||||||
 | 
						pool[0].Unhealthy = true
 | 
				
			||||||
 | 
						h = rrPolicy.Select(pool)
 | 
				
			||||||
 | 
						if h != pool[1] {
 | 
				
			||||||
 | 
							t.Error("Expected third round robin host to be first host in the pool.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestLeastConnPolicy(t *testing.T) {
 | 
				
			||||||
 | 
						pool := testPool()
 | 
				
			||||||
 | 
						lcPolicy := &LeastConn{}
 | 
				
			||||||
 | 
						pool[0].Conns = 10
 | 
				
			||||||
 | 
						pool[1].Conns = 10
 | 
				
			||||||
 | 
						h := lcPolicy.Select(pool)
 | 
				
			||||||
 | 
						if h != pool[2] {
 | 
				
			||||||
 | 
							t.Error("Expected least connection host to be third host.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						pool[2].Conns = 100
 | 
				
			||||||
 | 
						h = lcPolicy.Select(pool)
 | 
				
			||||||
 | 
						if h != pool[0] && h != pool[1] {
 | 
				
			||||||
 | 
							t.Error("Expected least connection host to be first or second host.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -2,52 +2,119 @@
 | 
				
			|||||||
package proxy
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"net/http"
 | 
						"errors"
 | 
				
			||||||
	"net/http/httputil"
 | 
					 | 
				
			||||||
	"net/url"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/mholt/caddy/middleware"
 | 
						"github.com/mholt/caddy/middleware"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var errUnreachable = errors.New("Unreachable backend")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Proxy represents a middleware instance that can proxy requests.
 | 
					// Proxy represents a middleware instance that can proxy requests.
 | 
				
			||||||
type Proxy struct {
 | 
					type Proxy struct {
 | 
				
			||||||
	Next      middleware.Handler
 | 
						Next      middleware.Handler
 | 
				
			||||||
	Rules []Rule
 | 
						Upstreams []Upstream
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// An upstream manages a pool of proxy upstream hosts. Select should return a
 | 
				
			||||||
 | 
					// suitable upstream host, or nil if no such hosts are available.
 | 
				
			||||||
 | 
					type Upstream interface {
 | 
				
			||||||
 | 
						// The path this upstream host should be routed on
 | 
				
			||||||
 | 
						From() string
 | 
				
			||||||
 | 
						// Selects an upstream host to be routed to.
 | 
				
			||||||
 | 
						Select() *UpstreamHost
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type UpstreamHostDownFunc func(*UpstreamHost) bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// An UpstreamHost represents a single proxy upstream
 | 
				
			||||||
 | 
					type UpstreamHost struct {
 | 
				
			||||||
 | 
						// The hostname of this upstream host
 | 
				
			||||||
 | 
						Name         string
 | 
				
			||||||
 | 
						ReverseProxy *ReverseProxy
 | 
				
			||||||
 | 
						Conns        int64
 | 
				
			||||||
 | 
						Fails        int32
 | 
				
			||||||
 | 
						FailTimeout  time.Duration
 | 
				
			||||||
 | 
						Unhealthy    bool
 | 
				
			||||||
 | 
						ExtraHeaders http.Header
 | 
				
			||||||
 | 
						CheckDown    UpstreamHostDownFunc
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (uh *UpstreamHost) Down() bool {
 | 
				
			||||||
 | 
						if uh.CheckDown == nil {
 | 
				
			||||||
 | 
							// Default settings
 | 
				
			||||||
 | 
							return uh.Unhealthy || uh.Fails > 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return uh.CheckDown(uh)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ServeHTTP satisfies the middleware.Handler interface.
 | 
					// ServeHTTP satisfies the middleware.Handler interface.
 | 
				
			||||||
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
					func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, rule := range p.Rules {
 | 
						for _, upstream := range p.Upstreams {
 | 
				
			||||||
		if middleware.Path(r.URL.Path).Matches(rule.From) {
 | 
							if middleware.Path(r.URL.Path).Matches(upstream.From()) {
 | 
				
			||||||
			var base string
 | 
								var replacer middleware.Replacer
 | 
				
			||||||
 | 
								start := time.Now()
 | 
				
			||||||
 | 
								requestHost := r.Host
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if strings.HasPrefix(rule.To, "http") { // includes https
 | 
								// Since Select() should give us "up" hosts, keep retrying
 | 
				
			||||||
				// destination includes a scheme! no need to guess
 | 
								// hosts until timeout (or until we get a nil host).
 | 
				
			||||||
				base = rule.To
 | 
								for time.Now().Sub(start) < (60 * time.Second) {
 | 
				
			||||||
			} else {
 | 
									host := upstream.Select()
 | 
				
			||||||
				// no scheme specified; assume same as request
 | 
									if host == nil {
 | 
				
			||||||
				var scheme string
 | 
										return http.StatusBadGateway, errUnreachable
 | 
				
			||||||
				if r.TLS == nil {
 | 
					 | 
				
			||||||
					scheme = "http"
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					scheme = "https"
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				base = scheme + "://" + rule.To
 | 
					 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									proxy := host.ReverseProxy
 | 
				
			||||||
 | 
									r.Host = host.Name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			baseUrl, err := url.Parse(base)
 | 
									if baseUrl, err := url.Parse(host.Name); err == nil {
 | 
				
			||||||
			if err != nil {
 | 
										r.Host = baseUrl.Host
 | 
				
			||||||
 | 
										if proxy == nil {
 | 
				
			||||||
 | 
											proxy = NewSingleHostReverseProxy(baseUrl)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									} else if proxy == nil {
 | 
				
			||||||
					return http.StatusInternalServerError, err
 | 
										return http.StatusInternalServerError, err
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			r.Host = baseUrl.Host
 | 
									var extraHeaders http.Header
 | 
				
			||||||
 | 
									if host.ExtraHeaders != nil {
 | 
				
			||||||
 | 
										extraHeaders = make(http.Header)
 | 
				
			||||||
 | 
										if replacer == nil {
 | 
				
			||||||
 | 
											rHost := r.Host
 | 
				
			||||||
 | 
											r.Host = requestHost
 | 
				
			||||||
 | 
											replacer = middleware.NewReplacer(r, nil)
 | 
				
			||||||
 | 
											r.Host = rHost
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										for header, values := range host.ExtraHeaders {
 | 
				
			||||||
 | 
											for _, value := range values {
 | 
				
			||||||
 | 
												extraHeaders.Add(header,
 | 
				
			||||||
 | 
													replacer.Replace(value))
 | 
				
			||||||
 | 
												if header == "Host" {
 | 
				
			||||||
 | 
													r.Host = replacer.Replace(value)
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// TODO: Construct this before; not during every request, if possible
 | 
									atomic.AddInt64(&host.Conns, 1)
 | 
				
			||||||
			proxy := httputil.NewSingleHostReverseProxy(baseUrl)
 | 
									backendErr := proxy.ServeHTTP(w, r, extraHeaders)
 | 
				
			||||||
			proxy.ServeHTTP(w, r)
 | 
									atomic.AddInt64(&host.Conns, -1)
 | 
				
			||||||
 | 
									if backendErr == nil {
 | 
				
			||||||
					return 0, nil
 | 
										return 0, nil
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									timeout := host.FailTimeout
 | 
				
			||||||
 | 
									if timeout == 0 {
 | 
				
			||||||
 | 
										timeout = 10 * time.Second
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									atomic.AddInt32(&host.Fails, 1)
 | 
				
			||||||
 | 
									go func(host *UpstreamHost, timeout time.Duration) {
 | 
				
			||||||
 | 
										time.Sleep(timeout)
 | 
				
			||||||
 | 
										atomic.AddInt32(&host.Fails, -1)
 | 
				
			||||||
 | 
									}(host, timeout)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return http.StatusBadGateway, errUnreachable
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return p.Next.ServeHTTP(w, r)
 | 
						return p.Next.ServeHTTP(w, r)
 | 
				
			||||||
@ -55,30 +122,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// New creates a new instance of proxy middleware.
 | 
					// New creates a new instance of proxy middleware.
 | 
				
			||||||
func New(c middleware.Controller) (middleware.Middleware, error) {
 | 
					func New(c middleware.Controller) (middleware.Middleware, error) {
 | 
				
			||||||
	rules, err := parse(c)
 | 
						if upstreams, err := newStaticUpstreams(c); err == nil {
 | 
				
			||||||
	if err != nil {
 | 
							return func(next middleware.Handler) middleware.Handler {
 | 
				
			||||||
 | 
								return Proxy{Next: next, Upstreams: upstreams}
 | 
				
			||||||
 | 
							}, nil
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	return func(next middleware.Handler) middleware.Handler {
 | 
					 | 
				
			||||||
		return Proxy{Next: next, Rules: rules}
 | 
					 | 
				
			||||||
	}, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func parse(c middleware.Controller) ([]Rule, error) {
 | 
					 | 
				
			||||||
	var rules []Rule
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for c.Next() {
 | 
					 | 
				
			||||||
		var rule Rule
 | 
					 | 
				
			||||||
		if !c.Args(&rule.From, &rule.To) {
 | 
					 | 
				
			||||||
			return rules, c.ArgErr()
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		rules = append(rules, rule)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return rules, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Rule struct {
 | 
					 | 
				
			||||||
	From, To string
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										215
									
								
								middleware/proxy/reverseproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								middleware/proxy/reverseproxy.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,215 @@
 | 
				
			|||||||
 | 
					// Copyright 2011 The Go Authors. All rights reserved.
 | 
				
			||||||
 | 
					// Use of this source code is governed by a BSD-style
 | 
				
			||||||
 | 
					// license that can be found in the LICENSE file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HTTP reverse proxy handler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// onExitFlushLoop is a callback set by tests to detect the state of the
 | 
				
			||||||
 | 
					// flushLoop() goroutine.
 | 
				
			||||||
 | 
					var onExitFlushLoop func()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ReverseProxy is an HTTP Handler that takes an incoming request and
 | 
				
			||||||
 | 
					// sends it to another server, proxying the response back to the
 | 
				
			||||||
 | 
					// client.
 | 
				
			||||||
 | 
					type ReverseProxy struct {
 | 
				
			||||||
 | 
						// Director must be a function which modifies
 | 
				
			||||||
 | 
						// the request into a new request to be sent
 | 
				
			||||||
 | 
						// using Transport. Its response is then copied
 | 
				
			||||||
 | 
						// back to the original client unmodified.
 | 
				
			||||||
 | 
						Director func(*http.Request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// The transport used to perform proxy requests.
 | 
				
			||||||
 | 
						// If nil, http.DefaultTransport is used.
 | 
				
			||||||
 | 
						Transport http.RoundTripper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// FlushInterval specifies the flush interval
 | 
				
			||||||
 | 
						// to flush to the client while copying the
 | 
				
			||||||
 | 
						// response body.
 | 
				
			||||||
 | 
						// If zero, no periodic flushing is done.
 | 
				
			||||||
 | 
						FlushInterval time.Duration
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func singleJoiningSlash(a, b string) string {
 | 
				
			||||||
 | 
						aslash := strings.HasSuffix(a, "/")
 | 
				
			||||||
 | 
						bslash := strings.HasPrefix(b, "/")
 | 
				
			||||||
 | 
						switch {
 | 
				
			||||||
 | 
						case aslash && bslash:
 | 
				
			||||||
 | 
							return a + b[1:]
 | 
				
			||||||
 | 
						case !aslash && !bslash:
 | 
				
			||||||
 | 
							return a + "/" + b
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return a + b
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
 | 
				
			||||||
 | 
					// URLs to the scheme, host, and base path provided in target. If the
 | 
				
			||||||
 | 
					// target's path is "/base" and the incoming request was for "/dir",
 | 
				
			||||||
 | 
					// the target request will be for /base/dir.
 | 
				
			||||||
 | 
					func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
 | 
				
			||||||
 | 
						targetQuery := target.RawQuery
 | 
				
			||||||
 | 
						director := func(req *http.Request) {
 | 
				
			||||||
 | 
							req.URL.Scheme = target.Scheme
 | 
				
			||||||
 | 
							req.URL.Host = target.Host
 | 
				
			||||||
 | 
							req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
 | 
				
			||||||
 | 
							if targetQuery == "" || req.URL.RawQuery == "" {
 | 
				
			||||||
 | 
								req.URL.RawQuery = targetQuery + req.URL.RawQuery
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &ReverseProxy{Director: director}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func copyHeader(dst, src http.Header) {
 | 
				
			||||||
 | 
						for k, vv := range src {
 | 
				
			||||||
 | 
							for _, v := range vv {
 | 
				
			||||||
 | 
								dst.Add(k, v)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Hop-by-hop headers. These are removed when sent to the backend.
 | 
				
			||||||
 | 
					// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
 | 
				
			||||||
 | 
					var hopHeaders = []string{
 | 
				
			||||||
 | 
						"Connection",
 | 
				
			||||||
 | 
						"Keep-Alive",
 | 
				
			||||||
 | 
						"Proxy-Authenticate",
 | 
				
			||||||
 | 
						"Proxy-Authorization",
 | 
				
			||||||
 | 
						"Te", // canonicalized version of "TE"
 | 
				
			||||||
 | 
						"Trailers",
 | 
				
			||||||
 | 
						"Transfer-Encoding",
 | 
				
			||||||
 | 
						"Upgrade",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extraHeaders http.Header) error {
 | 
				
			||||||
 | 
						transport := p.Transport
 | 
				
			||||||
 | 
						if transport == nil {
 | 
				
			||||||
 | 
							transport = http.DefaultTransport
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						outreq := new(http.Request)
 | 
				
			||||||
 | 
						*outreq = *req // includes shallow copies of maps, but okay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p.Director(outreq)
 | 
				
			||||||
 | 
						outreq.Proto = "HTTP/1.1"
 | 
				
			||||||
 | 
						outreq.ProtoMajor = 1
 | 
				
			||||||
 | 
						outreq.ProtoMinor = 1
 | 
				
			||||||
 | 
						outreq.Close = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Remove hop-by-hop headers to the backend.  Especially
 | 
				
			||||||
 | 
						// important is "Connection" because we want a persistent
 | 
				
			||||||
 | 
						// connection, regardless of what the client sent to us.  This
 | 
				
			||||||
 | 
						// is modifying the same underlying map from req (shallow
 | 
				
			||||||
 | 
						// copied above) so we only copy it if necessary.
 | 
				
			||||||
 | 
						copiedHeaders := false
 | 
				
			||||||
 | 
						for _, h := range hopHeaders {
 | 
				
			||||||
 | 
							if outreq.Header.Get(h) != "" {
 | 
				
			||||||
 | 
								if !copiedHeaders {
 | 
				
			||||||
 | 
									outreq.Header = make(http.Header)
 | 
				
			||||||
 | 
									copyHeader(outreq.Header, req.Header)
 | 
				
			||||||
 | 
									copiedHeaders = true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								outreq.Header.Del(h)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
 | 
				
			||||||
 | 
							// If we aren't the first proxy retain prior
 | 
				
			||||||
 | 
							// X-Forwarded-For information as a comma+space
 | 
				
			||||||
 | 
							// separated list and fold multiple headers into one.
 | 
				
			||||||
 | 
							if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
 | 
				
			||||||
 | 
								clientIP = strings.Join(prior, ", ") + ", " + clientIP
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							outreq.Header.Set("X-Forwarded-For", clientIP)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if extraHeaders != nil {
 | 
				
			||||||
 | 
							for k, v := range extraHeaders {
 | 
				
			||||||
 | 
								outreq.Header[k] = v
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						res, err := transport.RoundTrip(outreq)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer res.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, h := range hopHeaders {
 | 
				
			||||||
 | 
							res.Header.Del(h)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						copyHeader(rw.Header(), res.Header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rw.WriteHeader(res.StatusCode)
 | 
				
			||||||
 | 
						p.copyResponse(rw, res.Body)
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
 | 
				
			||||||
 | 
						if p.FlushInterval != 0 {
 | 
				
			||||||
 | 
							if wf, ok := dst.(writeFlusher); ok {
 | 
				
			||||||
 | 
								mlw := &maxLatencyWriter{
 | 
				
			||||||
 | 
									dst:     wf,
 | 
				
			||||||
 | 
									latency: p.FlushInterval,
 | 
				
			||||||
 | 
									done:    make(chan bool),
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								go mlw.flushLoop()
 | 
				
			||||||
 | 
								defer mlw.stop()
 | 
				
			||||||
 | 
								dst = mlw
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						io.Copy(dst, src)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type writeFlusher interface {
 | 
				
			||||||
 | 
						io.Writer
 | 
				
			||||||
 | 
						http.Flusher
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type maxLatencyWriter struct {
 | 
				
			||||||
 | 
						dst     writeFlusher
 | 
				
			||||||
 | 
						latency time.Duration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lk   sync.Mutex // protects Write + Flush
 | 
				
			||||||
 | 
						done chan bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *maxLatencyWriter) Write(p []byte) (int, error) {
 | 
				
			||||||
 | 
						m.lk.Lock()
 | 
				
			||||||
 | 
						defer m.lk.Unlock()
 | 
				
			||||||
 | 
						return m.dst.Write(p)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *maxLatencyWriter) flushLoop() {
 | 
				
			||||||
 | 
						t := time.NewTicker(m.latency)
 | 
				
			||||||
 | 
						defer t.Stop()
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-m.done:
 | 
				
			||||||
 | 
								if onExitFlushLoop != nil {
 | 
				
			||||||
 | 
									onExitFlushLoop()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							case <-t.C:
 | 
				
			||||||
 | 
								m.lk.Lock()
 | 
				
			||||||
 | 
								m.dst.Flush()
 | 
				
			||||||
 | 
								m.lk.Unlock()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *maxLatencyWriter) stop() { m.done <- true }
 | 
				
			||||||
							
								
								
									
										203
									
								
								middleware/proxy/upstream.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								middleware/proxy/upstream.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,203 @@
 | 
				
			|||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/mholt/caddy/middleware"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type staticUpstream struct {
 | 
				
			||||||
 | 
						from   string
 | 
				
			||||||
 | 
						Hosts  HostPool
 | 
				
			||||||
 | 
						Policy Policy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						FailTimeout time.Duration
 | 
				
			||||||
 | 
						MaxFails    int32
 | 
				
			||||||
 | 
						HealthCheck struct {
 | 
				
			||||||
 | 
							Path     string
 | 
				
			||||||
 | 
							Interval time.Duration
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newStaticUpstreams(c middleware.Controller) ([]Upstream, error) {
 | 
				
			||||||
 | 
						var upstreams []Upstream
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for c.Next() {
 | 
				
			||||||
 | 
							upstream := &staticUpstream{
 | 
				
			||||||
 | 
								from:        "",
 | 
				
			||||||
 | 
								Hosts:       nil,
 | 
				
			||||||
 | 
								Policy:      &Random{},
 | 
				
			||||||
 | 
								FailTimeout: 10 * time.Second,
 | 
				
			||||||
 | 
								MaxFails:    1,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							var proxyHeaders http.Header
 | 
				
			||||||
 | 
							if !c.Args(&upstream.from) {
 | 
				
			||||||
 | 
								return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							to := c.RemainingArgs()
 | 
				
			||||||
 | 
							if len(to) == 0 {
 | 
				
			||||||
 | 
								return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for c.NextBlock() {
 | 
				
			||||||
 | 
								switch c.Val() {
 | 
				
			||||||
 | 
								case "policy":
 | 
				
			||||||
 | 
									if !c.NextArg() {
 | 
				
			||||||
 | 
										return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									switch c.Val() {
 | 
				
			||||||
 | 
									case "random":
 | 
				
			||||||
 | 
										upstream.Policy = &Random{}
 | 
				
			||||||
 | 
									case "round_robin":
 | 
				
			||||||
 | 
										upstream.Policy = &RoundRobin{}
 | 
				
			||||||
 | 
									case "least_conn":
 | 
				
			||||||
 | 
										upstream.Policy = &LeastConn{}
 | 
				
			||||||
 | 
									default:
 | 
				
			||||||
 | 
										return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "fail_timeout":
 | 
				
			||||||
 | 
									if !c.NextArg() {
 | 
				
			||||||
 | 
										return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if dur, err := time.ParseDuration(c.Val()); err == nil {
 | 
				
			||||||
 | 
										upstream.FailTimeout = dur
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										return upstreams, err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "max_fails":
 | 
				
			||||||
 | 
									if !c.NextArg() {
 | 
				
			||||||
 | 
										return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if n, err := strconv.Atoi(c.Val()); err == nil {
 | 
				
			||||||
 | 
										upstream.MaxFails = int32(n)
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										return upstreams, err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "health_check":
 | 
				
			||||||
 | 
									if !c.NextArg() {
 | 
				
			||||||
 | 
										return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									upstream.HealthCheck.Path = c.Val()
 | 
				
			||||||
 | 
									upstream.HealthCheck.Interval = 30 * time.Second
 | 
				
			||||||
 | 
									if c.NextArg() {
 | 
				
			||||||
 | 
										if dur, err := time.ParseDuration(c.Val()); err == nil {
 | 
				
			||||||
 | 
											upstream.HealthCheck.Interval = dur
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											return upstreams, err
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "proxy_header":
 | 
				
			||||||
 | 
									var header, value string
 | 
				
			||||||
 | 
									if !c.Args(&header, &value) {
 | 
				
			||||||
 | 
										return upstreams, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if proxyHeaders == nil {
 | 
				
			||||||
 | 
										proxyHeaders = make(map[string][]string)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									proxyHeaders.Add(header, value)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							upstream.Hosts = make([]*UpstreamHost, len(to))
 | 
				
			||||||
 | 
							for i, host := range to {
 | 
				
			||||||
 | 
								if !strings.HasPrefix(host, "http") {
 | 
				
			||||||
 | 
									host = "http://" + host
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								uh := &UpstreamHost{
 | 
				
			||||||
 | 
									Name:         host,
 | 
				
			||||||
 | 
									Conns:        0,
 | 
				
			||||||
 | 
									Fails:        0,
 | 
				
			||||||
 | 
									FailTimeout:  upstream.FailTimeout,
 | 
				
			||||||
 | 
									Unhealthy:    false,
 | 
				
			||||||
 | 
									ExtraHeaders: proxyHeaders,
 | 
				
			||||||
 | 
									CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
 | 
				
			||||||
 | 
										return func(uh *UpstreamHost) bool {
 | 
				
			||||||
 | 
											if uh.Unhealthy {
 | 
				
			||||||
 | 
												return true
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
											if uh.Fails >= upstream.MaxFails &&
 | 
				
			||||||
 | 
												upstream.MaxFails != 0 {
 | 
				
			||||||
 | 
												return true
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
											return false
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}(upstream),
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if baseUrl, err := url.Parse(uh.Name); err == nil {
 | 
				
			||||||
 | 
									uh.ReverseProxy = NewSingleHostReverseProxy(baseUrl)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									return upstreams, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								upstream.Hosts[i] = uh
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if upstream.HealthCheck.Path != "" {
 | 
				
			||||||
 | 
								go upstream.healthCheckWorker(nil)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							upstreams = append(upstreams, upstream)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return upstreams, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *staticUpstream) healthCheck() {
 | 
				
			||||||
 | 
						for _, host := range u.Hosts {
 | 
				
			||||||
 | 
							hostUrl := host.Name + u.HealthCheck.Path
 | 
				
			||||||
 | 
							if r, err := http.Get(hostUrl); err == nil {
 | 
				
			||||||
 | 
								io.Copy(ioutil.Discard, r.Body)
 | 
				
			||||||
 | 
								r.Body.Close()
 | 
				
			||||||
 | 
								host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								host.Unhealthy = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *staticUpstream) healthCheckWorker(stop chan struct{}) {
 | 
				
			||||||
 | 
						ticker := time.NewTicker(u.HealthCheck.Interval)
 | 
				
			||||||
 | 
						u.healthCheck()
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-ticker.C:
 | 
				
			||||||
 | 
								u.healthCheck()
 | 
				
			||||||
 | 
							case <-stop:
 | 
				
			||||||
 | 
								// TODO: the library should provide a stop channel and global
 | 
				
			||||||
 | 
								// waitgroup to allow goroutines started by plugins a chance
 | 
				
			||||||
 | 
								// to clean themselves up.
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *staticUpstream) From() string {
 | 
				
			||||||
 | 
						return u.from
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (u *staticUpstream) Select() *UpstreamHost {
 | 
				
			||||||
 | 
						pool := u.Hosts
 | 
				
			||||||
 | 
						if len(pool) == 1 {
 | 
				
			||||||
 | 
							if pool[0].Down() {
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return pool[0]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						allDown := true
 | 
				
			||||||
 | 
						for _, host := range pool {
 | 
				
			||||||
 | 
							if !host.Down() {
 | 
				
			||||||
 | 
								allDown = false
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if allDown {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if u.Policy == nil {
 | 
				
			||||||
 | 
							return (&Random{}).Select(pool)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return u.Policy.Select(pool)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										43
									
								
								middleware/proxy/upstream_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								middleware/proxy/upstream_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,43 @@
 | 
				
			|||||||
 | 
					package proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestHealthCheck(t *testing.T) {
 | 
				
			||||||
 | 
						upstream := &staticUpstream{
 | 
				
			||||||
 | 
							from:        "",
 | 
				
			||||||
 | 
							Hosts:       testPool(),
 | 
				
			||||||
 | 
							Policy:      &Random{},
 | 
				
			||||||
 | 
							FailTimeout: 10 * time.Second,
 | 
				
			||||||
 | 
							MaxFails:    1,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						upstream.healthCheck()
 | 
				
			||||||
 | 
						if upstream.Hosts[0].Down() {
 | 
				
			||||||
 | 
							t.Error("Expected first host in testpool to not fail healthcheck.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !upstream.Hosts[1].Down() {
 | 
				
			||||||
 | 
							t.Error("Expected second host in testpool to fail healthcheck.")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSelect(t *testing.T) {
 | 
				
			||||||
 | 
						upstream := &staticUpstream{
 | 
				
			||||||
 | 
							from:        "",
 | 
				
			||||||
 | 
							Hosts:       testPool()[:3],
 | 
				
			||||||
 | 
							Policy:      &Random{},
 | 
				
			||||||
 | 
							FailTimeout: 10 * time.Second,
 | 
				
			||||||
 | 
							MaxFails:    1,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						upstream.Hosts[0].Unhealthy = true
 | 
				
			||||||
 | 
						upstream.Hosts[1].Unhealthy = true
 | 
				
			||||||
 | 
						upstream.Hosts[2].Unhealthy = true
 | 
				
			||||||
 | 
						if h := upstream.Select(); h != nil {
 | 
				
			||||||
 | 
							t.Error("Expected select to return nil as all host are down")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						upstream.Hosts[2].Unhealthy = false
 | 
				
			||||||
 | 
						if h := upstream.Select(); h == nil {
 | 
				
			||||||
 | 
							t.Error("Expected select to not return nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -8,17 +8,21 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// replacer is a type which can replace placeholder
 | 
					// Replacer is a type which can replace placeholder
 | 
				
			||||||
// substrings in a string with actual values from a
 | 
					// substrings in a string with actual values from a
 | 
				
			||||||
// http.Request and responseRecorder. Always use
 | 
					// http.Request and responseRecorder. Always use
 | 
				
			||||||
// NewReplacer to get one of these.
 | 
					// NewReplacer to get one of these.
 | 
				
			||||||
 | 
					type Replacer interface {
 | 
				
			||||||
 | 
						Replace(string) string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type replacer map[string]string
 | 
					type replacer map[string]string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewReplacer makes a new replacer based on r and rr.
 | 
					// NewReplacer makes a new replacer based on r and rr.
 | 
				
			||||||
// Do not create a new replacer until r and rr have all
 | 
					// Do not create a new replacer until r and rr have all
 | 
				
			||||||
// the needed values, because this function copies those
 | 
					// the needed values, because this function copies those
 | 
				
			||||||
// values into the replacer.
 | 
					// values into the replacer.
 | 
				
			||||||
func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
 | 
					func NewReplacer(r *http.Request, rr *responseRecorder) Replacer {
 | 
				
			||||||
	rep := replacer{
 | 
						rep := replacer{
 | 
				
			||||||
		"{method}": r.Method,
 | 
							"{method}": r.Method,
 | 
				
			||||||
		"{scheme}": func() string {
 | 
							"{scheme}": func() string {
 | 
				
			||||||
@ -33,6 +37,9 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
 | 
				
			|||||||
		"{fragment}": r.URL.Fragment,
 | 
							"{fragment}": r.URL.Fragment,
 | 
				
			||||||
		"{proto}":    r.Proto,
 | 
							"{proto}":    r.Proto,
 | 
				
			||||||
		"{remote}": func() string {
 | 
							"{remote}": func() string {
 | 
				
			||||||
 | 
								if fwdFor := r.Header.Get("X-Forwarded-For"); fwdFor != "" {
 | 
				
			||||||
 | 
									return fwdFor
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			host, _, err := net.SplitHostPort(r.RemoteAddr)
 | 
								host, _, err := net.SplitHostPort(r.RemoteAddr)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return r.RemoteAddr
 | 
									return r.RemoteAddr
 | 
				
			||||||
@ -50,9 +57,11 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
 | 
				
			|||||||
		"{when}": func() string {
 | 
							"{when}": func() string {
 | 
				
			||||||
			return time.Now().Format(timeFormat)
 | 
								return time.Now().Format(timeFormat)
 | 
				
			||||||
		}(),
 | 
							}(),
 | 
				
			||||||
		"{status}":  strconv.Itoa(rr.status),
 | 
						}
 | 
				
			||||||
		"{size}":    strconv.Itoa(rr.size),
 | 
						if rr != nil {
 | 
				
			||||||
		"{latency}": time.Since(rr.start).String(),
 | 
							rep["{status}"] = strconv.Itoa(rr.status)
 | 
				
			||||||
 | 
							rep["{size}"] = strconv.Itoa(rr.size)
 | 
				
			||||||
 | 
							rep["{latency}"] = time.Since(rr.start).String()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Header placeholders
 | 
						// Header placeholders
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user