Merge branch 'master' into trace-servicegraph

This commit is contained in:
Matt Holt
2024-10-02 08:08:04 -06:00
committed by GitHub
110 changed files with 4120 additions and 1250 deletions
+1 -1
View File
@@ -137,7 +137,7 @@ func parseUpstreamDialAddress(upstreamAddr string) (parsedAddr, error) {
}
// we can assume a port if only a hostname is specified, but use of a
// placeholder without a port likely means a port will be filled in
if port == "" && !strings.Contains(host, "{") && !caddy.IsUnixNetwork(network) {
if port == "" && !strings.Contains(host, "{") && !caddy.IsUnixNetwork(network) && !caddy.IsFdNetwork(network) {
port = "80"
}
}
+68 -14
View File
@@ -16,6 +16,7 @@ package reverseproxy
import (
"fmt"
"net"
"net/http"
"reflect"
"strconv"
@@ -27,6 +28,7 @@ import (
"github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
"github.com/caddyserver/caddy/v2/internal"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/headers"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
@@ -67,14 +69,16 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// lb_retry_match <request-matcher>
//
// # active health checking
// health_uri <uri>
// health_port <port>
// health_interval <interval>
// health_passes <num>
// health_fails <num>
// health_timeout <duration>
// health_status <status>
// health_body <regexp>
// health_uri <uri>
// health_port <port>
// health_interval <interval>
// health_passes <num>
// health_fails <num>
// health_timeout <duration>
// health_status <status>
// health_body <regexp>
// health_method <value>
// health_request_body <value>
// health_follow_redirects
// health_headers {
// <field> [<values...>]
@@ -89,12 +93,11 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
//
// # streaming
// flush_interval <duration>
// buffer_requests
// buffer_responses
// max_buffer_size <size>
// request_buffers <size>
// response_buffers <size>
// stream_timeout <duration>
// stream_close_delay <duration>
// trace_logs
// verbose_logs
//
// # request manipulation
// trusted_proxies [private_ranges] <ranges...>
@@ -353,6 +356,26 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
h.HealthChecks.Active.Path = d.Val()
caddy.Log().Named("config.adapter.caddyfile").Warn("the 'health_path' subdirective is deprecated, please use 'health_uri' instead!")
case "health_upstream":
if !d.NextArg() {
return d.ArgErr()
}
if h.HealthChecks == nil {
h.HealthChecks = new(HealthChecks)
}
if h.HealthChecks.Active == nil {
h.HealthChecks.Active = new(ActiveHealthChecks)
}
_, port, err := net.SplitHostPort(d.Val())
if err != nil {
return d.Errf("health_upstream is malformed '%s': %v", d.Val(), err)
}
_, err = strconv.Atoi(port)
if err != nil {
return d.Errf("bad port number '%s': %v", d.Val(), err)
}
h.HealthChecks.Active.Upstream = d.Val()
case "health_port":
if !d.NextArg() {
return d.ArgErr()
@@ -363,6 +386,9 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
if h.HealthChecks.Active == nil {
h.HealthChecks.Active = new(ActiveHealthChecks)
}
if h.HealthChecks.Active.Upstream != "" {
return d.Errf("the 'health_port' subdirective is ignored if 'health_upstream' is used!")
}
portNum, err := strconv.Atoi(d.Val())
if err != nil {
return d.Errf("bad port number '%s': %v", d.Val(), err)
@@ -387,6 +413,30 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
h.HealthChecks.Active.Headers = healthHeaders
case "health_method":
if !d.NextArg() {
return d.ArgErr()
}
if h.HealthChecks == nil {
h.HealthChecks = new(HealthChecks)
}
if h.HealthChecks.Active == nil {
h.HealthChecks.Active = new(ActiveHealthChecks)
}
h.HealthChecks.Active.Method = d.Val()
case "health_request_body":
if !d.NextArg() {
return d.ArgErr()
}
if h.HealthChecks == nil {
h.HealthChecks = new(HealthChecks)
}
if h.HealthChecks.Active == nil {
h.HealthChecks.Active = new(ActiveHealthChecks)
}
h.HealthChecks.Active.Body = d.Val()
case "health_interval":
if !d.NextArg() {
return d.ArgErr()
@@ -651,7 +701,7 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
case "trusted_proxies":
for d.NextArg() {
if d.Val() == "private_ranges" {
h.TrustedProxies = append(h.TrustedProxies, caddyhttp.PrivateRangesCIDR()...)
h.TrustedProxies = append(h.TrustedProxies, internal.PrivateRangesCIDR()...)
continue
}
h.TrustedProxies = append(h.TrustedProxies, d.Val())
@@ -1275,7 +1325,11 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return d.Err("cannot specify \"tls_trust_pool\" twice in caddyfile")
}
h.TLS.CARaw = caddyconfig.JSONModuleObject(ca, "provider", modStem, nil)
case "local_address":
if !d.NextArg() {
return d.ArgErr()
}
h.LocalAddress = d.Val()
default:
return d.Errf("unrecognized subdirective %s", d.Val())
}
+7 -5
View File
@@ -229,11 +229,13 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
if changeHost {
if handler.Headers == nil {
handler.Headers = &headers.Handler{
Request: &headers.HeaderOps{
Set: http.Header{},
},
}
handler.Headers = new(headers.Handler)
}
if handler.Headers.Request == nil {
handler.Headers.Request = new(headers.HeaderOps)
}
if handler.Headers.Request.Set == nil {
handler.Headers.Request.Set = http.Header{}
}
handler.Headers.Request.Set.Set("Host", "{http.reverse_proxy.upstream.hostport}")
}
@@ -40,6 +40,7 @@ import (
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// FCGIListenSockFileno describes listen socket file number.
@@ -184,10 +185,13 @@ func (f clientCloser) Close() error {
return f.rwc.Close()
}
logLevel := zapcore.WarnLevel
if f.status >= 400 {
f.logger.Error("stderr", zap.ByteString("body", stderr))
} else {
f.logger.Warn("stderr", zap.ByteString("body", stderr))
logLevel = zapcore.ErrorLevel
}
if c := f.logger.Check(logLevel, "stderr"); c != nil {
c.Write(zap.ByteString("body", stderr))
}
return f.rwc.Close()
@@ -148,10 +148,13 @@ func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
zap.Object("request", loggableReq),
zap.Object("env", loggableEnv),
)
logger.Debug("roundtrip",
zap.String("dial", address),
zap.Object("env", loggableEnv),
zap.Object("request", loggableReq))
if c := t.logger.Check(zapcore.DebugLevel, "roundtrip"); c != nil {
c.Write(
zap.String("dial", address),
zap.Object("env", loggableEnv),
zap.Object("request", loggableReq),
)
}
// connect to the backend
dialer := net.Dialer{Timeout: time.Duration(t.DialTimeout)}
+153 -74
View File
@@ -23,10 +23,13 @@ import (
"net/url"
"regexp"
"runtime/debug"
"slices"
"strconv"
"strings"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
@@ -75,13 +78,27 @@ type ActiveHealthChecks struct {
// The URI (path and query) to use for health checks
URI string `json:"uri,omitempty"`
// The host:port to use (if different from the upstream's dial address)
// for health checks. This should be used in tandem with `health_header` and
// `{http.reverse_proxy.active.target_upstream}`. This can be helpful when
// creating an intermediate service to do a more thorough health check.
// If upstream is set, the active health check port is ignored.
Upstream string `json:"upstream,omitempty"`
// The port to use (if different from the upstream's dial
// address) for health checks.
// address) for health checks. If active upstream is set,
// this value is ignored.
Port int `json:"port,omitempty"`
// HTTP headers to set on health check requests.
Headers http.Header `json:"headers,omitempty"`
// The HTTP method to use for health checks (default "GET").
Method string `json:"method,omitempty"`
// The body to send with the health check request.
Body string `json:"body,omitempty"`
// Whether to follow HTTP redirects in response to active health checks (default off).
FollowRedirects bool `json:"follow_redirects,omitempty"`
@@ -133,6 +150,11 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error {
}
a.Headers = cleaned
// If Method is not set, default to GET
if a.Method == "" {
a.Method = http.MethodGet
}
h.HealthChecks.Active.logger = h.logger.Named("health_checker.active")
timeout := time.Duration(a.Timeout)
@@ -165,9 +187,14 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error {
}
for _, upstream := range h.Upstreams {
// if there's an alternative port for health-check provided in the config,
// then use it, otherwise use the port of upstream.
if a.Port != 0 {
// if there's an alternative upstream for health-check provided in the config,
// then use it, otherwise use the upstream's dial address. if upstream is used,
// then the port is ignored.
if a.Upstream != "" {
upstream.activeHealthCheckUpstream = a.Upstream
} else if a.Port != 0 {
// if there's an alternative port for health-check provided in the config,
// then use it, otherwise use the port of upstream.
upstream.activeHealthCheckPort = a.Port
}
}
@@ -245,9 +272,12 @@ type CircuitBreaker interface {
func (h *Handler) activeHealthChecker() {
defer func() {
if err := recover(); err != nil {
h.HealthChecks.Active.logger.Error("active health checker panicked",
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()))
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "active health checker panicked"); c != nil {
c.Write(
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()),
)
}
}
}()
ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval))
@@ -270,54 +300,65 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
go func(upstream *Upstream) {
defer func() {
if err := recover(); err != nil {
h.HealthChecks.Active.logger.Error("active health check panicked",
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()))
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "active health checker panicked"); c != nil {
c.Write(
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()),
)
}
}
}()
networkAddr, err := caddy.NewReplacer().ReplaceOrErr(upstream.Dial, true, true)
if err != nil {
h.HealthChecks.Active.logger.Error("invalid use of placeholders in dial address for active health checks",
zap.String("address", networkAddr),
zap.Error(err),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "invalid use of placeholders in dial address for active health checks"); c != nil {
c.Write(
zap.String("address", networkAddr),
zap.Error(err),
)
}
return
}
addr, err := caddy.ParseNetworkAddress(networkAddr)
if err != nil {
h.HealthChecks.Active.logger.Error("bad network address",
zap.String("address", networkAddr),
zap.Error(err),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "bad network address"); c != nil {
c.Write(
zap.String("address", networkAddr),
zap.Error(err),
)
}
return
}
if hcp := uint(upstream.activeHealthCheckPort); hcp != 0 {
if addr.IsUnixNetwork() {
if addr.IsUnixNetwork() || addr.IsFdNetwork() {
addr.Network = "tcp" // I guess we just assume TCP since we are using a port??
}
addr.StartPort, addr.EndPort = hcp, hcp
}
if addr.PortRangeSize() != 1 {
h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)",
zap.String("address", networkAddr),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "multiple addresses (upstream must map to only one address)"); c != nil {
c.Write(
zap.String("address", networkAddr),
)
}
return
}
hostAddr := addr.JoinHostPort(0)
dialAddr := hostAddr
if addr.IsUnixNetwork() {
if addr.IsUnixNetwork() || addr.IsFdNetwork() {
// this will be used as the Host portion of a http.Request URL, and
// paths to socket files would produce an error when creating URL,
// so use a fake Host value instead; unix sockets are usually local
hostAddr = "localhost"
}
err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream)
err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, networkAddr, upstream)
if err != nil {
h.HealthChecks.Active.logger.Error("active health check failed",
zap.String("address", hostAddr),
zap.Error(err),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "active health check failed"); c != nil {
c.Write(
zap.String("address", hostAddr),
zap.Error(err),
)
}
}
}(upstream)
}
@@ -330,7 +371,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
// according to whether it passes the health check. An error is
// returned only if the health check fails to occur or if marking
// the host's health status fails.
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstream *Upstream) error {
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networkAddr string, upstream *Upstream) error {
// create the URL for the request that acts as a health check
u := &url.URL{
Scheme: "http",
@@ -342,7 +383,12 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
if err != nil {
host = hostAddr
}
if h.HealthChecks.Active.Port != 0 {
// ignore active health check port if active upstream is provided as the
// active upstream already contains the replacement port
if h.HealthChecks.Active.Upstream != "" {
u.Host = h.HealthChecks.Active.Upstream
} else if h.HealthChecks.Active.Port != 0 {
port := strconv.Itoa(h.HealthChecks.Active.Port)
u.Host = net.JoinHostPort(host, port)
}
@@ -352,12 +398,8 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
u.Scheme = "https"
// if the port is in the except list, flip back to HTTP
if ht, ok := h.Transport.(*HTTPTransport); ok {
for _, exceptPort := range ht.TLS.ExceptPorts {
if exceptPort == port {
u.Scheme = "http"
}
}
if ht, ok := h.Transport.(*HTTPTransport); ok && slices.Contains(ht.TLS.ExceptPorts, port) {
u.Scheme = "http"
}
}
@@ -370,6 +412,16 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
u.Path = h.HealthChecks.Active.Path
}
// replacer used for both body and headers. Only globals (env vars, system info, etc.) are available
repl := caddy.NewReplacer()
// if body is provided, create a reader for it, otherwise nil
var requestBody io.Reader
if h.HealthChecks.Active.Body != "" {
// set body, using replacer
requestBody = strings.NewReader(repl.ReplaceAll(h.HealthChecks.Active.Body, ""))
}
// attach dialing information to this request, as well as context values that
// may be expected by handlers of this request
ctx := h.ctx.Context
@@ -377,15 +429,15 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
ctx = context.WithValue(ctx, caddyhttp.VarsCtxKey, map[string]any{
dialInfoVarKey: dialInfo,
})
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
req, err := http.NewRequestWithContext(ctx, h.HealthChecks.Active.Method, u.String(), requestBody)
if err != nil {
return fmt.Errorf("making request: %v", err)
}
ctx = context.WithValue(ctx, caddyhttp.OriginalRequestCtxKey, *req)
req = req.WithContext(ctx)
// set headers, using a replacer with only globals (env vars, system info, etc.)
repl := caddy.NewReplacer()
// set headers, using replacer
repl.Set("http.reverse_proxy.active.target_upstream", networkAddr)
for key, vals := range h.HealthChecks.Active.Headers {
key = repl.ReplaceAll(key, "")
if key == "Host" {
@@ -401,9 +453,12 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
// increment failures and then check if it has reached the threshold to mark unhealthy
err := upstream.Host.countHealthFail(1)
if err != nil {
h.HealthChecks.Active.logger.Error("could not count active health failure",
zap.String("host", upstream.Dial),
zap.Error(err))
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "could not count active health failure"); c != nil {
c.Write(
zap.String("host", upstream.Dial),
zap.Error(err),
)
}
return
}
if upstream.Host.activeHealthFails() >= h.HealthChecks.Active.Fails {
@@ -419,13 +474,19 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
// increment passes and then check if it has reached the threshold to be healthy
err := upstream.Host.countHealthPass(1)
if err != nil {
h.HealthChecks.Active.logger.Error("could not count active health pass",
zap.String("host", upstream.Dial),
zap.Error(err))
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "could not count active health pass"); c != nil {
c.Write(
zap.String("host", upstream.Dial),
zap.Error(err),
)
}
return
}
if upstream.Host.activeHealthPasses() >= h.HealthChecks.Active.Passes {
if upstream.setHealthy(true) {
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "host is up"); c != nil {
c.Write(zap.String("host", hostAddr))
}
h.events.Emit(h.ctx, "healthy", map[string]any{"host": hostAddr})
upstream.Host.resetHealth()
}
@@ -435,10 +496,12 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
// do the request, being careful to tame the response body
resp, err := h.HealthChecks.Active.httpClient.Do(req)
if err != nil {
h.HealthChecks.Active.logger.Info("HTTP request failed",
zap.String("host", hostAddr),
zap.Error(err),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "HTTP request failed"); c != nil {
c.Write(
zap.String("host", hostAddr),
zap.Error(err),
)
}
markUnhealthy()
return nil
}
@@ -455,18 +518,22 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
// if status code is outside criteria, mark down
if h.HealthChecks.Active.ExpectStatus > 0 {
if !caddyhttp.StatusCodeMatches(resp.StatusCode, h.HealthChecks.Active.ExpectStatus) {
h.HealthChecks.Active.logger.Info("unexpected status code",
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "unexpected status code"); c != nil {
c.Write(
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
}
markUnhealthy()
return nil
}
} else if resp.StatusCode < 200 || resp.StatusCode >= 300 {
h.HealthChecks.Active.logger.Info("status code out of tolerances",
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "status code out of tolerances"); c != nil {
c.Write(
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
}
markUnhealthy()
return nil
}
@@ -475,24 +542,27 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
if h.HealthChecks.Active.bodyRegexp != nil {
bodyBytes, err := io.ReadAll(body)
if err != nil {
h.HealthChecks.Active.logger.Info("failed to read response body",
zap.String("host", hostAddr),
zap.Error(err),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "failed to read response body"); c != nil {
c.Write(
zap.String("host", hostAddr),
zap.Error(err),
)
}
markUnhealthy()
return nil
}
if !h.HealthChecks.Active.bodyRegexp.Match(bodyBytes) {
h.HealthChecks.Active.logger.Info("response body failed expectations",
zap.String("host", hostAddr),
)
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "response body failed expectations"); c != nil {
c.Write(
zap.String("host", hostAddr),
)
}
markUnhealthy()
return nil
}
}
// passed health check parameters, so mark as healthy
h.HealthChecks.Active.logger.Info("host is up", zap.String("host", hostAddr))
markHealthy()
return nil
@@ -516,9 +586,12 @@ func (h *Handler) countFailure(upstream *Upstream) {
// count failure immediately
err := upstream.Host.countFail(1)
if err != nil {
h.HealthChecks.Passive.logger.Error("could not count failure",
zap.String("host", upstream.Dial),
zap.Error(err))
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "could not count failure"); c != nil {
c.Write(
zap.String("host", upstream.Dial),
zap.Error(err),
)
}
return
}
@@ -526,9 +599,12 @@ func (h *Handler) countFailure(upstream *Upstream) {
go func(host *Host, failDuration time.Duration) {
defer func() {
if err := recover(); err != nil {
h.HealthChecks.Passive.logger.Error("passive health check failure forgetter panicked",
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()))
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "passive health check failure forgetter panicked"); c != nil {
c.Write(
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()),
)
}
}
}()
timer := time.NewTimer(failDuration)
@@ -541,9 +617,12 @@ func (h *Handler) countFailure(upstream *Upstream) {
}
err := host.countFail(-1)
if err != nil {
h.HealthChecks.Passive.logger.Error("could not forget failure",
zap.String("host", upstream.Dial),
zap.Error(err))
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "could not forget failure"); c != nil {
c.Write(
zap.String("host", upstream.Dial),
zap.Error(err),
)
}
}
}(upstream.Host, failDuration)
}
+5 -4
View File
@@ -57,10 +57,11 @@ type Upstream struct {
// HeaderAffinity string
// IPAffinity string
activeHealthCheckPort int
healthCheckPolicy *PassiveHealthChecks
cb CircuitBreaker
unhealthy int32 // accessed atomically; status from active health checker
activeHealthCheckPort int
activeHealthCheckUpstream string
healthCheckPolicy *PassiveHealthChecks
cb CircuitBreaker
unhealthy int32 // accessed atomically; status from active health checker
}
// (pointer receiver necessary to avoid a race condition, since
+50 -15
View File
@@ -27,12 +27,14 @@ import (
"net/url"
"os"
"reflect"
"slices"
"strings"
"time"
"github.com/pires/go-proxyproto"
"github.com/quic-go/quic-go/http3"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/http2"
"github.com/caddyserver/caddy/v2"
@@ -132,6 +134,10 @@ type HTTPTransport struct {
// to change or removal while experimental.
Versions []string `json:"versions,omitempty"`
// Specify the address to bind to when connecting to an upstream. In other words,
// it is the address the upstream sees as the remote address.
LocalAddress string `json:"local_address,omitempty"`
// The pre-configured underlying HTTP transport.
Transport *http.Transport `json:"-"`
@@ -185,6 +191,31 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
FallbackDelay: time.Duration(h.FallbackDelay),
}
if h.LocalAddress != "" {
netaddr, err := caddy.ParseNetworkAddressWithDefaults(h.LocalAddress, "tcp", 0)
if err != nil {
return nil, err
}
if netaddr.PortRangeSize() > 1 {
return nil, fmt.Errorf("local_address must be a single address, not a port range")
}
switch netaddr.Network {
case "tcp", "tcp4", "tcp6":
dialer.LocalAddr, err = net.ResolveTCPAddr(netaddr.Network, netaddr.JoinHostPort(0))
if err != nil {
return nil, err
}
case "unix", "unixgram", "unixpacket":
dialer.LocalAddr, err = net.ResolveUnixAddr(netaddr.Network, netaddr.JoinHostPort(0))
if err != nil {
return nil, err
}
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("local_address must be a TCP address, not a UDP address")
default:
return nil, fmt.Errorf("unsupported network")
}
}
if h.Resolver != nil {
err := h.Resolver.ParseAddresses()
if err != nil {
@@ -351,7 +382,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
rt.DisableCompression = !*h.Compression
}
if sliceContains(h.Versions, "2") {
if slices.Contains(h.Versions, "2") {
if err := http2.ConfigureTransport(rt); err != nil {
return nil, err
}
@@ -363,13 +394,20 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
// site owners control the backends), so it must be exclusive
if len(h.Versions) == 1 && h.Versions[0] == "3" {
h.h3Transport = new(http3.RoundTripper)
} else if len(h.Versions) > 1 && sliceContains(h.Versions, "3") {
if h.TLS != nil {
var err error
h.h3Transport.TLSClientConfig, err = h.TLS.MakeTLSClientConfig(caddyCtx)
if err != nil {
return nil, fmt.Errorf("making TLS client config for HTTP/3 transport: %v", err)
}
}
} else if len(h.Versions) > 1 && slices.Contains(h.Versions, "3") {
return nil, fmt.Errorf("if HTTP/3 is enabled to the upstream, no other HTTP versions are supported")
}
// if h2c is enabled, configure its transport (std lib http.Transport
// does not "HTTP/2 over cleartext TCP")
if sliceContains(h.Versions, "h2c") {
if slices.Contains(h.Versions, "h2c") {
// crafting our own http2.Transport doesn't allow us to utilize
// most of the customizations/preferences on the http.Transport,
// because, for some reason, only http2.ConfigureTransport()
@@ -439,6 +477,9 @@ func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// if H2C ("HTTP/2 over cleartext") is enabled and the upstream request is
// HTTP without TLS, use the alternate H2C-capable transport instead
if req.URL.Scheme == "http" && h.h2cTransport != nil {
// There is no dedicated DisableKeepAlives field in *http2.Transport.
// This is an alternative way to disable keep-alive.
req.Close = h.Transport.DisableKeepAlives
return h.h2cTransport.RoundTrip(req)
}
@@ -711,7 +752,9 @@ func (c *tcpRWTimeoutConn) Read(b []byte) (int, error) {
if c.readTimeout > 0 {
err := c.TCPConn.SetReadDeadline(time.Now().Add(c.readTimeout))
if err != nil {
c.logger.Error("failed to set read deadline", zap.Error(err))
if ce := c.logger.Check(zapcore.ErrorLevel, "failed to set read deadline"); ce != nil {
ce.Write(zap.Error(err))
}
}
}
return c.TCPConn.Read(b)
@@ -721,7 +764,9 @@ func (c *tcpRWTimeoutConn) Write(b []byte) (int, error) {
if c.writeTimeout > 0 {
err := c.TCPConn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
if err != nil {
c.logger.Error("failed to set write deadline", zap.Error(err))
if ce := c.logger.Check(zapcore.ErrorLevel, "failed to set write deadline"); ce != nil {
ce.Write(zap.Error(err))
}
}
}
return c.TCPConn.Write(b)
@@ -739,16 +784,6 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
return x509.ParseCertificate(derBytes)
}
// sliceContains returns true if needle is in haystack.
func sliceContains(haystack []string, needle string) bool {
for _, s := range haystack {
if s == needle {
return true
}
}
return false
}
// Interface guards
var (
_ caddy.Provisioner = (*HTTPTransport)(nil)
+7 -3
View File
@@ -8,6 +8,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
var reverseProxyMetrics = struct {
@@ -48,9 +49,12 @@ func (m *metricsUpstreamsHealthyUpdater) Init() {
go func() {
defer func() {
if err := recover(); err != nil {
reverseProxyMetrics.logger.Error("upstreams healthy metrics updater panicked",
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()))
if c := reverseProxyMetrics.logger.Check(zapcore.ErrorLevel, "upstreams healthy metrics updater panicked"); c != nil {
c.Write(
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()),
)
}
}
}()
+54 -16
View File
@@ -34,6 +34,7 @@ import (
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/http/httpguts"
"github.com/caddyserver/caddy/v2"
@@ -68,6 +69,7 @@ func init() {
// `{http.reverse_proxy.upstream.duration_ms}` | Same as 'upstream.duration', but in milliseconds.
// `{http.reverse_proxy.duration}` | Total time spent proxying, including selecting an upstream, retries, and writing response.
// `{http.reverse_proxy.duration_ms}` | Same as 'duration', but in milliseconds.
// `{http.reverse_proxy.retries}` | The number of retries actually performed to communicate with an upstream.
type Handler struct {
// Configures the method of transport for the proxy. A transport
// is what performs the actual "round trip" to the backend.
@@ -439,11 +441,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
if h.LoadBalancing != nil {
lbWait = time.Duration(h.LoadBalancing.TryInterval)
}
h.logger.Debug("retrying", zap.Error(proxyErr), zap.Duration("after", lbWait))
if c := h.logger.Check(zapcore.DebugLevel, "retrying"); c != nil {
c.Write(zap.Error(proxyErr), zap.Duration("after", lbWait))
}
}
retries++
}
// number of retries actually performed
repl.Set("http.reverse_proxy.retries", retries)
if proxyErr != nil {
return statusError(proxyErr)
}
@@ -463,13 +470,17 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
if h.DynamicUpstreams != nil {
dUpstreams, err := h.DynamicUpstreams.GetUpstreams(r)
if err != nil {
h.logger.Error("failed getting dynamic upstreams; falling back to static upstreams", zap.Error(err))
if c := h.logger.Check(zapcore.ErrorLevel, "failed getting dynamic upstreams; falling back to static upstreams"); c != nil {
c.Write(zap.Error(err))
}
} else {
upstreams = dUpstreams
for _, dUp := range dUpstreams {
h.provisionUpstream(dUp)
}
h.logger.Debug("provisioned dynamic upstreams", zap.Int("count", len(dUpstreams)))
if c := h.logger.Check(zapcore.DebugLevel, "provisioned dynamic upstreams"); c != nil {
c.Write(zap.Int("count", len(dUpstreams)))
}
defer func() {
// these upstreams are dynamic, so they are only used for this iteration
// of the proxy loop; be sure to let them go away when we're done with them
@@ -500,9 +511,12 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
return true, fmt.Errorf("making dial info: %v", err)
}
h.logger.Debug("selected upstream",
zap.String("dial", dialInfo.Address),
zap.Int("total_upstreams", len(upstreams)))
if c := h.logger.Check(zapcore.DebugLevel, "selected upstream"); c != nil {
c.Write(
zap.String("dial", dialInfo.Address),
zap.Int("total_upstreams", len(upstreams)),
)
}
// attach to the request information about how to dial the upstream;
// this is necessary because the information cannot be sufficiently
@@ -606,6 +620,18 @@ func (h Handler) prepareRequest(req *http.Request, repl *caddy.Replacer) (*http.
req.Header.Set("User-Agent", "")
}
// Indicate if request has been conveyed in early data.
// RFC 8470: "An intermediary that forwards a request prior to the
// completion of the TLS handshake with its client MUST send it with
// the Early-Data header field set to “1” (i.e., it adds it if not
// present in the request). An intermediary MUST use the Early-Data
// header field if the request might have been subject to a replay and
// might already have been forwarded by it or another instance
// (see Section 6.2)."
if req.TLS != nil && !req.TLS.HandshakeComplete {
req.Header.Set("Early-Data", "1")
}
reqUpType := upgradeType(req.Header)
removeConnectionHeaders(req.Header)
@@ -798,16 +824,22 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
ShouldLogCredentials: shouldLogCredentials,
}),
)
if err != nil {
logger.Debug("upstream roundtrip", zap.Error(err))
if c := logger.Check(zapcore.DebugLevel, "upstream roundtrip"); c != nil {
c.Write(zap.Error(err))
}
return err
}
logger.Debug("upstream roundtrip",
zap.Object("headers", caddyhttp.LoggableHTTPHeader{
Header: res.Header,
ShouldLogCredentials: shouldLogCredentials,
}),
zap.Int("status", res.StatusCode))
if c := logger.Check(zapcore.DebugLevel, "upstream roundtrip"); c != nil {
c.Write(
zap.Object("headers", caddyhttp.LoggableHTTPHeader{
Header: res.Header,
ShouldLogCredentials: shouldLogCredentials,
}),
zap.Int("status", res.StatusCode),
)
}
// duration until upstream wrote response headers (roundtrip duration)
repl.Set("http.reverse_proxy.upstream.latency", duration)
@@ -866,7 +898,9 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
repl.Set("http.reverse_proxy.status_code", res.StatusCode)
repl.Set("http.reverse_proxy.status_text", res.Status)
logger.Debug("handling response", zap.Int("handler", i))
if c := logger.Check(zapcore.DebugLevel, "handling response"); c != nil {
c.Write(zap.Int("handler", i))
}
// we make some data available via request context to child routes
// so that they may inherit some options and functions from the
@@ -962,7 +996,9 @@ func (h *Handler) finalizeResponse(
err := h.copyResponse(rw, res.Body, h.flushInterval(req, res), logger)
errClose := res.Body.Close() // close now, instead of defer, to populate res.Trailer
if h.VerboseLogs || errClose != nil {
logger.Debug("closed response body from upstream", zap.Error(errClose))
if c := logger.Check(zapcore.DebugLevel, "closed response body from upstream"); c != nil {
c.Write(zap.Error(errClose))
}
}
if err != nil {
// we're streaming the response and we've already written headers, so
@@ -970,7 +1006,9 @@ func (h *Handler) finalizeResponse(
// we'll just log the error and abort the stream here and panic just as
// the standard lib's proxy to propagate the stream error.
// see issue https://github.com/caddyserver/caddy/issues/5951
logger.Error("aborting with incomplete response", zap.Error(err))
if c := logger.Check(zapcore.WarnLevel, "aborting with incomplete response"); c != nil {
c.Write(zap.Error(err))
}
// no extra logging from stdlib
panic(http.ErrAbortHandler)
}
+59 -25
View File
@@ -31,6 +31,7 @@ import (
"unsafe"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/http/httpguts"
)
@@ -41,14 +42,18 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
// Taken from https://github.com/golang/go/commit/5c489514bc5e61ad9b5b07bd7d8ec65d66a0512a
// We know reqUpType is ASCII, it's checked by the caller.
if !asciiIsPrint(resUpType) {
logger.Debug("backend tried to switch to invalid protocol",
zap.String("backend_upgrade", resUpType))
if c := logger.Check(zapcore.DebugLevel, "backend tried to switch to invalid protocol"); c != nil {
c.Write(zap.String("backend_upgrade", resUpType))
}
return
}
if !asciiEqualFold(reqUpType, resUpType) {
logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
zap.String("backend_upgrade", resUpType),
zap.String("requested_upgrade", reqUpType))
if c := logger.Check(zapcore.DebugLevel, "backend tried to switch to unexpected protocol via Upgrade header"); c != nil {
c.Write(
zap.String("backend_upgrade", resUpType),
zap.String("requested_upgrade", reqUpType),
)
}
return
}
@@ -68,12 +73,16 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
//nolint:bodyclose
conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
h.logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw)))
if c := logger.Check(zapcore.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil {
c.Write(zap.String("type", fmt.Sprintf("%T", rw)))
}
return
}
if hijackErr != nil {
h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr))
if c := logger.Check(zapcore.ErrorLevel, "hijack failed on protocol switch"); c != nil {
c.Write(zap.Error(hijackErr))
}
return
}
@@ -93,11 +102,15 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
start := time.Now()
defer func() {
conn.Close()
logger.Debug("connection closed", zap.Duration("duration", time.Since(start)))
if c := logger.Check(zapcore.DebugLevel, "hijack failed on protocol switch"); c != nil {
c.Write(zap.Duration("duration", time.Since(start)))
}
}()
if err := brw.Flush(); err != nil {
logger.Debug("response flush", zap.Error(err))
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
c.Write(zap.Error(err))
}
return
}
@@ -107,7 +120,9 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
data, _ := brw.Peek(buffered)
_, err := backConn.Write(data)
if err != nil {
logger.Debug("backConn write failed", zap.Error(err))
if c := logger.Check(zapcore.DebugLevel, "backConn write failed"); c != nil {
c.Write(zap.Error(err))
}
return
}
}
@@ -148,9 +163,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
go spc.copyFromBackend(errc)
select {
case err := <-errc:
logger.Debug("streaming error", zap.Error(err))
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
c.Write(zap.Error(err))
}
case time := <-timeoutc:
logger.Debug("stream timed out", zap.Time("timeout", time))
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
c.Write(zap.Time("timeout", time))
}
}
}
@@ -247,7 +266,9 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
logger.Debug("waiting to read from upstream")
nr, rerr := src.Read(buf)
logger := logger.With(zap.Int("read", nr))
logger.Debug("read from upstream", zap.Error(rerr))
if c := logger.Check(zapcore.DebugLevel, "read from upstream"); c != nil {
c.Write(zap.Error(rerr))
}
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
// TODO: this could be useful to know (indeed, it revealed an error in our
// fastcgi PoC earlier; but it's this single error report here that necessitates
@@ -256,7 +277,9 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
// something we need to report to the client, but read errors are a problem on our
// end for sure. so we need to decide what we want.)
// p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr)
h.logger.Error("reading from backend", zap.Error(rerr))
if c := logger.Check(zapcore.ErrorLevel, "reading from backend"); c != nil {
c.Write(zap.Error(rerr))
}
}
if nr > 0 {
logger.Debug("writing to downstream")
@@ -264,10 +287,13 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
if nw > 0 {
written += int64(nw)
}
logger.Debug("wrote to downstream",
zap.Int("written", nw),
zap.Int64("written_total", written),
zap.Error(werr))
if c := logger.Check(zapcore.DebugLevel, "wrote to downstream"); c != nil {
c.Write(
zap.Int("written", nw),
zap.Int64("written_total", written),
zap.Error(werr),
)
}
if werr != nil {
return written, fmt.Errorf("writing: %w", werr)
}
@@ -347,13 +373,17 @@ func (h *Handler) cleanupConnections() error {
if len(h.connections) > 0 {
delay := time.Duration(h.StreamCloseDelay)
h.connectionsCloseTimer = time.AfterFunc(delay, func() {
h.logger.Debug("closing streaming connections after delay",
zap.Duration("delay", delay))
if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
c.Write(zap.Duration("delay", delay))
}
err := h.closeConnections()
if err != nil {
h.logger.Error("failed to closed connections after delay",
zap.Error(err),
zap.Duration("delay", delay))
if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil {
c.Write(
zap.Error(err),
zap.Duration("delay", delay),
)
}
}
})
}
@@ -494,7 +524,9 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
n, err = m.dst.Write(p)
m.logger.Debug("wrote bytes", zap.Int("n", n), zap.Error(err))
if c := m.logger.Check(zapcore.DebugLevel, "wrote bytes"); c != nil {
c.Write(zap.Int("n", n), zap.Error(err))
}
if m.latency < 0 {
m.logger.Debug("flushing immediately")
//nolint:errcheck
@@ -510,7 +542,9 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
} else {
m.t.Reset(m.latency)
}
m.logger.Debug("timer set for delayed flush", zap.Duration("duration", m.latency))
if c := m.logger.Check(zapcore.DebugLevel, "timer set for delayed flush"); c != nil {
c.Write(zap.Duration("duration", m.latency))
}
m.flushPending = true
return
}
+55 -33
View File
@@ -12,6 +12,7 @@ import (
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/caddyserver/caddy/v2"
)
@@ -136,10 +137,13 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
return allNew(cached.upstreams), nil
}
su.logger.Debug("refreshing SRV upstreams",
zap.String("service", service),
zap.String("proto", proto),
zap.String("name", name))
if c := su.logger.Check(zapcore.DebugLevel, "refreshing SRV upstreams"); c != nil {
c.Write(
zap.String("service", service),
zap.String("proto", proto),
zap.String("name", name),
)
}
_, records, err := su.resolver.LookupSRV(r.Context(), service, proto, name)
if err != nil {
@@ -148,23 +152,30 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// only return an error if no records were also returned.
if len(records) == 0 {
if su.GracePeriod > 0 {
su.logger.Error("SRV lookup failed; using previously cached", zap.Error(err))
if c := su.logger.Check(zapcore.ErrorLevel, "SRV lookup failed; using previously cached"); c != nil {
c.Write(zap.Error(err))
}
cached.freshness = time.Now().Add(time.Duration(su.GracePeriod) - time.Duration(su.Refresh))
srvs[suAddr] = cached
return allNew(cached.upstreams), nil
}
return nil, err
}
su.logger.Warn("SRV records filtered", zap.Error(err))
if c := su.logger.Check(zapcore.WarnLevel, "SRV records filtered"); c != nil {
c.Write(zap.Error(err))
}
}
upstreams := make([]Upstream, len(records))
for i, rec := range records {
su.logger.Debug("discovered SRV record",
zap.String("target", rec.Target),
zap.Uint16("port", rec.Port),
zap.Uint16("priority", rec.Priority),
zap.Uint16("weight", rec.Weight))
if c := su.logger.Check(zapcore.DebugLevel, "discovered SRV record"); c != nil {
c.Write(
zap.String("target", rec.Target),
zap.Uint16("port", rec.Port),
zap.Uint16("priority", rec.Priority),
zap.Uint16("weight", rec.Weight),
)
}
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
upstreams[i] = Upstream{Dial: addr}
}
@@ -231,6 +242,19 @@ type IPVersions struct {
IPv6 *bool `json:"ipv6,omitempty"`
}
func resolveIpVersion(versions *IPVersions) string {
resolveIpv4 := versions == nil || (versions.IPv4 == nil && versions.IPv6 == nil) || (versions.IPv4 != nil && *versions.IPv4)
resolveIpv6 := versions == nil || (versions.IPv6 == nil && versions.IPv4 == nil) || (versions.IPv6 != nil && *versions.IPv6)
switch {
case resolveIpv4 && !resolveIpv6:
return "ip4"
case !resolveIpv4 && resolveIpv6:
return "ip6"
default:
return "ip"
}
}
// AUpstreams provides upstreams from A/AAAA lookups.
// Results are cached and refreshed at the configured
// refresh interval.
@@ -313,9 +337,6 @@ func (au *AUpstreams) Provision(ctx caddy.Context) error {
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
resolveIpv4 := au.Versions == nil || au.Versions.IPv4 == nil || *au.Versions.IPv4
resolveIpv6 := au.Versions == nil || au.Versions.IPv6 == nil || *au.Versions.IPv6
// Map ipVersion early, so we can use it as part of the cache-key.
// This should be fairly inexpensive and comes and the upside of
// allowing the same dynamic upstream (name + port combination)
@@ -324,15 +345,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// It also forced a cache-miss if a previously cached dynamic
// upstream changes its ip version, e.g. after a config reload,
// while keeping the cache-invalidation as simple as it currently is.
var ipVersion string
switch {
case resolveIpv4 && !resolveIpv6:
ipVersion = "ip4"
case !resolveIpv4 && resolveIpv6:
ipVersion = "ip6"
default:
ipVersion = "ip"
}
ipVersion := resolveIpVersion(au.Versions)
auStr := repl.ReplaceAll(au.String()+ipVersion, "")
@@ -359,10 +372,13 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
name := repl.ReplaceAll(au.Name, "")
port := repl.ReplaceAll(au.Port, "")
au.logger.Debug("refreshing A upstreams",
zap.String("version", ipVersion),
zap.String("name", name),
zap.String("port", port))
if c := au.logger.Check(zapcore.DebugLevel, "refreshing A upstreams"); c != nil {
c.Write(
zap.String("version", ipVersion),
zap.String("name", name),
zap.String("port", port),
)
}
ips, err := au.resolver.LookupIP(r.Context(), ipVersion, name)
if err != nil {
@@ -371,8 +387,9 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams := make([]Upstream, len(ips))
for i, ip := range ips {
au.logger.Debug("discovered A record",
zap.String("ip", ip.String()))
if c := au.logger.Check(zapcore.DebugLevel, "discovered A record"); c != nil {
c.Write(zap.String("ip", ip.String()))
}
upstreams[i] = Upstream{
Dial: net.JoinHostPort(ip.String(), port),
}
@@ -465,11 +482,16 @@ func (mu MultiUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
up, err := src.GetUpstreams(r)
if err != nil {
mu.logger.Error("upstream source returned error",
zap.Int("source_idx", i),
zap.Error(err))
if c := mu.logger.Check(zapcore.ErrorLevel, "upstream source returned error"); c != nil {
c.Write(
zap.Int("source_idx", i),
zap.Error(err),
)
}
} else if len(up) == 0 {
mu.logger.Warn("upstream source returned 0 upstreams", zap.Int("source_idx", i))
if c := mu.logger.Check(zapcore.WarnLevel, "upstream source returned 0 upstreams"); c != nil {
c.Write(zap.Int("source_idx", i))
}
} else {
upstreams = append(upstreams, up...)
}
@@ -0,0 +1,56 @@
package reverseproxy
import "testing"
func TestResolveIpVersion(t *testing.T) {
falseBool := false
trueBool := true
tests := []struct {
Versions *IPVersions
expectedIpVersion string
}{
{
Versions: &IPVersions{IPv4: &trueBool},
expectedIpVersion: "ip4",
},
{
Versions: &IPVersions{IPv4: &falseBool},
expectedIpVersion: "ip",
},
{
Versions: &IPVersions{IPv4: &trueBool, IPv6: &falseBool},
expectedIpVersion: "ip4",
},
{
Versions: &IPVersions{IPv6: &trueBool},
expectedIpVersion: "ip6",
},
{
Versions: &IPVersions{IPv6: &falseBool},
expectedIpVersion: "ip",
},
{
Versions: &IPVersions{IPv6: &trueBool, IPv4: &falseBool},
expectedIpVersion: "ip6",
},
{
Versions: &IPVersions{},
expectedIpVersion: "ip",
},
{
Versions: &IPVersions{IPv4: &trueBool, IPv6: &trueBool},
expectedIpVersion: "ip",
},
{
Versions: &IPVersions{IPv4: &falseBool, IPv6: &falseBool},
expectedIpVersion: "ip",
},
}
for _, test := range tests {
ipVersion := resolveIpVersion(test.Versions)
if ipVersion != test.expectedIpVersion {
t.Errorf("resolveIpVersion(): Expected %s got %s", test.expectedIpVersion, ipVersion)
}
}
}