Compare commits

...

5 Commits

Author SHA1 Message Date
Matt Holt 4ec1ea5153 reverseproxy: Add ability to clear dynamic upstreams cache during retries (#7662)
* reverseproxy: Add ability to clear dynamic upstreams cache during retries

This is an optional interface for dynamic upstream modules to implement if they cache results.

TODO: More documentation; this is an experiment.

* Add some godoc

* Export interface; update godoc
2026-04-28 09:21:03 -06:00
Matthew Holt 163910e74e encode: Implement Flush for legacy compatibility
(By sponsor request)
2026-02-16 15:59:28 -07:00
WeidiDeng b499a1a823 reverseproxy: do not disable keepalive if proxy protocol is used (#7300) 2026-02-06 08:47:09 -07:00
Matthew Holt 99231c12ef reverseproxy: Customizable dial network for SRV upstreams
By request of a sponsor
2026-02-03 12:55:21 -07:00
jjiang-stripe 3d9b1df852 caddytls: Fix TrustedCACerts backwards compatibility (#6889)
* add failing test

* fix ca pool provisioning

* remove unused param
2025-03-11 11:38:24 -06:00
8 changed files with 146 additions and 24 deletions
+8
View File
@@ -295,6 +295,14 @@ func (rw *responseWriter) FlushError() error {
return http.NewResponseController(rw.ResponseWriter).Flush() return http.NewResponseController(rw.ResponseWriter).Flush()
} }
// Flush calls FlushError() and simply discards any error. It is only implemented for backwards
// compatibility with legacy code that does not use FlushError; we know at least one sponsor
// needs this. It should not be relied upon as a stable part of the exported API, as it may be
// removed in the future.
func (rw *responseWriter) Flush() {
_ = rw.FlushError()
}
// Write writes to the response. If the response qualifies, // Write writes to the response. If the response qualifies,
// it is encoded using the encoder, which is initialized // it is encoded using the encoder, which is initialized
// if not done so already. // if not done so already.
@@ -1504,6 +1504,7 @@ func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return d.Errf("bad delay value '%s': %v", d.Val(), err) return d.Errf("bad delay value '%s': %v", d.Val(), err)
} }
u.FallbackDelay = caddy.Duration(dur) u.FallbackDelay = caddy.Duration(dur)
case "grace_period": case "grace_period":
if !d.NextArg() { if !d.NextArg() {
return d.ArgErr() return d.ArgErr()
+3
View File
@@ -283,3 +283,6 @@ const proxyProtocolInfoVarKey = "reverse_proxy.proxy_protocol_info"
type ProxyProtocolInfo struct { type ProxyProtocolInfo struct {
AddrPort netip.AddrPort AddrPort netip.AddrPort
} }
// proxyVarKey is the key used that indicates the proxy server used for a request.
const proxyVarKey = "reverse_proxy.proxy"
+19 -17
View File
@@ -236,15 +236,15 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
} }
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
// For unix socket upstreams, we need to recover the dial info from // The network is usually tcp, and the address is the host in http.Request.URL.Host
// the request's context, because the Host on the request's URL // and that's been overwritten in directRequest
// will have been modified by directing the request, overwriting // However, if proxy is used according to http.ProxyFromEnvironment or proxy providers,
// the unix socket filename. // address will be the address of the proxy server.
// Also, we need to avoid overwriting the address at this point
// when not necessary, because http.ProxyFromEnvironment may have // This means we can safely use the address in dialInfo if proxy is not used (the address and network will be same any way)
// modified the address according to the user's env proxy config. // or if the upstream is unix (because there is no way socks or http proxy can be used for unix address).
if dialInfo, ok := GetDialInfo(ctx); ok { if dialInfo, ok := GetDialInfo(ctx); ok {
if strings.HasPrefix(dialInfo.Network, "unix") { if caddyhttp.GetVar(ctx, proxyVarKey) == nil || strings.HasPrefix(dialInfo.Network, "unix") {
network = dialInfo.Network network = dialInfo.Network
address = dialInfo.Address address = dialInfo.Address
} }
@@ -339,9 +339,19 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
} else { } else {
proxy = http.ProxyFromEnvironment proxy = http.ProxyFromEnvironment
} }
// we need to keep track if a proxy is used for a request
proxyWrapper := func(req *http.Request) (*url.URL, error) {
u, err := proxy(req)
if u == nil || err != nil {
return u, err
}
// there must be a proxy for this request
caddyhttp.SetVar(req.Context(), proxyVarKey, u)
return u, nil
}
rt := &http.Transport{ rt := &http.Transport{
Proxy: proxy, Proxy: proxyWrapper,
DialContext: dialContext, DialContext: dialContext,
MaxConnsPerHost: h.MaxConnsPerHost, MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
@@ -370,14 +380,6 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout) rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
} }
// The proxy protocol header can only be sent once right after opening the connection.
// So single connection must not be used for multiple requests, which can potentially
// come from different clients.
if !rt.DisableKeepAlives && h.ProxyProtocol != "" {
caddyCtx.Logger().Warn("disabling keepalives, they are incompatible with using PROXY protocol")
rt.DisableKeepAlives = true
}
if h.Compression != nil { if h.Compression != nil {
rt.DisableCompression = !*h.Compression rt.DisableCompression = !*h.Compression
} }
+37 -1
View File
@@ -494,6 +494,17 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
// get the updated list of upstreams // get the updated list of upstreams
upstreams := h.Upstreams upstreams := h.Upstreams
if h.DynamicUpstreams != nil { if h.DynamicUpstreams != nil {
if retries > 0 {
// after a failure (and thus during a retry), give dynamic upstream modules an opportunity
// to purge their relevant cache entries so we don't keep retrying bad upstreams
if cachingDynamicUpstreams, ok := h.DynamicUpstreams.(CachingUpstreamSource); ok {
if err := cachingDynamicUpstreams.ResetCache(r); err != nil {
if c := h.logger.Check(zapcore.ErrorLevel, "failed clearing dynamic upstream source's cache"); c != nil {
c.Write(zap.Error(err))
}
}
}
}
dUpstreams, err := h.DynamicUpstreams.GetUpstreams(r) dUpstreams, err := h.DynamicUpstreams.GetUpstreams(r)
if err != nil { if err != nil {
if c := h.logger.Check(zapcore.ErrorLevel, "failed getting dynamic upstreams; falling back to static upstreams"); c != nil { if c := h.logger.Check(zapcore.ErrorLevel, "failed getting dynamic upstreams; falling back to static upstreams"); c != nil {
@@ -1175,7 +1186,7 @@ func (lb LoadBalancing) tryAgain(ctx caddy.Context, start time.Time, retries int
// directRequest modifies only req.URL so that it points to the upstream // directRequest modifies only req.URL so that it points to the upstream
// in the given DialInfo. It must modify ONLY the request URL. // in the given DialInfo. It must modify ONLY the request URL.
func (Handler) directRequest(req *http.Request, di DialInfo) { func (h *Handler) directRequest(req *http.Request, di DialInfo) {
// we need a host, so set the upstream's host address // we need a host, so set the upstream's host address
reqHost := di.Address reqHost := di.Address
@@ -1186,6 +1197,13 @@ func (Handler) directRequest(req *http.Request, di DialInfo) {
reqHost = di.Host reqHost = di.Host
} }
// add client address to the host to let transport differentiate requests from different clients
if ht, ok := h.Transport.(*HTTPTransport); ok && ht.ProxyProtocol != "" {
if proxyProtocolInfo, ok := caddyhttp.GetVar(req.Context(), proxyProtocolInfoVarKey).(ProxyProtocolInfo); ok {
reqHost = proxyProtocolInfo.AddrPort.String() + "->" + reqHost
}
}
req.URL.Host = reqHost req.URL.Host = reqHost
} }
@@ -1428,10 +1446,28 @@ type Selector interface {
// may be called during each retry, multiple times per request, and as // may be called during each retry, multiple times per request, and as
// such, needs to be instantaneous. The returned slice will not be // such, needs to be instantaneous. The returned slice will not be
// modified. // modified.
//
// For upstream sources that cache results, implement the
// [CachingUpstreamSource] interface for optimal performance.
type UpstreamSource interface { type UpstreamSource interface {
GetUpstreams(*http.Request) ([]*Upstream, error) GetUpstreams(*http.Request) ([]*Upstream, error)
} }
// CachingUpstreamSource is an upstream source that caches its upstreams.
// The relevant cache entry can be cleared/reset for a given request during
// retries if a request fails. This can help ensure that failing backends
// are not retried.
//
// EXPERIMENTAL: Subject to change.
type CachingUpstreamSource interface {
UpstreamSource
// ResetCache clears any cache entry related to the given request.
// The next time GetUpstreams is called, it should have new upstream
// information for the given request.
ResetCache(*http.Request) error
}
// Hop-by-hop headers. These are removed when sent to the backend. // Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the // As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the // Connection header field. These are the headers defined by the
+25 -4
View File
@@ -70,6 +70,11 @@ type SRVUpstreams struct {
// A negative value disables this. // A negative value disables this.
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"` FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
// Specific network to dial when connecting to the upstream(s)
// provided by SRV records upstream. See Go's net package for
// accepted values. For example, to restrict to IPv4, use "tcp4".
DialNetwork string `json:"dial_network,omitempty"`
resolver *net.Resolver resolver *net.Resolver
logger *zap.Logger logger *zap.Logger
@@ -114,6 +119,18 @@ func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
return nil return nil
} }
func (su *SRVUpstreams) ResetCache(r *http.Request) error {
srvsMu.Lock()
if r == nil {
srvs = make(map[string]srvLookup)
} else {
suAddr, _, _, _ := su.expandedAddr(r)
delete(srvs, suAddr)
}
srvsMu.Unlock()
return nil
}
func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
suAddr, service, proto, name := su.expandedAddr(r) suAddr, service, proto, name := su.expandedAddr(r)
@@ -177,6 +194,9 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
) )
} }
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port))) addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
if su.DialNetwork != "" {
addr = su.DialNetwork + "/" + addr
}
upstreams[i] = Upstream{Dial: addr} upstreams[i] = Upstream{Dial: addr}
} }
@@ -546,8 +566,9 @@ var (
// Interface guards // Interface guards
var ( var (
_ caddy.Provisioner = (*SRVUpstreams)(nil) _ caddy.Provisioner = (*SRVUpstreams)(nil)
_ UpstreamSource = (*SRVUpstreams)(nil) _ UpstreamSource = (*SRVUpstreams)(nil)
_ caddy.Provisioner = (*AUpstreams)(nil) _ CachingUpstreamSource = (*SRVUpstreams)(nil)
_ UpstreamSource = (*AUpstreams)(nil) _ caddy.Provisioner = (*AUpstreams)(nil)
_ UpstreamSource = (*AUpstreams)(nil)
) )
+6 -2
View File
@@ -749,10 +749,14 @@ func (clientauth *ClientAuthentication) provision(ctx caddy.Context) error {
// if we have TrustedCACerts explicitly set, create an 'inline' CA and return // if we have TrustedCACerts explicitly set, create an 'inline' CA and return
if len(clientauth.TrustedCACerts) > 0 { if len(clientauth.TrustedCACerts) > 0 {
clientauth.ca = InlineCAPool{ caPool := InlineCAPool{
TrustedCACerts: clientauth.TrustedCACerts, TrustedCACerts: clientauth.TrustedCACerts,
} }
return nil err := caPool.Provision(ctx)
if err != nil {
return nil
}
clientauth.ca = caPool
} }
// if we don't have any CARaw set, there's not much work to do // if we don't have any CARaw set, there's not much work to do
+47
View File
@@ -20,6 +20,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
) )
@@ -278,3 +279,49 @@ func TestClientAuthenticationUnmarshalCaddyfileWithDirectiveName(t *testing.T) {
}) })
} }
} }
func TestClientAuthenticationProvision(t *testing.T) {
tests := []struct {
name string
ca ClientAuthentication
wantErr bool
}{
{
name: "specifying both 'CARaw' and 'TrustedCACerts' produces an error",
ca: ClientAuthentication{
CARaw: json.RawMessage(`{"provider":"inline","trusted_ca_certs":["foo"]}`),
TrustedCACerts: []string{"foo"},
},
wantErr: true,
},
{
name: "specifying both 'CARaw' and 'TrustedCACertPEMFiles' produces an error",
ca: ClientAuthentication{
CARaw: json.RawMessage(`{"provider":"inline","trusted_ca_certs":["foo"]}`),
TrustedCACertPEMFiles: []string{"foo"},
},
wantErr: true,
},
{
name: "setting 'TrustedCACerts' provisions the cert pool",
ca: ClientAuthentication{
TrustedCACerts: []string{test_der_1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.ca.provision(caddy.Context{})
if (err != nil) != tt.wantErr {
t.Errorf("ClientAuthentication.provision() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if tt.ca.ca.CertPool() == nil {
t.Error("CertPool is nil, expected non-nil value")
}
}
})
}
}