mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	Merge pull request #40 from ChannelMeter/proxy-middleware
Proxy Middleware: Add support for multiple backends, load balancing & healthchecks
This commit is contained in:
		
						commit
						5f32f9b1c8
					
				
							
								
								
									
										91
									
								
								middleware/proxy/policy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								middleware/proxy/policy.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type HostPool []*UpstreamHost
 | 
			
		||||
 | 
			
		||||
// Policy decides how a host will be selected from a pool.
 | 
			
		||||
type Policy interface {
 | 
			
		||||
	Select(pool HostPool) *UpstreamHost
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The random policy randomly selected an up host from the pool.
 | 
			
		||||
type Random struct{}
 | 
			
		||||
 | 
			
		||||
func (r *Random) Select(pool HostPool) *UpstreamHost {
 | 
			
		||||
	// instead of just generating a random index
 | 
			
		||||
	// this is done to prevent selecting a down host
 | 
			
		||||
	var randHost *UpstreamHost
 | 
			
		||||
	count := 0
 | 
			
		||||
	for _, host := range pool {
 | 
			
		||||
		if host.Down() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		count++
 | 
			
		||||
		if count == 1 {
 | 
			
		||||
			randHost = host
 | 
			
		||||
		} else {
 | 
			
		||||
			r := rand.Int() % count
 | 
			
		||||
			if r == (count - 1) {
 | 
			
		||||
				randHost = host
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return randHost
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The least_conn policy selects a host with the least connections.
 | 
			
		||||
// If multiple hosts have the least amount of connections, one is randomly
 | 
			
		||||
// chosen.
 | 
			
		||||
type LeastConn struct{}
 | 
			
		||||
 | 
			
		||||
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
 | 
			
		||||
	var bestHost *UpstreamHost
 | 
			
		||||
	count := 0
 | 
			
		||||
	leastConn := int64(1<<63 - 1)
 | 
			
		||||
	for _, host := range pool {
 | 
			
		||||
		if host.Down() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		hostConns := host.Conns
 | 
			
		||||
		if hostConns < leastConn {
 | 
			
		||||
			bestHost = host
 | 
			
		||||
			leastConn = hostConns
 | 
			
		||||
			count = 1
 | 
			
		||||
		} else if hostConns == leastConn {
 | 
			
		||||
			// randomly select host among hosts with least connections
 | 
			
		||||
			count++
 | 
			
		||||
			if count == 1 {
 | 
			
		||||
				bestHost = host
 | 
			
		||||
			} else {
 | 
			
		||||
				r := rand.Int() % count
 | 
			
		||||
				if r == (count - 1) {
 | 
			
		||||
					bestHost = host
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return bestHost
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The round_robin policy selects a host based on round robin ordering.
 | 
			
		||||
type RoundRobin struct {
 | 
			
		||||
	Robin uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
 | 
			
		||||
	poolLen := uint32(len(pool))
 | 
			
		||||
	selection := atomic.AddUint32(&r.Robin, 1) % poolLen
 | 
			
		||||
	host := pool[selection]
 | 
			
		||||
	// if the currently selected host is down, just ffwd to up host
 | 
			
		||||
	for i := uint32(1); host.Down() && i < poolLen; i++ {
 | 
			
		||||
		host = pool[(selection+i)%poolLen]
 | 
			
		||||
	}
 | 
			
		||||
	if host.Down() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return host
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										57
									
								
								middleware/proxy/policy_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								middleware/proxy/policy_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,57 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func testPool() HostPool {
 | 
			
		||||
	pool := []*UpstreamHost{
 | 
			
		||||
		&UpstreamHost{
 | 
			
		||||
			Name: "http://google.com", // this should resolve (healthcheck test)
 | 
			
		||||
		},
 | 
			
		||||
		&UpstreamHost{
 | 
			
		||||
			Name: "http://shouldnot.resolve", // this shouldn't
 | 
			
		||||
		},
 | 
			
		||||
		&UpstreamHost{
 | 
			
		||||
			Name: "http://C",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	return HostPool(pool)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRoundRobinPolicy(t *testing.T) {
 | 
			
		||||
	pool := testPool()
 | 
			
		||||
	rrPolicy := &RoundRobin{}
 | 
			
		||||
	h := rrPolicy.Select(pool)
 | 
			
		||||
	// First selected host is 1, because counter starts at 0
 | 
			
		||||
	// and increments before host is selected
 | 
			
		||||
	if h != pool[1] {
 | 
			
		||||
		t.Error("Expected first round robin host to be second host in the pool.")
 | 
			
		||||
	}
 | 
			
		||||
	h = rrPolicy.Select(pool)
 | 
			
		||||
	if h != pool[2] {
 | 
			
		||||
		t.Error("Expected second round robin host to be third host in the pool.")
 | 
			
		||||
	}
 | 
			
		||||
	// mark host as down
 | 
			
		||||
	pool[0].Unhealthy = true
 | 
			
		||||
	h = rrPolicy.Select(pool)
 | 
			
		||||
	if h != pool[1] {
 | 
			
		||||
		t.Error("Expected third round robin host to be first host in the pool.")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestLeastConnPolicy(t *testing.T) {
 | 
			
		||||
	pool := testPool()
 | 
			
		||||
	lcPolicy := &LeastConn{}
 | 
			
		||||
	pool[0].Conns = 10
 | 
			
		||||
	pool[1].Conns = 10
 | 
			
		||||
	h := lcPolicy.Select(pool)
 | 
			
		||||
	if h != pool[2] {
 | 
			
		||||
		t.Error("Expected least connection host to be third host.")
 | 
			
		||||
	}
 | 
			
		||||
	pool[2].Conns = 100
 | 
			
		||||
	h = lcPolicy.Select(pool)
 | 
			
		||||
	if h != pool[0] && h != pool[1] {
 | 
			
		||||
		t.Error("Expected least connection host to be first or second host.")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -2,51 +2,118 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httputil"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/mholt/caddy/middleware"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var errUnreachable = errors.New("Unreachable backend")
 | 
			
		||||
 | 
			
		||||
// Proxy represents a middleware instance that can proxy requests.
 | 
			
		||||
type Proxy struct {
 | 
			
		||||
	Next  middleware.Handler
 | 
			
		||||
	Rules []Rule
 | 
			
		||||
	Next      middleware.Handler
 | 
			
		||||
	Upstreams []Upstream
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// An upstream manages a pool of proxy upstream hosts. Select should return a
 | 
			
		||||
// suitable upstream host, or nil if no such hosts are available.
 | 
			
		||||
type Upstream interface {
 | 
			
		||||
	// The path this upstream host should be routed on
 | 
			
		||||
	From() string
 | 
			
		||||
	// Selects an upstream host to be routed to.
 | 
			
		||||
	Select() *UpstreamHost
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type UpstreamHostDownFunc func(*UpstreamHost) bool
 | 
			
		||||
 | 
			
		||||
// An UpstreamHost represents a single proxy upstream
 | 
			
		||||
type UpstreamHost struct {
 | 
			
		||||
	// The hostname of this upstream host
 | 
			
		||||
	Name         string
 | 
			
		||||
	ReverseProxy *ReverseProxy
 | 
			
		||||
	Conns        int64
 | 
			
		||||
	Fails        int32
 | 
			
		||||
	FailTimeout  time.Duration
 | 
			
		||||
	Unhealthy    bool
 | 
			
		||||
	ExtraHeaders http.Header
 | 
			
		||||
	CheckDown    UpstreamHostDownFunc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (uh *UpstreamHost) Down() bool {
 | 
			
		||||
	if uh.CheckDown == nil {
 | 
			
		||||
		// Default settings
 | 
			
		||||
		return uh.Unhealthy || uh.Fails > 0
 | 
			
		||||
	}
 | 
			
		||||
	return uh.CheckDown(uh)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeHTTP satisfies the middleware.Handler interface.
 | 
			
		||||
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
			
		||||
 | 
			
		||||
	for _, rule := range p.Rules {
 | 
			
		||||
		if middleware.Path(r.URL.Path).Matches(rule.From) {
 | 
			
		||||
			var base string
 | 
			
		||||
	for _, upstream := range p.Upstreams {
 | 
			
		||||
		if middleware.Path(r.URL.Path).Matches(upstream.From()) {
 | 
			
		||||
			var replacer middleware.Replacer
 | 
			
		||||
			start := time.Now()
 | 
			
		||||
			requestHost := r.Host
 | 
			
		||||
 | 
			
		||||
			if strings.HasPrefix(rule.To, "http") { // includes https
 | 
			
		||||
				// destination includes a scheme! no need to guess
 | 
			
		||||
				base = rule.To
 | 
			
		||||
			} else {
 | 
			
		||||
				// no scheme specified; assume same as request
 | 
			
		||||
				var scheme string
 | 
			
		||||
				if r.TLS == nil {
 | 
			
		||||
					scheme = "http"
 | 
			
		||||
				} else {
 | 
			
		||||
					scheme = "https"
 | 
			
		||||
			// Since Select() should give us "up" hosts, keep retrying
 | 
			
		||||
			// hosts until timeout (or until we get a nil host).
 | 
			
		||||
			for time.Now().Sub(start) < (60 * time.Second) {
 | 
			
		||||
				host := upstream.Select()
 | 
			
		||||
				if host == nil {
 | 
			
		||||
					return http.StatusBadGateway, errUnreachable
 | 
			
		||||
				}
 | 
			
		||||
				base = scheme + "://" + rule.To
 | 
			
		||||
			}
 | 
			
		||||
				proxy := host.ReverseProxy
 | 
			
		||||
				r.Host = host.Name
 | 
			
		||||
 | 
			
		||||
			baseUrl, err := url.Parse(base)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return http.StatusInternalServerError, err
 | 
			
		||||
			}
 | 
			
		||||
			r.Host = baseUrl.Host
 | 
			
		||||
				if baseUrl, err := url.Parse(host.Name); err == nil {
 | 
			
		||||
					r.Host = baseUrl.Host
 | 
			
		||||
					if proxy == nil {
 | 
			
		||||
						proxy = NewSingleHostReverseProxy(baseUrl)
 | 
			
		||||
					}
 | 
			
		||||
				} else if proxy == nil {
 | 
			
		||||
					return http.StatusInternalServerError, err
 | 
			
		||||
				}
 | 
			
		||||
				var extraHeaders http.Header
 | 
			
		||||
				if host.ExtraHeaders != nil {
 | 
			
		||||
					extraHeaders = make(http.Header)
 | 
			
		||||
					if replacer == nil {
 | 
			
		||||
						rHost := r.Host
 | 
			
		||||
						r.Host = requestHost
 | 
			
		||||
						replacer = middleware.NewReplacer(r, nil)
 | 
			
		||||
						r.Host = rHost
 | 
			
		||||
					}
 | 
			
		||||
					for header, values := range host.ExtraHeaders {
 | 
			
		||||
						for _, value := range values {
 | 
			
		||||
							extraHeaders.Add(header,
 | 
			
		||||
								replacer.Replace(value))
 | 
			
		||||
							if header == "Host" {
 | 
			
		||||
								r.Host = replacer.Replace(value)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			// TODO: Construct this before; not during every request, if possible
 | 
			
		||||
			proxy := httputil.NewSingleHostReverseProxy(baseUrl)
 | 
			
		||||
			proxy.ServeHTTP(w, r)
 | 
			
		||||
			return 0, nil
 | 
			
		||||
				atomic.AddInt64(&host.Conns, 1)
 | 
			
		||||
				backendErr := proxy.ServeHTTP(w, r, extraHeaders)
 | 
			
		||||
				atomic.AddInt64(&host.Conns, -1)
 | 
			
		||||
				if backendErr == nil {
 | 
			
		||||
					return 0, nil
 | 
			
		||||
				}
 | 
			
		||||
				timeout := host.FailTimeout
 | 
			
		||||
				if timeout == 0 {
 | 
			
		||||
					timeout = 10 * time.Second
 | 
			
		||||
				}
 | 
			
		||||
				atomic.AddInt32(&host.Fails, 1)
 | 
			
		||||
				go func(host *UpstreamHost, timeout time.Duration) {
 | 
			
		||||
					time.Sleep(timeout)
 | 
			
		||||
					atomic.AddInt32(&host.Fails, -1)
 | 
			
		||||
				}(host, timeout)
 | 
			
		||||
			}
 | 
			
		||||
			return http.StatusBadGateway, errUnreachable
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -55,30 +122,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
			
		||||
 | 
			
		||||
// New creates a new instance of proxy middleware.
 | 
			
		||||
func New(c middleware.Controller) (middleware.Middleware, error) {
 | 
			
		||||
	rules, err := parse(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
	if upstreams, err := newStaticUpstreams(c); err == nil {
 | 
			
		||||
		return func(next middleware.Handler) middleware.Handler {
 | 
			
		||||
			return Proxy{Next: next, Upstreams: upstreams}
 | 
			
		||||
		}, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return func(next middleware.Handler) middleware.Handler {
 | 
			
		||||
		return Proxy{Next: next, Rules: rules}
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parse(c middleware.Controller) ([]Rule, error) {
 | 
			
		||||
	var rules []Rule
 | 
			
		||||
 | 
			
		||||
	for c.Next() {
 | 
			
		||||
		var rule Rule
 | 
			
		||||
		if !c.Args(&rule.From, &rule.To) {
 | 
			
		||||
			return rules, c.ArgErr()
 | 
			
		||||
		}
 | 
			
		||||
		rules = append(rules, rule)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rules, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Rule struct {
 | 
			
		||||
	From, To string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										215
									
								
								middleware/proxy/reverseproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								middleware/proxy/reverseproxy.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,215 @@
 | 
			
		||||
// Copyright 2011 The Go Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a BSD-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
// HTTP reverse proxy handler
 | 
			
		||||
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// onExitFlushLoop is a callback set by tests to detect the state of the
 | 
			
		||||
// flushLoop() goroutine.
 | 
			
		||||
var onExitFlushLoop func()
 | 
			
		||||
 | 
			
		||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
 | 
			
		||||
// sends it to another server, proxying the response back to the
 | 
			
		||||
// client.
 | 
			
		||||
type ReverseProxy struct {
 | 
			
		||||
	// Director must be a function which modifies
 | 
			
		||||
	// the request into a new request to be sent
 | 
			
		||||
	// using Transport. Its response is then copied
 | 
			
		||||
	// back to the original client unmodified.
 | 
			
		||||
	Director func(*http.Request)
 | 
			
		||||
 | 
			
		||||
	// The transport used to perform proxy requests.
 | 
			
		||||
	// If nil, http.DefaultTransport is used.
 | 
			
		||||
	Transport http.RoundTripper
 | 
			
		||||
 | 
			
		||||
	// FlushInterval specifies the flush interval
 | 
			
		||||
	// to flush to the client while copying the
 | 
			
		||||
	// response body.
 | 
			
		||||
	// If zero, no periodic flushing is done.
 | 
			
		||||
	FlushInterval time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func singleJoiningSlash(a, b string) string {
 | 
			
		||||
	aslash := strings.HasSuffix(a, "/")
 | 
			
		||||
	bslash := strings.HasPrefix(b, "/")
 | 
			
		||||
	switch {
 | 
			
		||||
	case aslash && bslash:
 | 
			
		||||
		return a + b[1:]
 | 
			
		||||
	case !aslash && !bslash:
 | 
			
		||||
		return a + "/" + b
 | 
			
		||||
	}
 | 
			
		||||
	return a + b
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
 | 
			
		||||
// URLs to the scheme, host, and base path provided in target. If the
 | 
			
		||||
// target's path is "/base" and the incoming request was for "/dir",
 | 
			
		||||
// the target request will be for /base/dir.
 | 
			
		||||
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
 | 
			
		||||
	targetQuery := target.RawQuery
 | 
			
		||||
	director := func(req *http.Request) {
 | 
			
		||||
		req.URL.Scheme = target.Scheme
 | 
			
		||||
		req.URL.Host = target.Host
 | 
			
		||||
		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
 | 
			
		||||
		if targetQuery == "" || req.URL.RawQuery == "" {
 | 
			
		||||
			req.URL.RawQuery = targetQuery + req.URL.RawQuery
 | 
			
		||||
		} else {
 | 
			
		||||
			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &ReverseProxy{Director: director}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func copyHeader(dst, src http.Header) {
 | 
			
		||||
	for k, vv := range src {
 | 
			
		||||
		for _, v := range vv {
 | 
			
		||||
			dst.Add(k, v)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Hop-by-hop headers. These are removed when sent to the backend.
 | 
			
		||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
 | 
			
		||||
var hopHeaders = []string{
 | 
			
		||||
	"Connection",
 | 
			
		||||
	"Keep-Alive",
 | 
			
		||||
	"Proxy-Authenticate",
 | 
			
		||||
	"Proxy-Authorization",
 | 
			
		||||
	"Te", // canonicalized version of "TE"
 | 
			
		||||
	"Trailers",
 | 
			
		||||
	"Transfer-Encoding",
 | 
			
		||||
	"Upgrade",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extraHeaders http.Header) error {
 | 
			
		||||
	transport := p.Transport
 | 
			
		||||
	if transport == nil {
 | 
			
		||||
		transport = http.DefaultTransport
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	outreq := new(http.Request)
 | 
			
		||||
	*outreq = *req // includes shallow copies of maps, but okay
 | 
			
		||||
 | 
			
		||||
	p.Director(outreq)
 | 
			
		||||
	outreq.Proto = "HTTP/1.1"
 | 
			
		||||
	outreq.ProtoMajor = 1
 | 
			
		||||
	outreq.ProtoMinor = 1
 | 
			
		||||
	outreq.Close = false
 | 
			
		||||
 | 
			
		||||
	// Remove hop-by-hop headers to the backend.  Especially
 | 
			
		||||
	// important is "Connection" because we want a persistent
 | 
			
		||||
	// connection, regardless of what the client sent to us.  This
 | 
			
		||||
	// is modifying the same underlying map from req (shallow
 | 
			
		||||
	// copied above) so we only copy it if necessary.
 | 
			
		||||
	copiedHeaders := false
 | 
			
		||||
	for _, h := range hopHeaders {
 | 
			
		||||
		if outreq.Header.Get(h) != "" {
 | 
			
		||||
			if !copiedHeaders {
 | 
			
		||||
				outreq.Header = make(http.Header)
 | 
			
		||||
				copyHeader(outreq.Header, req.Header)
 | 
			
		||||
				copiedHeaders = true
 | 
			
		||||
			}
 | 
			
		||||
			outreq.Header.Del(h)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
 | 
			
		||||
		// If we aren't the first proxy retain prior
 | 
			
		||||
		// X-Forwarded-For information as a comma+space
 | 
			
		||||
		// separated list and fold multiple headers into one.
 | 
			
		||||
		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
 | 
			
		||||
			clientIP = strings.Join(prior, ", ") + ", " + clientIP
 | 
			
		||||
		}
 | 
			
		||||
		outreq.Header.Set("X-Forwarded-For", clientIP)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if extraHeaders != nil {
 | 
			
		||||
		for k, v := range extraHeaders {
 | 
			
		||||
			outreq.Header[k] = v
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res, err := transport.RoundTrip(outreq)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
 | 
			
		||||
	for _, h := range hopHeaders {
 | 
			
		||||
		res.Header.Del(h)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	copyHeader(rw.Header(), res.Header)
 | 
			
		||||
 | 
			
		||||
	rw.WriteHeader(res.StatusCode)
 | 
			
		||||
	p.copyResponse(rw, res.Body)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
 | 
			
		||||
	if p.FlushInterval != 0 {
 | 
			
		||||
		if wf, ok := dst.(writeFlusher); ok {
 | 
			
		||||
			mlw := &maxLatencyWriter{
 | 
			
		||||
				dst:     wf,
 | 
			
		||||
				latency: p.FlushInterval,
 | 
			
		||||
				done:    make(chan bool),
 | 
			
		||||
			}
 | 
			
		||||
			go mlw.flushLoop()
 | 
			
		||||
			defer mlw.stop()
 | 
			
		||||
			dst = mlw
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	io.Copy(dst, src)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type writeFlusher interface {
 | 
			
		||||
	io.Writer
 | 
			
		||||
	http.Flusher
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type maxLatencyWriter struct {
 | 
			
		||||
	dst     writeFlusher
 | 
			
		||||
	latency time.Duration
 | 
			
		||||
 | 
			
		||||
	lk   sync.Mutex // protects Write + Flush
 | 
			
		||||
	done chan bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
 | 
			
		||||
	m.lk.Lock()
 | 
			
		||||
	defer m.lk.Unlock()
 | 
			
		||||
	return m.dst.Write(p)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *maxLatencyWriter) flushLoop() {
 | 
			
		||||
	t := time.NewTicker(m.latency)
 | 
			
		||||
	defer t.Stop()
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-m.done:
 | 
			
		||||
			if onExitFlushLoop != nil {
 | 
			
		||||
				onExitFlushLoop()
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		case <-t.C:
 | 
			
		||||
			m.lk.Lock()
 | 
			
		||||
			m.dst.Flush()
 | 
			
		||||
			m.lk.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *maxLatencyWriter) stop() { m.done <- true }
 | 
			
		||||
							
								
								
									
										203
									
								
								middleware/proxy/upstream.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								middleware/proxy/upstream.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,203 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/mholt/caddy/middleware"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type staticUpstream struct {
 | 
			
		||||
	from   string
 | 
			
		||||
	Hosts  HostPool
 | 
			
		||||
	Policy Policy
 | 
			
		||||
 | 
			
		||||
	FailTimeout time.Duration
 | 
			
		||||
	MaxFails    int32
 | 
			
		||||
	HealthCheck struct {
 | 
			
		||||
		Path     string
 | 
			
		||||
		Interval time.Duration
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newStaticUpstreams(c middleware.Controller) ([]Upstream, error) {
 | 
			
		||||
	var upstreams []Upstream
 | 
			
		||||
 | 
			
		||||
	for c.Next() {
 | 
			
		||||
		upstream := &staticUpstream{
 | 
			
		||||
			from:        "",
 | 
			
		||||
			Hosts:       nil,
 | 
			
		||||
			Policy:      &Random{},
 | 
			
		||||
			FailTimeout: 10 * time.Second,
 | 
			
		||||
			MaxFails:    1,
 | 
			
		||||
		}
 | 
			
		||||
		var proxyHeaders http.Header
 | 
			
		||||
		if !c.Args(&upstream.from) {
 | 
			
		||||
			return upstreams, c.ArgErr()
 | 
			
		||||
		}
 | 
			
		||||
		to := c.RemainingArgs()
 | 
			
		||||
		if len(to) == 0 {
 | 
			
		||||
			return upstreams, c.ArgErr()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for c.NextBlock() {
 | 
			
		||||
			switch c.Val() {
 | 
			
		||||
			case "policy":
 | 
			
		||||
				if !c.NextArg() {
 | 
			
		||||
					return upstreams, c.ArgErr()
 | 
			
		||||
				}
 | 
			
		||||
				switch c.Val() {
 | 
			
		||||
				case "random":
 | 
			
		||||
					upstream.Policy = &Random{}
 | 
			
		||||
				case "round_robin":
 | 
			
		||||
					upstream.Policy = &RoundRobin{}
 | 
			
		||||
				case "least_conn":
 | 
			
		||||
					upstream.Policy = &LeastConn{}
 | 
			
		||||
				default:
 | 
			
		||||
					return upstreams, c.ArgErr()
 | 
			
		||||
				}
 | 
			
		||||
			case "fail_timeout":
 | 
			
		||||
				if !c.NextArg() {
 | 
			
		||||
					return upstreams, c.ArgErr()
 | 
			
		||||
				}
 | 
			
		||||
				if dur, err := time.ParseDuration(c.Val()); err == nil {
 | 
			
		||||
					upstream.FailTimeout = dur
 | 
			
		||||
				} else {
 | 
			
		||||
					return upstreams, err
 | 
			
		||||
				}
 | 
			
		||||
			case "max_fails":
 | 
			
		||||
				if !c.NextArg() {
 | 
			
		||||
					return upstreams, c.ArgErr()
 | 
			
		||||
				}
 | 
			
		||||
				if n, err := strconv.Atoi(c.Val()); err == nil {
 | 
			
		||||
					upstream.MaxFails = int32(n)
 | 
			
		||||
				} else {
 | 
			
		||||
					return upstreams, err
 | 
			
		||||
				}
 | 
			
		||||
			case "health_check":
 | 
			
		||||
				if !c.NextArg() {
 | 
			
		||||
					return upstreams, c.ArgErr()
 | 
			
		||||
				}
 | 
			
		||||
				upstream.HealthCheck.Path = c.Val()
 | 
			
		||||
				upstream.HealthCheck.Interval = 30 * time.Second
 | 
			
		||||
				if c.NextArg() {
 | 
			
		||||
					if dur, err := time.ParseDuration(c.Val()); err == nil {
 | 
			
		||||
						upstream.HealthCheck.Interval = dur
 | 
			
		||||
					} else {
 | 
			
		||||
						return upstreams, err
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			case "proxy_header":
 | 
			
		||||
				var header, value string
 | 
			
		||||
				if !c.Args(&header, &value) {
 | 
			
		||||
					return upstreams, c.ArgErr()
 | 
			
		||||
				}
 | 
			
		||||
				if proxyHeaders == nil {
 | 
			
		||||
					proxyHeaders = make(map[string][]string)
 | 
			
		||||
				}
 | 
			
		||||
				proxyHeaders.Add(header, value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		upstream.Hosts = make([]*UpstreamHost, len(to))
 | 
			
		||||
		for i, host := range to {
 | 
			
		||||
			if !strings.HasPrefix(host, "http") {
 | 
			
		||||
				host = "http://" + host
 | 
			
		||||
			}
 | 
			
		||||
			uh := &UpstreamHost{
 | 
			
		||||
				Name:         host,
 | 
			
		||||
				Conns:        0,
 | 
			
		||||
				Fails:        0,
 | 
			
		||||
				FailTimeout:  upstream.FailTimeout,
 | 
			
		||||
				Unhealthy:    false,
 | 
			
		||||
				ExtraHeaders: proxyHeaders,
 | 
			
		||||
				CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
 | 
			
		||||
					return func(uh *UpstreamHost) bool {
 | 
			
		||||
						if uh.Unhealthy {
 | 
			
		||||
							return true
 | 
			
		||||
						}
 | 
			
		||||
						if uh.Fails >= upstream.MaxFails &&
 | 
			
		||||
							upstream.MaxFails != 0 {
 | 
			
		||||
							return true
 | 
			
		||||
						}
 | 
			
		||||
						return false
 | 
			
		||||
					}
 | 
			
		||||
				}(upstream),
 | 
			
		||||
			}
 | 
			
		||||
			if baseUrl, err := url.Parse(uh.Name); err == nil {
 | 
			
		||||
				uh.ReverseProxy = NewSingleHostReverseProxy(baseUrl)
 | 
			
		||||
			} else {
 | 
			
		||||
				return upstreams, err
 | 
			
		||||
			}
 | 
			
		||||
			upstream.Hosts[i] = uh
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if upstream.HealthCheck.Path != "" {
 | 
			
		||||
			go upstream.healthCheckWorker(nil)
 | 
			
		||||
		}
 | 
			
		||||
		upstreams = append(upstreams, upstream)
 | 
			
		||||
	}
 | 
			
		||||
	return upstreams, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *staticUpstream) healthCheck() {
 | 
			
		||||
	for _, host := range u.Hosts {
 | 
			
		||||
		hostUrl := host.Name + u.HealthCheck.Path
 | 
			
		||||
		if r, err := http.Get(hostUrl); err == nil {
 | 
			
		||||
			io.Copy(ioutil.Discard, r.Body)
 | 
			
		||||
			r.Body.Close()
 | 
			
		||||
			host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
 | 
			
		||||
		} else {
 | 
			
		||||
			host.Unhealthy = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *staticUpstream) healthCheckWorker(stop chan struct{}) {
 | 
			
		||||
	ticker := time.NewTicker(u.HealthCheck.Interval)
 | 
			
		||||
	u.healthCheck()
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			u.healthCheck()
 | 
			
		||||
		case <-stop:
 | 
			
		||||
			// TODO: the library should provide a stop channel and global
 | 
			
		||||
			// waitgroup to allow goroutines started by plugins a chance
 | 
			
		||||
			// to clean themselves up.
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *staticUpstream) From() string {
 | 
			
		||||
	return u.from
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *staticUpstream) Select() *UpstreamHost {
 | 
			
		||||
	pool := u.Hosts
 | 
			
		||||
	if len(pool) == 1 {
 | 
			
		||||
		if pool[0].Down() {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		return pool[0]
 | 
			
		||||
	}
 | 
			
		||||
	allDown := true
 | 
			
		||||
	for _, host := range pool {
 | 
			
		||||
		if !host.Down() {
 | 
			
		||||
			allDown = false
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if allDown {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if u.Policy == nil {
 | 
			
		||||
		return (&Random{}).Select(pool)
 | 
			
		||||
	} else {
 | 
			
		||||
		return u.Policy.Select(pool)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										43
									
								
								middleware/proxy/upstream_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								middleware/proxy/upstream_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,43 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHealthCheck(t *testing.T) {
 | 
			
		||||
	upstream := &staticUpstream{
 | 
			
		||||
		from:        "",
 | 
			
		||||
		Hosts:       testPool(),
 | 
			
		||||
		Policy:      &Random{},
 | 
			
		||||
		FailTimeout: 10 * time.Second,
 | 
			
		||||
		MaxFails:    1,
 | 
			
		||||
	}
 | 
			
		||||
	upstream.healthCheck()
 | 
			
		||||
	if upstream.Hosts[0].Down() {
 | 
			
		||||
		t.Error("Expected first host in testpool to not fail healthcheck.")
 | 
			
		||||
	}
 | 
			
		||||
	if !upstream.Hosts[1].Down() {
 | 
			
		||||
		t.Error("Expected second host in testpool to fail healthcheck.")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelect(t *testing.T) {
 | 
			
		||||
	upstream := &staticUpstream{
 | 
			
		||||
		from:        "",
 | 
			
		||||
		Hosts:       testPool()[:3],
 | 
			
		||||
		Policy:      &Random{},
 | 
			
		||||
		FailTimeout: 10 * time.Second,
 | 
			
		||||
		MaxFails:    1,
 | 
			
		||||
	}
 | 
			
		||||
	upstream.Hosts[0].Unhealthy = true
 | 
			
		||||
	upstream.Hosts[1].Unhealthy = true
 | 
			
		||||
	upstream.Hosts[2].Unhealthy = true
 | 
			
		||||
	if h := upstream.Select(); h != nil {
 | 
			
		||||
		t.Error("Expected select to return nil as all host are down")
 | 
			
		||||
	}
 | 
			
		||||
	upstream.Hosts[2].Unhealthy = false
 | 
			
		||||
	if h := upstream.Select(); h == nil {
 | 
			
		||||
		t.Error("Expected select to not return nil")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -8,17 +8,21 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// replacer is a type which can replace placeholder
 | 
			
		||||
// Replacer is a type which can replace placeholder
 | 
			
		||||
// substrings in a string with actual values from a
 | 
			
		||||
// http.Request and responseRecorder. Always use
 | 
			
		||||
// NewReplacer to get one of these.
 | 
			
		||||
type Replacer interface {
 | 
			
		||||
	Replace(string) string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type replacer map[string]string
 | 
			
		||||
 | 
			
		||||
// NewReplacer makes a new replacer based on r and rr.
 | 
			
		||||
// Do not create a new replacer until r and rr have all
 | 
			
		||||
// the needed values, because this function copies those
 | 
			
		||||
// values into the replacer.
 | 
			
		||||
func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
 | 
			
		||||
func NewReplacer(r *http.Request, rr *responseRecorder) Replacer {
 | 
			
		||||
	rep := replacer{
 | 
			
		||||
		"{method}": r.Method,
 | 
			
		||||
		"{scheme}": func() string {
 | 
			
		||||
@ -33,6 +37,9 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
 | 
			
		||||
		"{fragment}": r.URL.Fragment,
 | 
			
		||||
		"{proto}":    r.Proto,
 | 
			
		||||
		"{remote}": func() string {
 | 
			
		||||
			if fwdFor := r.Header.Get("X-Forwarded-For"); fwdFor != "" {
 | 
			
		||||
				return fwdFor
 | 
			
		||||
			}
 | 
			
		||||
			host, _, err := net.SplitHostPort(r.RemoteAddr)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return r.RemoteAddr
 | 
			
		||||
@ -50,9 +57,11 @@ func NewReplacer(r *http.Request, rr *responseRecorder) replacer {
 | 
			
		||||
		"{when}": func() string {
 | 
			
		||||
			return time.Now().Format(timeFormat)
 | 
			
		||||
		}(),
 | 
			
		||||
		"{status}":  strconv.Itoa(rr.status),
 | 
			
		||||
		"{size}":    strconv.Itoa(rr.size),
 | 
			
		||||
		"{latency}": time.Since(rr.start).String(),
 | 
			
		||||
	}
 | 
			
		||||
	if rr != nil {
 | 
			
		||||
		rep["{status}"] = strconv.Itoa(rr.status)
 | 
			
		||||
		rep["{size}"] = strconv.Itoa(rr.size)
 | 
			
		||||
		rep["{latency}"] = time.Since(rr.start).String()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Header placeholders
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user