mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-04 03:27:23 -05:00 
			
		
		
		
	
		
			
				
	
	
		
			400 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			400 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package proxy
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"io/ioutil"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"path"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/mholt/caddy/caddyfile"
 | 
						|
	"github.com/mholt/caddy/caddyhttp/httpserver"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	supportedPolicies            = make(map[string]func() Policy)
 | 
						|
	warnedProxyHeaderDeprecation bool // TODO: Temporary, until proxy_header is removed entirely
 | 
						|
)
 | 
						|
 | 
						|
type staticUpstream struct {
 | 
						|
	from               string
 | 
						|
	upstreamHeaders    http.Header
 | 
						|
	downstreamHeaders  http.Header
 | 
						|
	Hosts              HostPool
 | 
						|
	Policy             Policy
 | 
						|
	KeepAlive          int
 | 
						|
	insecureSkipVerify bool
 | 
						|
 | 
						|
	FailTimeout time.Duration
 | 
						|
	MaxFails    int32
 | 
						|
	MaxConns    int64
 | 
						|
	HealthCheck struct {
 | 
						|
		Client   http.Client
 | 
						|
		Path     string
 | 
						|
		Interval time.Duration
 | 
						|
		Timeout  time.Duration
 | 
						|
	}
 | 
						|
	WithoutPathPrefix string
 | 
						|
	IgnoredSubPaths   []string
 | 
						|
}
 | 
						|
 | 
						|
// NewStaticUpstreams parses the configuration input and sets up
 | 
						|
// static upstreams for the proxy middleware.
 | 
						|
func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
 | 
						|
	var upstreams []Upstream
 | 
						|
	for c.Next() {
 | 
						|
		upstream := &staticUpstream{
 | 
						|
			from:              "",
 | 
						|
			upstreamHeaders:   make(http.Header),
 | 
						|
			downstreamHeaders: make(http.Header),
 | 
						|
			Hosts:             nil,
 | 
						|
			Policy:            &Random{},
 | 
						|
			FailTimeout:       10 * time.Second,
 | 
						|
			MaxFails:          1,
 | 
						|
			MaxConns:          0,
 | 
						|
			KeepAlive:         http.DefaultMaxIdleConnsPerHost,
 | 
						|
		}
 | 
						|
 | 
						|
		if !c.Args(&upstream.from) {
 | 
						|
			return upstreams, c.ArgErr()
 | 
						|
		}
 | 
						|
 | 
						|
		var to []string
 | 
						|
		for _, t := range c.RemainingArgs() {
 | 
						|
			parsed, err := parseUpstream(t)
 | 
						|
			if err != nil {
 | 
						|
				return upstreams, err
 | 
						|
			}
 | 
						|
			to = append(to, parsed...)
 | 
						|
		}
 | 
						|
 | 
						|
		for c.NextBlock() {
 | 
						|
			switch c.Val() {
 | 
						|
			case "upstream":
 | 
						|
				if !c.NextArg() {
 | 
						|
					return upstreams, c.ArgErr()
 | 
						|
				}
 | 
						|
				parsed, err := parseUpstream(c.Val())
 | 
						|
				if err != nil {
 | 
						|
					return upstreams, err
 | 
						|
				}
 | 
						|
				to = append(to, parsed...)
 | 
						|
			default:
 | 
						|
				if err := parseBlock(&c, upstream); err != nil {
 | 
						|
					return upstreams, err
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if len(to) == 0 {
 | 
						|
			return upstreams, c.ArgErr()
 | 
						|
		}
 | 
						|
 | 
						|
		upstream.Hosts = make([]*UpstreamHost, len(to))
 | 
						|
		for i, host := range to {
 | 
						|
			uh, err := upstream.NewHost(host)
 | 
						|
			if err != nil {
 | 
						|
				return upstreams, err
 | 
						|
			}
 | 
						|
			upstream.Hosts[i] = uh
 | 
						|
		}
 | 
						|
 | 
						|
		if upstream.HealthCheck.Path != "" {
 | 
						|
			upstream.HealthCheck.Client = http.Client{
 | 
						|
				Timeout: upstream.HealthCheck.Timeout,
 | 
						|
			}
 | 
						|
			go upstream.HealthCheckWorker(nil)
 | 
						|
		}
 | 
						|
		upstreams = append(upstreams, upstream)
 | 
						|
	}
 | 
						|
	return upstreams, nil
 | 
						|
}
 | 
						|
 | 
						|
// RegisterPolicy adds a custom policy to the proxy.
 | 
						|
func RegisterPolicy(name string, policy func() Policy) {
 | 
						|
	supportedPolicies[name] = policy
 | 
						|
}
 | 
						|
 | 
						|
func (u *staticUpstream) From() string {
 | 
						|
	return u.from
 | 
						|
}
 | 
						|
 | 
						|
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
 | 
						|
	if !strings.HasPrefix(host, "http") &&
 | 
						|
		!strings.HasPrefix(host, "unix:") {
 | 
						|
		host = "http://" + host
 | 
						|
	}
 | 
						|
	uh := &UpstreamHost{
 | 
						|
		Name:              host,
 | 
						|
		Conns:             0,
 | 
						|
		Fails:             0,
 | 
						|
		FailTimeout:       u.FailTimeout,
 | 
						|
		Unhealthy:         false,
 | 
						|
		UpstreamHeaders:   u.upstreamHeaders,
 | 
						|
		DownstreamHeaders: u.downstreamHeaders,
 | 
						|
		CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
 | 
						|
			return func(uh *UpstreamHost) bool {
 | 
						|
				if uh.Unhealthy {
 | 
						|
					return true
 | 
						|
				}
 | 
						|
				if uh.Fails >= u.MaxFails &&
 | 
						|
					u.MaxFails != 0 {
 | 
						|
					return true
 | 
						|
				}
 | 
						|
				return false
 | 
						|
			}
 | 
						|
		}(u),
 | 
						|
		WithoutPathPrefix: u.WithoutPathPrefix,
 | 
						|
		MaxConns:          u.MaxConns,
 | 
						|
	}
 | 
						|
 | 
						|
	baseURL, err := url.Parse(uh.Name)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
 | 
						|
	if u.insecureSkipVerify {
 | 
						|
		uh.ReverseProxy.UseInsecureTransport()
 | 
						|
	}
 | 
						|
 | 
						|
	return uh, nil
 | 
						|
}
 | 
						|
 | 
						|
func parseUpstream(u string) ([]string, error) {
 | 
						|
	if !strings.HasPrefix(u, "unix:") {
 | 
						|
		colonIdx := strings.LastIndex(u, ":")
 | 
						|
		protoIdx := strings.Index(u, "://")
 | 
						|
 | 
						|
		if colonIdx != -1 && colonIdx != protoIdx {
 | 
						|
			us := u[:colonIdx]
 | 
						|
			ue := ""
 | 
						|
			portsEnd := len(u)
 | 
						|
			if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
 | 
						|
				portsEnd = colonIdx + nextSlash
 | 
						|
				ue = u[portsEnd:]
 | 
						|
			}
 | 
						|
			ports := u[len(us)+1 : portsEnd]
 | 
						|
 | 
						|
			if separators := strings.Count(ports, "-"); separators == 1 {
 | 
						|
				portsStr := strings.Split(ports, "-")
 | 
						|
				pIni, err := strconv.Atoi(portsStr[0])
 | 
						|
				if err != nil {
 | 
						|
					return nil, err
 | 
						|
				}
 | 
						|
 | 
						|
				pEnd, err := strconv.Atoi(portsStr[1])
 | 
						|
				if err != nil {
 | 
						|
					return nil, err
 | 
						|
				}
 | 
						|
 | 
						|
				if pEnd <= pIni {
 | 
						|
					return nil, fmt.Errorf("port range [%s] is invalid", ports)
 | 
						|
				}
 | 
						|
 | 
						|
				hosts := []string{}
 | 
						|
				for p := pIni; p <= pEnd; p++ {
 | 
						|
					hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
 | 
						|
				}
 | 
						|
				return hosts, nil
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return []string{u}, nil
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
 | 
						|
	switch c.Val() {
 | 
						|
	case "policy":
 | 
						|
		if !c.NextArg() {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		policyCreateFunc, ok := supportedPolicies[c.Val()]
 | 
						|
		if !ok {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		u.Policy = policyCreateFunc()
 | 
						|
	case "fail_timeout":
 | 
						|
		if !c.NextArg() {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		dur, err := time.ParseDuration(c.Val())
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		u.FailTimeout = dur
 | 
						|
	case "max_fails":
 | 
						|
		if !c.NextArg() {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		n, err := strconv.Atoi(c.Val())
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		u.MaxFails = int32(n)
 | 
						|
	case "max_conns":
 | 
						|
		if !c.NextArg() {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		n, err := strconv.ParseInt(c.Val(), 10, 64)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		u.MaxConns = n
 | 
						|
	case "health_check":
 | 
						|
		if !c.NextArg() {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		u.HealthCheck.Path = c.Val()
 | 
						|
 | 
						|
		// Set defaults
 | 
						|
		if u.HealthCheck.Interval == 0 {
 | 
						|
			u.HealthCheck.Interval = 30 * time.Second
 | 
						|
		}
 | 
						|
		if u.HealthCheck.Timeout == 0 {
 | 
						|
			u.HealthCheck.Timeout = 60 * time.Second
 | 
						|
		}
 | 
						|
	case "health_check_interval":
 | 
						|
		var interval string
 | 
						|
		if !c.Args(&interval) {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		dur, err := time.ParseDuration(interval)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		u.HealthCheck.Interval = dur
 | 
						|
	case "health_check_timeout":
 | 
						|
		var interval string
 | 
						|
		if !c.Args(&interval) {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		dur, err := time.ParseDuration(interval)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		u.HealthCheck.Timeout = dur
 | 
						|
	case "proxy_header": // TODO: deprecate this shortly after 0.9
 | 
						|
		if !warnedProxyHeaderDeprecation {
 | 
						|
			fmt.Println("WARNING: proxy_header is deprecated and will be removed soon; use header_upstream instead.")
 | 
						|
			warnedProxyHeaderDeprecation = true
 | 
						|
		}
 | 
						|
		fallthrough
 | 
						|
	case "header_upstream":
 | 
						|
		var header, value string
 | 
						|
		if !c.Args(&header, &value) {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		u.upstreamHeaders.Add(header, value)
 | 
						|
	case "header_downstream":
 | 
						|
		var header, value string
 | 
						|
		if !c.Args(&header, &value) {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		u.downstreamHeaders.Add(header, value)
 | 
						|
	case "transparent":
 | 
						|
		u.upstreamHeaders.Add("Host", "{host}")
 | 
						|
		u.upstreamHeaders.Add("X-Real-IP", "{remote}")
 | 
						|
		u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
 | 
						|
		u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
 | 
						|
	case "websocket":
 | 
						|
		u.upstreamHeaders.Add("Connection", "{>Connection}")
 | 
						|
		u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
 | 
						|
	case "without":
 | 
						|
		if !c.NextArg() {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		u.WithoutPathPrefix = c.Val()
 | 
						|
	case "except":
 | 
						|
		ignoredPaths := c.RemainingArgs()
 | 
						|
		if len(ignoredPaths) == 0 {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		u.IgnoredSubPaths = ignoredPaths
 | 
						|
	case "insecure_skip_verify":
 | 
						|
		u.insecureSkipVerify = true
 | 
						|
	case "keepalive":
 | 
						|
		if !c.NextArg() {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		n, err := strconv.Atoi(c.Val())
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		if n < 0 {
 | 
						|
			return c.ArgErr()
 | 
						|
		}
 | 
						|
		u.KeepAlive = n
 | 
						|
	default:
 | 
						|
		return c.Errf("unknown property '%s'", c.Val())
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (u *staticUpstream) healthCheck() {
 | 
						|
	for _, host := range u.Hosts {
 | 
						|
		hostURL := host.Name + u.HealthCheck.Path
 | 
						|
		if r, err := u.HealthCheck.Client.Get(hostURL); err == nil {
 | 
						|
			io.Copy(ioutil.Discard, r.Body)
 | 
						|
			r.Body.Close()
 | 
						|
			host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
 | 
						|
		} else {
 | 
						|
			host.Unhealthy = true
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
 | 
						|
	ticker := time.NewTicker(u.HealthCheck.Interval)
 | 
						|
	u.healthCheck()
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case <-ticker.C:
 | 
						|
			u.healthCheck()
 | 
						|
		case <-stop:
 | 
						|
			// TODO: the library should provide a stop channel and global
 | 
						|
			// waitgroup to allow goroutines started by plugins a chance
 | 
						|
			// to clean themselves up.
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (u *staticUpstream) Select(r *http.Request) *UpstreamHost {
 | 
						|
	pool := u.Hosts
 | 
						|
	if len(pool) == 1 {
 | 
						|
		if !pool[0].Available() {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
		return pool[0]
 | 
						|
	}
 | 
						|
	allUnavailable := true
 | 
						|
	for _, host := range pool {
 | 
						|
		if host.Available() {
 | 
						|
			allUnavailable = false
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if allUnavailable {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	if u.Policy == nil {
 | 
						|
		return (&Random{}).Select(pool, r)
 | 
						|
	}
 | 
						|
	return u.Policy.Select(pool, r)
 | 
						|
}
 | 
						|
 | 
						|
func (u *staticUpstream) AllowedPath(requestPath string) bool {
 | 
						|
	for _, ignoredSubPath := range u.IgnoredSubPaths {
 | 
						|
		if httpserver.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
 | 
						|
			return false
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return true
 | 
						|
}
 |