Merge branch 'master' into add-tests

This commit is contained in:
Mohammed Al Sahaf
2026-03-20 04:17:15 +03:00
149 changed files with 9739 additions and 1396 deletions
+28 -2
View File
@@ -18,6 +18,7 @@ import (
"cmp"
"context"
"crypto/tls"
"errors"
"fmt"
"maps"
"net"
@@ -51,6 +52,7 @@ func init() {
// Placeholder | Description
// ------------|---------------
// `{http.request.body}` | The request body (⚠️ inefficient; use only for debugging)
// `{http.request.body_base64}` | The request body, base64-encoded (⚠️ for debugging)
// `{http.request.cookie.*}` | HTTP request cookie
// `{http.request.duration}` | Time up to now spent handling the request (after decoding headers from client)
// `{http.request.duration_ms}` | Same as 'duration', but in milliseconds.
@@ -82,6 +84,7 @@ func init() {
// `{http.request.tls.proto}` | The negotiated next protocol
// `{http.request.tls.proto_mutual}` | The negotiated next protocol was advertised by the server
// `{http.request.tls.server_name}` | The server name requested by the client, if any
// `{http.request.tls.ech}` | Whether ECH was offered by the client and accepted by the server
// `{http.request.tls.client.fingerprint}` | The SHA256 checksum of the client certificate
// `{http.request.tls.client.public_key}` | The public key of the client certificate.
// `{http.request.tls.client.public_key_sha256}` | The SHA256 checksum of the client's public key.
@@ -198,6 +201,8 @@ func (app *App) Provision(ctx caddy.Context) error {
if app.Metrics != nil {
app.Metrics.init = sync.Once{}
app.Metrics.httpMetrics = &httpMetrics{}
// Scan config for allowed hosts to prevent cardinality explosion
app.Metrics.scanConfigForHosts(app)
}
// prepare each server
oldContext := ctx.Context
@@ -344,6 +349,20 @@ func (app *App) Provision(ctx caddy.Context) error {
srv.listenerWrappers = append([]caddy.ListenerWrapper{new(tlsPlaceholderWrapper)}, srv.listenerWrappers...)
}
}
// set up each packet conn modifier
if srv.PacketConnWrappersRaw != nil {
vals, err := ctx.LoadModule(srv, "PacketConnWrappersRaw")
if err != nil {
return fmt.Errorf("loading packet conn wrapper modules: %v", err)
}
// if any wrappers were configured, they come before the QUIC handshake;
// unlike TLS above, there is no QUIC placeholder
for _, val := range vals.([]any) {
srv.packetConnWrappers = append(srv.packetConnWrappers, val.(caddy.PacketConnWrapper))
}
}
// pre-compile the primary handler chain, and be sure to wrap it in our
// route handler so that important security checks are done, etc.
primaryRoute := emptyHandler
@@ -693,9 +712,10 @@ func (app *App) Stop() error {
// enforce grace period if configured
if app.GracePeriod > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(app.GracePeriod))
timeout := time.Duration(app.GracePeriod)
ctx, cancel = context.WithTimeoutCause(ctx, timeout, fmt.Errorf("server graceful shutdown %ds timeout", int(timeout.Seconds())))
defer cancel()
app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", time.Duration(app.GracePeriod)))
app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", timeout))
} else {
app.logger.Info("servers shutting down with eternal grace period")
}
@@ -721,6 +741,9 @@ func (app *App) Stop() error {
}
if err := server.server.Shutdown(ctx); err != nil {
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
err = cause
}
app.logger.Error("server shutdown",
zap.Error(err),
zap.Strings("addresses", server.Listen))
@@ -744,6 +767,9 @@ func (app *App) Stop() error {
}
if err := server.h3server.Shutdown(ctx); err != nil {
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
err = cause
}
app.logger.Error("HTTP/3 server shutdown",
zap.Error(err),
zap.Strings("addresses", server.Listen))
+101 -3
View File
@@ -90,7 +90,16 @@ func (app *App) automaticHTTPSPhase1(ctx caddy.Context, repl *caddy.Replacer) er
// the log configuration for an HTTPS enabled server
var logCfg *ServerLogConfig
for srvName, srv := range app.Servers {
// Sort server names to ensure deterministic iteration.
// This prevents race conditions where the order of server processing
// could affect which server gets assigned the HTTP->HTTPS redirect listener.
srvNames := make([]string, 0, len(app.Servers))
for name := range app.Servers {
srvNames = append(srvNames, name)
}
slices.Sort(srvNames)
for _, srvName := range srvNames {
srv := app.Servers[srvName]
// as a prerequisite, provision route matchers; this is
// required for all routes on all servers, and must be
// done before we attempt to do phase 1 of auto HTTPS,
@@ -398,15 +407,60 @@ uniqueDomainsLoop:
return append(routes, app.makeRedirRoute(uint(app.httpsPort()), MatcherSet{MatchProtocol("http")}))
}
// Sort redirect addresses to ensure deterministic process
redirServerAddrsSorted := make([]string, 0, len(redirServers))
for addr := range redirServers {
redirServerAddrsSorted = append(redirServerAddrsSorted, addr)
}
slices.Sort(redirServerAddrsSorted)
redirServersLoop:
for redirServerAddr, routes := range redirServers {
for _, redirServerAddr := range redirServerAddrsSorted {
routes := redirServers[redirServerAddr]
// for each redirect listener, see if there's already a
// server configured to listen on that exact address; if so,
// insert the redirect route to the end of its route list
// after any other routes with host matchers; otherwise,
// we'll create a new server for all the listener addresses
// that are unused and serve the remaining redirects from it
for _, srv := range app.Servers {
// Sort redirect routes by host specificity to ensure exact matches
// take precedence over wildcards, preventing ambiguous routing.
slices.SortFunc(routes, func(a, b Route) int {
hostA := getFirstHostFromRoute(a)
hostB := getFirstHostFromRoute(b)
// Catch-all routes (empty host) have the lowest priority
if hostA == "" && hostB != "" {
return 1
}
if hostB == "" && hostA != "" {
return -1
}
hasWildcardA := strings.Contains(hostA, "*")
hasWildcardB := strings.Contains(hostB, "*")
// Exact domains take precedence over wildcards
if !hasWildcardA && hasWildcardB {
return -1
}
if hasWildcardA && !hasWildcardB {
return 1
}
// If both are exact or both are wildcards, the longer one is more specific
if len(hostA) != len(hostB) {
return len(hostB) - len(hostA)
}
// Tie-breaker: alphabetical order to ensure determinism
return strings.Compare(hostA, hostB)
})
// Use the sorted srvNames to consistently find the target server
for _, srvName := range srvNames {
srv := app.Servers[srvName]
// only look at servers which listen on an address which
// we want to add redirects to
if !srv.hasListenerAddress(redirServerAddr) {
@@ -560,6 +614,27 @@ func (app *App) createAutomationPolicies(ctx caddy.Context, internalNames, tails
}
}
// Ensure automation policies' CertMagic configs are rebuilt when
// ACME issuer templates may have been modified above (for example,
// alternate ports filled in by the HTTP app). If a policy is already
// provisioned, perform a lightweight rebuild of the CertMagic config
// so issuers receive SetConfig with the updated templates; otherwise
// run a normal Provision to initialize the policy.
for i, ap := range app.tlsApp.Automation.Policies {
// If the policy is already provisioned, rebuild only the CertMagic
// config so issuers get SetConfig with updated templates. Otherwise
// provision the policy normally (which may load modules).
if ap.IsProvisioned() {
if err := ap.RebuildCertMagic(app.tlsApp); err != nil {
return fmt.Errorf("rebuilding certmagic config for automation policy %d: %v", i, err)
}
} else {
if err := ap.Provision(app.tlsApp); err != nil {
return fmt.Errorf("provisioning automation policy %d after auto-HTTPS defaults: %v", i, err)
}
}
}
if basePolicy == nil {
// no base policy found; we will make one
basePolicy = new(caddytls.AutomationPolicy)
@@ -773,3 +848,26 @@ func isTailscaleDomain(name string) bool {
}
type acmeCapable interface{ GetACMEIssuer() *caddytls.ACMEIssuer }
// getFirstHostFromRoute traverses a route's matchers to find the Host rule.
// Since we are dealing with internally generated redirect routes, the host
// is typically the first string within the MatchHost.
func getFirstHostFromRoute(r Route) string {
for _, matcherSet := range r.MatcherSets {
for _, m := range matcherSet {
// Check if the matcher is of type MatchHost (value or pointer)
switch hm := m.(type) {
case MatchHost:
if len(hm) > 0 {
return hm[0]
}
case *MatchHost:
if len(*hm) > 0 {
return (*hm)[0]
}
}
}
}
// Return an empty string if it's a catch-all route (no specific host)
return ""
}
+3 -3
View File
@@ -19,7 +19,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
weakrand "math/rand"
weakrand "math/rand/v2"
"net/http"
"strings"
"sync"
@@ -244,7 +244,7 @@ func (c *Cache) makeRoom() {
// strategy; generating random numbers is cheap and
// ensures a much better distribution.
//nolint:gosec
rnd := weakrand.Intn(len(c.cache))
rnd := weakrand.IntN(len(c.cache))
i := 0
for key := range c.cache {
if i == rnd {
@@ -287,7 +287,7 @@ type Account struct {
// The user's hashed password, in Modular Crypt Format (with `$` prefix)
// or base64-encoded.
Password string `json:"password"`
Password string `json:"password"` //nolint:gosec // false positive, this is a hashed password
password []byte
}
+37 -5
View File
@@ -412,10 +412,12 @@ func CELMatcherImpl(macroName, funcName string, matcherDataTypes []*cel.Type, fa
return nil, fmt.Errorf("unsupported matcher data type: %s, %s", matcherDataTypes[0], matcherDataTypes[1])
}
case 3:
// nolint:gosec // false positive, impossible to be out of bounds; see: https://github.com/securego/gosec/issues/1525
if matcherDataTypes[0] == cel.StringType && matcherDataTypes[1] == cel.StringType && matcherDataTypes[2] == cel.StringType {
macro = parser.NewGlobalMacro(macroName, 3, celMatcherStringListMacroExpander(funcName))
matcherDataTypes = []*cel.Type{cel.ListType(cel.StringType)}
} else {
// nolint:gosec // false positive, impossible to be out of bounds; see: https://github.com/securego/gosec/issues/1525
return nil, fmt.Errorf("unsupported matcher data type: %s, %s, %s", matcherDataTypes[0], matcherDataTypes[1], matcherDataTypes[2])
}
}
@@ -665,12 +667,29 @@ func celMatcherJSONMacroExpander(funcName string) parser.MacroExpander {
// map literals containing heterogeneous values, in this case string and list
// of string.
func CELValueToMapStrList(data ref.Val) (map[string][]string, error) {
mapStrType := reflect.TypeOf(map[string]any{})
// Prefer map[string]any, but newer cel-go versions may return map[any]any
mapStrType := reflect.TypeFor[map[string]any]()
mapStrRaw, err := data.ConvertToNative(mapStrType)
var mapStrIface map[string]any
if err != nil {
return nil, err
// Try map[any]any and convert keys to strings
mapAnyType := reflect.TypeFor[map[any]any]()
mapAnyRaw, err2 := data.ConvertToNative(mapAnyType)
if err2 != nil {
return nil, err
}
mapAnyIface := mapAnyRaw.(map[any]any)
mapStrIface = make(map[string]any, len(mapAnyIface))
for k, v := range mapAnyIface {
ks, ok := k.(string)
if !ok {
return nil, fmt.Errorf("unsupported map key type in header match: %T", k)
}
mapStrIface[ks] = v
}
} else {
mapStrIface = mapStrRaw.(map[string]any)
}
mapStrIface := mapStrRaw.(map[string]any)
mapStrListStr := make(map[string][]string, len(mapStrIface))
for k, v := range mapStrIface {
switch val := v.(type) {
@@ -685,13 +704,26 @@ func CELValueToMapStrList(data ref.Val) (map[string][]string, error) {
for i, elem := range val {
strVal, ok := elem.(types.String)
if !ok {
return nil, fmt.Errorf("unsupported value type in header match: %T", val)
return nil, fmt.Errorf("unsupported value type in matcher input: %T", val)
}
convVals[i] = string(strVal)
}
mapStrListStr[k] = convVals
case []any:
convVals := make([]string, len(val))
for i, elem := range val {
switch e := elem.(type) {
case string:
convVals[i] = e
case types.String:
convVals[i] = string(e)
default:
return nil, fmt.Errorf("unsupported element type in matcher input list: %T", elem)
}
}
mapStrListStr[k] = convVals
default:
return nil, fmt.Errorf("unsupported value type in header match: %T", val)
return nil, fmt.Errorf("unsupported value type in matcher input: %T", val)
}
}
return mapStrListStr, nil
+2 -2
View File
@@ -168,8 +168,8 @@ func (enc *Encode) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
// caches without knowing about our changes...
if etag := r.Header.Get("If-None-Match"); etag != "" && !strings.HasPrefix(etag, "W/") {
ourSuffix := "-" + encName + `"`
if strings.HasSuffix(etag, ourSuffix) {
etag = strings.TrimSuffix(etag, ourSuffix) + `"`
if before, ok := strings.CutSuffix(etag, ourSuffix); ok {
etag = before + `"`
r.Header.Set("If-None-Match", etag)
}
}
+2 -2
View File
@@ -17,7 +17,7 @@ package caddyhttp
import (
"errors"
"fmt"
weakrand "math/rand"
weakrand "math/rand/v2"
"path"
"runtime"
"strings"
@@ -101,7 +101,7 @@ func randString(n int, sameCase bool) string {
b := make([]byte, n)
for i := range b {
//nolint:gosec
b[i] = dict[weakrand.Int63()%int64(len(dict))]
b[i] = dict[weakrand.IntN(len(dict))]
}
return string(b)
}
+1
View File
@@ -169,6 +169,7 @@ func (fsrv *FileServer) serveBrowse(fileSystem fs.FS, root, dirPath string, w ht
// Actual files
for _, item := range listing.Items {
//nolint:gosec // not sure how this could be XSS unless you lose control of the file system (like aren't sanitizing) and client ignores Content-Type of text/plain
if _, err := fmt.Fprintf(writer, "%s\t%s\t%s\n",
item.Name, item.HumanSize(), item.HumanModTime("January 2, 2006 at 15:04:05"),
); err != nil {
+2 -1
View File
@@ -404,7 +404,7 @@ func (m MatchFile) selectFile(r *http.Request) (bool, error) {
}
// for each glob result, combine all the forms of the path
var candidates []matchCandidate
candidates := make([]matchCandidate, 0, len(globResults))
for _, result := range globResults {
candidates = append(candidates, matchCandidate{
fullpath: result,
@@ -720,6 +720,7 @@ var globSafeRepl = strings.NewReplacer(
"*", "\\*",
"[", "\\[",
"?", "\\?",
"\\", "\\\\",
)
const (
+69 -38
View File
@@ -20,7 +20,9 @@ import (
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/caddyserver/caddy/v2"
@@ -28,6 +30,13 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
type testCase struct {
path string
expectedPath string
expectedType string
matched bool
}
func TestFileMatcher(t *testing.T) {
// Windows doesn't like colons in files names
isWindows := runtime.GOOS == "windows"
@@ -45,12 +54,7 @@ func TestFileMatcher(t *testing.T) {
f.Close()
}
for i, tc := range []struct {
path string
expectedPath string
expectedType string
matched bool
}{
for i, tc := range []testCase{
{
path: "/foo.txt",
expectedPath: "/foo.txt",
@@ -116,44 +120,71 @@ func TestFileMatcher(t *testing.T) {
matched: !isWindows,
},
} {
m := &MatchFile{
fsmap: &filesystems.FileSystemMap{},
Root: "./testdata",
TryFiles: []string{"{http.request.uri.path}", "{http.request.uri.path}/"},
}
fileMatcherTest(t, i, tc)
}
}
u, err := url.Parse(tc.path)
if err != nil {
t.Errorf("Test %d: parsing path: %v", i, err)
}
func TestFileMatcherNonWindows(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
req := &http.Request{URL: u}
repl := caddyhttp.NewTestReplacer(req)
// this is impossible to test on Windows, but tests a security patch for other platforms
tc := testCase{
path: "/foodir/secr%5Cet.txt",
expectedPath: "/foodir/secr\\et.txt",
expectedType: "file",
matched: true,
}
result, err := m.MatchWithError(req)
if err != nil {
t.Errorf("Test %d: unexpected error: %v", i, err)
}
if result != tc.matched {
t.Errorf("Test %d: expected match=%t, got %t", i, tc.matched, result)
}
f, err := os.Create(filepath.Join("testdata", strings.TrimPrefix(tc.expectedPath, "/")))
if err != nil {
t.Fatalf("could not create test file: %v", err)
}
defer f.Close()
defer os.Remove(f.Name())
rel, ok := repl.Get("http.matchers.file.relative")
if !ok && result {
t.Errorf("Test %d: expected replacer value", i)
}
if !result {
continue
}
fileMatcherTest(t, 0, tc)
}
if rel != tc.expectedPath {
t.Errorf("Test %d: actual path: %v, expected: %v", i, rel, tc.expectedPath)
}
func fileMatcherTest(t *testing.T, i int, tc testCase) {
m := &MatchFile{
fsmap: &filesystems.FileSystemMap{},
Root: "./testdata",
TryFiles: []string{"{http.request.uri.path}", "{http.request.uri.path}/"},
}
fileType, _ := repl.Get("http.matchers.file.type")
if fileType != tc.expectedType {
t.Errorf("Test %d: actual file type: %v, expected: %v", i, fileType, tc.expectedType)
}
u, err := url.Parse(tc.path)
if err != nil {
t.Errorf("Test %d: parsing path: %v", i, err)
}
req := &http.Request{URL: u}
repl := caddyhttp.NewTestReplacer(req)
result, err := m.MatchWithError(req)
if err != nil {
t.Errorf("Test %d: unexpected error: %v", i, err)
}
if result != tc.matched {
t.Errorf("Test %d: expected match=%t, got %t", i, tc.matched, result)
}
rel, ok := repl.Get("http.matchers.file.relative")
if !ok && result {
t.Errorf("Test %d: expected replacer value", i)
}
if !result {
return
}
if rel != tc.expectedPath {
t.Errorf("Test %d: actual path: %v, expected: %v", i, rel, tc.expectedPath)
}
fileType, _ := repl.Get("http.matchers.file.type")
if fileType != tc.expectedType {
t.Errorf("Test %d: actual file type: %v, expected: %v", i, fileType, tc.expectedType)
}
}
+7 -2
View File
@@ -20,7 +20,7 @@ import (
"fmt"
"io"
"io/fs"
weakrand "math/rand"
weakrand "math/rand/v2"
"mime"
"net/http"
"os"
@@ -125,6 +125,11 @@ type FileServer struct {
// When possible, all paths are resolved to their absolute form before
// comparisons are made. For maximum clarity and explictness, use complete,
// absolute paths; or, for greater portability, use relative paths instead.
//
// Note that hide comparisons are case-sensitive. On case-insensitive
// filesystems, requests with different path casing may still resolve to the
// same file or directory on disk, so hide should not be treated as a
// security boundary for sensitive paths.
Hide []string `json:"hide,omitempty"`
// The names of files to try as index files if a folder is requested.
@@ -601,7 +606,7 @@ func (fsrv *FileServer) openFile(fileSystem fs.FS, filename string, w http.Respo
// maybe the server is under load and ran out of file descriptors?
// have client wait arbitrary seconds to help prevent a stampede
//nolint:gosec
backoff := weakrand.Intn(maxBackoff-minBackoff) + minBackoff
backoff := weakrand.IntN(maxBackoff-minBackoff) + minBackoff
w.Header().Set("Retry-After", strconv.Itoa(backoff))
if c := fsrv.logger.Check(zapcore.DebugLevel, "retry after backoff"); c != nil {
c.Write(zap.String("filename", filename), zap.Int("backoff", backoff), zap.Error(err))
+1 -3
View File
@@ -168,8 +168,6 @@ func parseReqHdrCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
}
h.Next() // consume the directive name again (matcher parsing resets)
configValues := []httpcaddyfile.ConfigValue{}
if !h.NextArg() {
return nil, h.ArgErr()
}
@@ -204,7 +202,7 @@ func parseReqHdrCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
return nil, h.Err(err.Error())
}
configValues = append(configValues, h.NewRoute(matcherSet, hdr)...)
configValues := h.NewRoute(matcherSet, hdr)
if h.NextArg() {
return nil, h.ArgErr()
+7 -4
View File
@@ -161,11 +161,11 @@ func (ops *HeaderOps) Provision(_ caddy.Context) error {
// containsPlaceholders checks if the string contains Caddy placeholder syntax {key}
func containsPlaceholders(s string) bool {
openIdx := strings.Index(s, "{")
if openIdx == -1 {
_, after, ok := strings.Cut(s, "{")
if !ok {
return false
}
closeIdx := strings.Index(s[openIdx+1:], "}")
closeIdx := strings.Index(after, "}")
if closeIdx == -1 {
return false
}
@@ -217,7 +217,10 @@ type RespHeaderOps struct {
}
// ApplyTo applies ops to hdr using repl.
func (ops HeaderOps) ApplyTo(hdr http.Header, repl *caddy.Replacer) {
func (ops *HeaderOps) ApplyTo(hdr http.Header, repl *caddy.Replacer) {
if ops == nil {
return
}
// before manipulating headers in other ways, check if there
// is configuration to delete all headers, and do that first
// because if a header is to be added, we don't want to delete
+28 -2
View File
@@ -17,6 +17,7 @@ package intercept
import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"
"strings"
@@ -175,10 +176,35 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy
c.Write(zap.Int("handler", rec.handlerIndex))
}
// pass the request through the response handler routes
return rec.handler.Routes.Compile(next).ServeHTTP(w, r)
// response recorder doesn't create a new copy of the original headers, they're
// present in the original response writer
// create a new recorder to see if any response body from the new handler is present,
// if not, use the already buffered response body
recorder := caddyhttp.NewResponseRecorder(w, nil, nil)
if err := rec.handler.Routes.Compile(emptyHandler).ServeHTTP(recorder, r); err != nil {
return err
}
// no new response status and the status is not 0
if recorder.Status() == 0 && rec.Status() != 0 {
w.WriteHeader(rec.Status())
}
// no new response body and there is some in the original response
// TODO: what if the new response doesn't have a body by design?
// see: https://github.com/caddyserver/caddy/pull/6232#issue-2235224400
if recorder.Size() == 0 && buf.Len() > 0 {
_, err := io.Copy(w, buf)
return err
}
return nil
}
// this handler does nothing because everything we need is already buffered
var emptyHandler caddyhttp.Handler = caddyhttp.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) error {
return nil
})
// UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax:
//
// intercept [<matcher>] {
+78 -1
View File
@@ -15,18 +15,28 @@
package caddyhttp
import (
"context"
"encoding/json"
"errors"
"log/slog"
"net"
"net/http"
"strings"
"sync"
"go.uber.org/zap"
"go.uber.org/zap/exp/zapslog"
"go.uber.org/zap/zapcore"
"github.com/caddyserver/caddy/v2"
)
func init() {
caddy.RegisterSlogHandlerFactory(func(handler slog.Handler, core zapcore.Core, moduleID string) slog.Handler {
return &extraFieldsSlogHandler{defaultHandler: handler, core: core, moduleID: moduleID}
})
}
// ServerLogConfig describes a server's logging configuration. If
// enabled without customization, all requests to this server are
// logged to the default logger; logger destinations may be
@@ -223,17 +233,21 @@ func errLogValues(err error) (status int, msg string, fields func() []zapcore.Fi
// ExtraLogFields is a list of extra fields to log with every request.
type ExtraLogFields struct {
fields []zapcore.Field
fields []zapcore.Field
handlers sync.Map
}
// Add adds a field to the list of extra fields to log.
func (e *ExtraLogFields) Add(field zap.Field) {
e.handlers.Clear()
e.fields = append(e.fields, field)
}
// Set sets a field in the list of extra fields to log.
// If the field already exists, it is replaced.
func (e *ExtraLogFields) Set(field zap.Field) {
e.handlers.Clear()
for i := range e.fields {
if e.fields[i].Key == field.Key {
e.fields[i] = field
@@ -243,6 +257,29 @@ func (e *ExtraLogFields) Set(field zap.Field) {
e.fields = append(e.fields, field)
}
func (e *ExtraLogFields) getSloggerHandler(handler *extraFieldsSlogHandler) (h slog.Handler) {
if existing, ok := e.handlers.Load(handler); ok {
return existing.(slog.Handler)
}
if handler.moduleID == "" {
h = zapslog.NewHandler(handler.core.With(e.fields))
} else {
h = zapslog.NewHandler(handler.core.With(e.fields), zapslog.WithName(handler.moduleID))
}
if handler.group != "" {
h = h.WithGroup(handler.group)
}
if handler.attrs != nil {
h = h.WithAttrs(handler.attrs)
}
e.handlers.Store(handler, h)
return h
}
const (
// Variable name used to indicate that this request
// should be omitted from the access logs
@@ -254,3 +291,43 @@ const (
// Variable name used to indicate the logger to be used
AccessLoggerNameVarKey string = "access_logger_names"
)
type extraFieldsSlogHandler struct {
defaultHandler slog.Handler
core zapcore.Core
moduleID string
group string
attrs []slog.Attr
}
func (e *extraFieldsSlogHandler) Enabled(ctx context.Context, level slog.Level) bool {
return e.defaultHandler.Enabled(ctx, level)
}
func (e *extraFieldsSlogHandler) Handle(ctx context.Context, record slog.Record) error {
if elf, ok := ctx.Value(ExtraLogFieldsCtxKey).(*ExtraLogFields); ok {
return elf.getSloggerHandler(e).Handle(ctx, record)
}
return e.defaultHandler.Handle(ctx, record)
}
func (e *extraFieldsSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return &extraFieldsSlogHandler{
e.defaultHandler.WithAttrs(attrs),
e.core,
e.moduleID,
e.group,
append(e.attrs, attrs...),
}
}
func (e *extraFieldsSlogHandler) WithGroup(name string) slog.Handler {
return &extraFieldsSlogHandler{
e.defaultHandler.WithGroup(name),
e.core,
e.moduleID,
name,
e.attrs,
}
}
+7 -1
View File
@@ -15,6 +15,8 @@
package logging
import (
"strings"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
@@ -26,7 +28,7 @@ func init() {
// parseCaddyfile sets up the log_append handler from Caddyfile tokens. Syntax:
//
// log_append [<matcher>] <key> <value>
// log_append [<matcher>] [<]<key> <value>
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
handler := new(LogAppend)
err := handler.UnmarshalCaddyfile(h.Dispenser)
@@ -43,6 +45,10 @@ func (h *LogAppend) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
if !d.NextArg() {
return d.ArgErr()
}
if strings.HasPrefix(h.Key, "<") && len(h.Key) > 1 {
h.Early = true
h.Key = h.Key[1:]
}
h.Value = d.Val()
return nil
}
@@ -15,6 +15,8 @@
package logging
import (
"bytes"
"encoding/base64"
"net/http"
"strings"
@@ -42,6 +44,12 @@ type LogAppend struct {
// map, the value of that key will be used. Otherwise
// the value will be used as-is as a constant string.
Value string `json:"value,omitempty"`
// Early, if true, adds the log field before calling
// the next handler in the chain. By default, the log
// field is added on the way back up the middleware chain,
// after all subsequent handlers have completed.
Early bool `json:"early,omitempty"`
}
// CaddyModule returns the Caddy module information.
@@ -53,13 +61,63 @@ func (LogAppend) CaddyModule() caddy.ModuleInfo {
}
func (h LogAppend) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
// Run the next handler in the chain first.
// Determine if we need to add the log field early.
// We do if the Early flag is set, or for convenience,
// if the value is a special placeholder for the request body.
needsEarly := h.Early || h.Value == placeholderRequestBody || h.Value == placeholderRequestBodyBase64
// Check if we need to buffer the response for special placeholders
needsResponseBody := h.Value == placeholderResponseBody || h.Value == placeholderResponseBodyBase64
if needsEarly && !needsResponseBody {
// Add the log field before calling the next handler
// (but not if we need the response body, which isn't available yet)
h.addLogField(r, nil)
}
var rec caddyhttp.ResponseRecorder
var buf *bytes.Buffer
if needsResponseBody {
// Wrap the response writer with a recorder to capture the response body
buf = new(bytes.Buffer)
rec = caddyhttp.NewResponseRecorder(w, buf, func(status int, header http.Header) bool {
// Always buffer the response when we need to log the body
return true
})
w = rec
}
// Run the next handler in the chain.
// If an error occurs, we still want to add
// any extra log fields that we can, so we
// hold onto the error and return it later.
handlerErr := next.ServeHTTP(w, r)
// On the way back up the chain, add the extra log field
if needsResponseBody {
// Write the buffered response to the client
if rec.Buffered() {
h.addLogField(r, buf)
err := rec.WriteResponse()
if err != nil {
return err
}
}
return handlerErr
}
if !h.Early {
// Add the log field after the handler completes
h.addLogField(r, buf)
}
return handlerErr
}
// addLogField adds the log field to the request's extra log fields.
// If buf is not nil, it contains the buffered response body for special
// response body placeholders.
func (h LogAppend) addLogField(r *http.Request, buf *bytes.Buffer) {
ctx := r.Context()
vars := ctx.Value(caddyhttp.VarsCtxKey).(map[string]any)
@@ -67,7 +125,21 @@ func (h LogAppend) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
extra := ctx.Value(caddyhttp.ExtraLogFieldsCtxKey).(*caddyhttp.ExtraLogFields)
var varValue any
if strings.HasPrefix(h.Value, "{") &&
// Handle special case placeholders for response body
if h.Value == placeholderResponseBody {
if buf != nil {
varValue = buf.String()
} else {
varValue = ""
}
} else if h.Value == placeholderResponseBodyBase64 {
if buf != nil {
varValue = base64.StdEncoding.EncodeToString(buf.Bytes())
} else {
varValue = ""
}
} else if strings.HasPrefix(h.Value, "{") &&
strings.HasSuffix(h.Value, "}") &&
strings.Count(h.Value, "{") == 1 {
// the value looks like a placeholder, so get its value
@@ -84,10 +156,17 @@ func (h LogAppend) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
// We use zap.Any because it will reflect
// to the correct type for us.
extra.Add(zap.Any(h.Key, varValue))
return handlerErr
}
const (
// Special placeholder values that are handled by log_append
// rather than by the replacer.
placeholderRequestBody = "{http.request.body}"
placeholderRequestBodyBase64 = "{http.request.body_base64}"
placeholderResponseBody = "{http.response.body}"
placeholderResponseBodyBase64 = "{http.response.body_base64}"
)
// Interface guards
var (
_ caddyhttp.MiddlewareHandler = (*LogAppend)(nil)
+1
View File
@@ -110,6 +110,7 @@ func (t LoggableTLSConnState) MarshalLogObject(enc zapcore.ObjectEncoder) error
enc.AddUint16("cipher_suite", t.CipherSuite)
enc.AddString("proto", t.NegotiatedProtocol)
enc.AddString("server_name", t.ServerName)
enc.AddBool("ech", t.ECHAccepted)
if len(t.PeerCertificates) > 0 {
enc.AddString("client_common_name", t.PeerCertificates[0].Subject.CommonName)
enc.AddString("client_serial", t.PeerCertificates[0].SerialNumber.String())
+13 -7
View File
@@ -262,13 +262,17 @@ func (m MatchHost) Provision(_ caddy.Context) error {
if err != nil {
return fmt.Errorf("converting hostname '%s' to ASCII: %v", host, err)
}
if asciiHost != host {
m[i] = asciiHost
}
normalizedHost := strings.ToLower(asciiHost)
if firstI, ok := seen[normalizedHost]; ok {
return fmt.Errorf("host at index %d is repeated at index %d: %s", firstI, i, host)
}
// Normalize exact hosts for standardized comparison in large-list fastpath later on.
// Keep wildcards/placeholders untouched.
if m.fuzzy(asciiHost) {
m[i] = asciiHost
} else {
m[i] = normalizedHost
}
seen[normalizedHost] = i
}
@@ -312,14 +316,15 @@ func (m MatchHost) MatchWithError(r *http.Request) (bool, error) {
}
if m.large() {
reqHostLower := strings.ToLower(reqHost)
// fast path: locate exact match using binary search (about 100-1000x faster for large lists)
pos := sort.Search(len(m), func(i int) bool {
if m.fuzzy(m[i]) {
return false
}
return m[i] >= reqHost
return m[i] >= reqHostLower
})
if pos < len(m) && m[pos] == reqHost {
if pos < len(m) && m[pos] == reqHostLower {
return true, nil
}
}
@@ -533,6 +538,7 @@ func (m MatchPath) MatchWithError(r *http.Request) (bool, error) {
}
func (MatchPath) matchPatternWithEscapeSequence(escapedPath, matchPath string) bool {
escapedPath = strings.ToLower(escapedPath)
// We would just compare the pattern against r.URL.Path,
// but the pattern contains %, indicating that we should
// compare at least some part of the path in raw/escaped
@@ -632,8 +638,8 @@ func (MatchPath) matchPatternWithEscapeSequence(escapedPath, matchPath string) b
// we can now treat rawpath globs (%*) as regular globs (*)
matchPath = strings.ReplaceAll(matchPath, "%*", "*")
// ignore error here because we can't handle it anyway=
matches, _ := path.Match(matchPath, sb.String())
// ignore error here because we can't handle it anyway
matches, _ := path.Match(matchPath, strings.ToLower(sb.String()))
return matches
}
+20 -1
View File
@@ -412,6 +412,16 @@ func TestPathMatcher(t *testing.T) {
input: "/foo%2fbar/baz",
expect: true,
},
{
match: MatchPath{"/admin%2fpanel"},
input: "/ADMIN%2fpanel",
expect: true,
},
{
match: MatchPath{"/admin%2fpa*el"},
input: "/ADMIN%2fPaAzZLm123NEL",
expect: true,
},
} {
err := tc.match.Provision(caddy.Context{})
if err == nil && tc.provisionErr {
@@ -957,6 +967,7 @@ func TestVarREMatcher(t *testing.T) {
desc string
match MatchVarsRE
input VarsMiddleware
headers http.Header
expect bool
expectRepl map[string]string
}{
@@ -991,6 +1002,14 @@ func TestVarREMatcher(t *testing.T) {
input: VarsMiddleware{"Var1": "var1Value"},
expect: true,
},
{
desc: "placeholder key value containing braces is not double-expanded",
match: MatchVarsRE{"{http.request.header.X-Input}": &MatchRegexp{Pattern: ".+", Name: "val"}},
input: VarsMiddleware{},
headers: http.Header{"X-Input": []string{"{env.HOME}"}},
expect: true,
expectRepl: map[string]string{"val.0": "{env.HOME}"},
},
} {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
@@ -1007,7 +1026,7 @@ func TestVarREMatcher(t *testing.T) {
}
// set up the fake request and its Replacer
req := &http.Request{URL: new(url.URL), Method: http.MethodGet}
req := &http.Request{URL: new(url.URL), Method: http.MethodGet, Header: tc.headers}
repl := caddy.NewReplacer()
ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl)
ctx = context.WithValue(ctx, VarsCtxKey, make(map[string]any))
+128 -12
View File
@@ -17,14 +17,60 @@ import (
// Metrics configures metrics observations.
// EXPERIMENTAL and subject to change or removal.
//
// Example configuration:
//
// {
// "apps": {
// "http": {
// "metrics": {
// "per_host": true,
// "observe_catchall_hosts": false
// },
// "servers": {
// "srv0": {
// "routes": [{
// "match": [{"host": ["example.com", "www.example.com"]}],
// "handle": [{"handler": "static_response", "body": "Hello"}]
// }]
// }
// }
// }
// }
// }
//
// In this configuration:
// - Requests to example.com and www.example.com get individual host labels
// - All other hosts (e.g., attacker.com) are aggregated under "_other" label
// - This prevents unlimited cardinality from arbitrary Host headers
type Metrics struct {
// Enable per-host metrics. Enabling this option may
// incur high-memory consumption, depending on the number of hosts
// managed by Caddy.
//
// CARDINALITY PROTECTION: To prevent unbounded cardinality attacks,
// only explicitly configured hosts (via host matchers) are allowed
// by default. Other hosts are aggregated under the "_other" label.
// See AllowCatchAllHosts to change this behavior.
PerHost bool `json:"per_host,omitempty"`
init sync.Once
httpMetrics *httpMetrics `json:"-"`
// Allow metrics for catch-all hosts (hosts without explicit configuration).
// When false (default), only hosts explicitly configured via host matchers
// will get individual metrics labels. All other hosts will be aggregated
// under the "_other" label to prevent cardinality explosion.
//
// This is automatically enabled for HTTPS servers (since certificates provide
// some protection against unbounded cardinality), but disabled for HTTP servers
// by default to prevent cardinality attacks from arbitrary Host headers.
//
// Set to true to allow all hosts to get individual metrics (NOT RECOMMENDED
// for production environments exposed to the internet).
ObserveCatchallHosts bool `json:"observe_catchall_hosts,omitempty"`
init sync.Once
httpMetrics *httpMetrics
allowedHosts map[string]struct{}
hasHTTPSServer bool
}
type httpMetrics struct {
@@ -101,6 +147,63 @@ func initHTTPMetrics(ctx caddy.Context, metrics *Metrics) {
}, httpLabels)
}
// scanConfigForHosts scans the HTTP app configuration to build a set of allowed hosts
// for metrics collection, similar to how auto-HTTPS scans for domain names.
func (m *Metrics) scanConfigForHosts(app *App) {
if !m.PerHost {
return
}
m.allowedHosts = make(map[string]struct{})
m.hasHTTPSServer = false
for _, srv := range app.Servers {
// Check if this server has TLS enabled
serverHasTLS := len(srv.TLSConnPolicies) > 0
if serverHasTLS {
m.hasHTTPSServer = true
}
// Collect hosts from route matchers
for _, route := range srv.Routes {
for _, matcherSet := range route.MatcherSets {
for _, matcher := range matcherSet {
if hm, ok := matcher.(*MatchHost); ok {
for _, host := range *hm {
// Only allow non-fuzzy hosts to prevent unbounded cardinality
if !hm.fuzzy(host) {
m.allowedHosts[strings.ToLower(host)] = struct{}{}
}
}
}
}
}
}
}
}
// shouldAllowHostMetrics determines if metrics should be collected for the given host.
// This implements the cardinality protection by only allowing metrics for:
// 1. Explicitly configured hosts
// 2. Catch-all requests on HTTPS servers (if AllowCatchAllHosts is true or auto-enabled)
// 3. Catch-all requests on HTTP servers only if explicitly allowed
func (m *Metrics) shouldAllowHostMetrics(host string, isHTTPS bool) bool {
if !m.PerHost {
return true // host won't be used in labels anyway
}
normalizedHost := strings.ToLower(host)
// Always allow explicitly configured hosts
if _, exists := m.allowedHosts[normalizedHost]; exists {
return true
}
// For catch-all requests (not in allowed hosts)
allowCatchAll := m.ObserveCatchallHosts || (isHTTPS && m.hasHTTPSServer)
return allowCatchAll
}
// serverNameFromContext extracts the current server name from the context.
// Returns "UNKNOWN" if none is available (should probably never happen).
func serverNameFromContext(ctx context.Context) string {
@@ -111,21 +214,24 @@ func serverNameFromContext(ctx context.Context) string {
return srv.name
}
type metricsInstrumentedHandler struct {
// metricsInstrumentedRoute wraps a compiled route Handler with metrics
// instrumentation. It wraps the entire compiled route chain once,
// collecting metrics only once per route match.
type metricsInstrumentedRoute struct {
handler string
mh MiddlewareHandler
next Handler
metrics *Metrics
}
func newMetricsInstrumentedHandler(ctx caddy.Context, handler string, mh MiddlewareHandler, metrics *Metrics) *metricsInstrumentedHandler {
metrics.init.Do(func() {
initHTTPMetrics(ctx, metrics)
func newMetricsInstrumentedRoute(ctx caddy.Context, handler string, next Handler, m *Metrics) *metricsInstrumentedRoute {
m.init.Do(func() {
initHTTPMetrics(ctx, m)
})
return &metricsInstrumentedHandler{handler, mh, metrics}
return &metricsInstrumentedRoute{handler: handler, next: next, metrics: m}
}
func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next Handler) error {
func (h *metricsInstrumentedRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
server := serverNameFromContext(r.Context())
labels := prometheus.Labels{"server": server, "handler": h.handler}
method := metrics.SanitizeMethod(r.Method)
@@ -133,9 +239,19 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
// of a panic
statusLabels := prometheus.Labels{"server": server, "handler": h.handler, "method": method, "code": ""}
// Determine if this is an HTTPS request
isHTTPS := r.TLS != nil
if h.metrics.PerHost {
labels["host"] = strings.ToLower(r.Host)
statusLabels["host"] = strings.ToLower(r.Host)
// Apply cardinality protection for host metrics
if h.metrics.shouldAllowHostMetrics(r.Host, isHTTPS) {
labels["host"] = strings.ToLower(r.Host)
statusLabels["host"] = strings.ToLower(r.Host)
} else {
// Use a catch-all label for unallowed hosts to prevent cardinality explosion
labels["host"] = "_other"
statusLabels["host"] = "_other"
}
}
inFlight := h.metrics.httpMetrics.requestInFlight.With(labels)
@@ -154,7 +270,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
return false
})
wrec := NewResponseRecorder(w, nil, writeHeaderRecorder)
err := h.mh.ServeHTTP(wrec, r, next)
err := h.next.ServeHTTP(wrec, r)
dur := time.Since(start).Seconds()
h.metrics.httpMetrics.requestCount.With(labels).Inc()
+227 -32
View File
@@ -2,6 +2,7 @@ package caddyhttp
import (
"context"
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"
@@ -46,16 +47,12 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
return handlerErr
})
mh := middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
return h.ServeHTTP(w, r)
})
ih := newMetricsInstrumentedHandler(ctx, "bar", mh, metrics)
ih := newMetricsInstrumentedRoute(ctx, "bar", h, metrics)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if actual := ih.ServeHTTP(w, r, h); actual != handlerErr {
if actual := ih.ServeHTTP(w, r); actual != handlerErr {
t.Errorf("Not same: expected %#v, but got %#v", handlerErr, actual)
}
if actual := testutil.ToFloat64(metrics.httpMetrics.requestInFlight); actual != 0.0 {
@@ -63,19 +60,19 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
}
handlerErr = nil
if err := ih.ServeHTTP(w, r, h); err != nil {
if err := ih.ServeHTTP(w, r); err != nil {
t.Errorf("Received unexpected error: %v", err)
}
// an empty handler - no errors, no header written
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
emptyHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
})
ih = newMetricsInstrumentedHandler(ctx, "empty", mh, metrics)
ih = newMetricsInstrumentedRoute(ctx, "empty", emptyHandler, metrics)
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
if err := ih.ServeHTTP(w, r, h); err != nil {
if err := ih.ServeHTTP(w, r); err != nil {
t.Errorf("Received unexpected error: %v", err)
}
if actual := w.Result().StatusCode; actual != 200 {
@@ -86,16 +83,16 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
}
// handler returning an error with an HTTP status
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
errHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return Error(http.StatusTooManyRequests, nil)
})
ih = newMetricsInstrumentedHandler(ctx, "foo", mh, metrics)
ih = newMetricsInstrumentedRoute(ctx, "foo", errHandler, metrics)
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
if err := ih.ServeHTTP(w, r, nil); err == nil {
if err := ih.ServeHTTP(w, r); err == nil {
t.Errorf("expected error to be propagated")
}
@@ -206,9 +203,11 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
metrics := &Metrics{
PerHost: true,
init: sync.Once{},
httpMetrics: &httpMetrics{},
PerHost: true,
ObserveCatchallHosts: true, // Allow all hosts for testing
init: sync.Once{},
httpMetrics: &httpMetrics{},
allowedHosts: make(map[string]struct{}),
}
handlerErr := errors.New("oh noes")
response := []byte("hello world!")
@@ -222,16 +221,12 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
return handlerErr
})
mh := middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
return h.ServeHTTP(w, r)
})
ih := newMetricsInstrumentedHandler(ctx, "bar", mh, metrics)
ih := newMetricsInstrumentedRoute(ctx, "bar", h, metrics)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if actual := ih.ServeHTTP(w, r, h); actual != handlerErr {
if actual := ih.ServeHTTP(w, r); actual != handlerErr {
t.Errorf("Not same: expected %#v, but got %#v", handlerErr, actual)
}
if actual := testutil.ToFloat64(metrics.httpMetrics.requestInFlight); actual != 0.0 {
@@ -239,19 +234,19 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
}
handlerErr = nil
if err := ih.ServeHTTP(w, r, h); err != nil {
if err := ih.ServeHTTP(w, r); err != nil {
t.Errorf("Received unexpected error: %v", err)
}
// an empty handler - no errors, no header written
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
emptyHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
})
ih = newMetricsInstrumentedHandler(ctx, "empty", mh, metrics)
ih = newMetricsInstrumentedRoute(ctx, "empty", emptyHandler, metrics)
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
if err := ih.ServeHTTP(w, r, h); err != nil {
if err := ih.ServeHTTP(w, r); err != nil {
t.Errorf("Received unexpected error: %v", err)
}
if actual := w.Result().StatusCode; actual != 200 {
@@ -262,16 +257,16 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
}
// handler returning an error with an HTTP status
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
errHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return Error(http.StatusTooManyRequests, nil)
})
ih = newMetricsInstrumentedHandler(ctx, "foo", mh, metrics)
ih = newMetricsInstrumentedRoute(ctx, "foo", errHandler, metrics)
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
if err := ih.ServeHTTP(w, r, nil); err == nil {
if err := ih.ServeHTTP(w, r); err == nil {
t.Errorf("expected error to be propagated")
}
@@ -379,8 +374,208 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
}
}
type middlewareHandlerFunc func(http.ResponseWriter, *http.Request, Handler) error
func TestMetricsCardinalityProtection(t *testing.T) {
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
func (f middlewareHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, h Handler) error {
return f(w, r, h)
// Test 1: Without AllowCatchAllHosts, arbitrary hosts should be mapped to "_other"
metrics := &Metrics{
PerHost: true,
ObserveCatchallHosts: false, // Default - should map unknown hosts to "_other"
init: sync.Once{},
httpMetrics: &httpMetrics{},
allowedHosts: make(map[string]struct{}),
}
// Add one allowed host
metrics.allowedHosts["allowed.com"] = struct{}{}
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.Write([]byte("hello"))
return nil
})
ih := newMetricsInstrumentedRoute(ctx, "test", h, metrics)
// Test request to allowed host
r1 := httptest.NewRequest("GET", "http://allowed.com/", nil)
r1.Host = "allowed.com"
w1 := httptest.NewRecorder()
ih.ServeHTTP(w1, r1)
// Test request to unknown host (should be mapped to "_other")
r2 := httptest.NewRequest("GET", "http://attacker.com/", nil)
r2.Host = "attacker.com"
w2 := httptest.NewRecorder()
ih.ServeHTTP(w2, r2)
// Test request to another unknown host (should also be mapped to "_other")
r3 := httptest.NewRequest("GET", "http://evil.com/", nil)
r3.Host = "evil.com"
w3 := httptest.NewRecorder()
ih.ServeHTTP(w3, r3)
// Check that metrics contain:
// - One entry for "allowed.com"
// - One entry for "_other" (aggregating attacker.com and evil.com)
expected := `
# HELP caddy_http_requests_total Counter of HTTP(S) requests made.
# TYPE caddy_http_requests_total counter
caddy_http_requests_total{handler="test",host="_other",server="UNKNOWN"} 2
caddy_http_requests_total{handler="test",host="allowed.com",server="UNKNOWN"} 1
`
if err := testutil.GatherAndCompare(ctx.GetMetricsRegistry(), strings.NewReader(expected),
"caddy_http_requests_total",
); err != nil {
t.Errorf("Cardinality protection test failed: %s", err)
}
}
func TestMetricsHTTPSCatchAll(t *testing.T) {
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
// Test that HTTPS requests allow catch-all even when AllowCatchAllHosts is false
metrics := &Metrics{
PerHost: true,
ObserveCatchallHosts: false,
hasHTTPSServer: true, // Simulate having HTTPS servers
init: sync.Once{},
httpMetrics: &httpMetrics{},
allowedHosts: make(map[string]struct{}), // Empty - no explicitly allowed hosts
}
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.Write([]byte("hello"))
return nil
})
ih := newMetricsInstrumentedRoute(ctx, "test", h, metrics)
// Test HTTPS request (should be allowed even though not in allowedHosts)
r1 := httptest.NewRequest("GET", "https://unknown.com/", nil)
r1.Host = "unknown.com"
r1.TLS = &tls.ConnectionState{} // Mark as TLS/HTTPS
w1 := httptest.NewRecorder()
ih.ServeHTTP(w1, r1)
// Test HTTP request (should be mapped to "_other")
r2 := httptest.NewRequest("GET", "http://unknown.com/", nil)
r2.Host = "unknown.com"
// No TLS field = HTTP request
w2 := httptest.NewRecorder()
ih.ServeHTTP(w2, r2)
// Check that HTTPS request gets real host, HTTP gets "_other"
expected := `
# HELP caddy_http_requests_total Counter of HTTP(S) requests made.
# TYPE caddy_http_requests_total counter
caddy_http_requests_total{handler="test",host="_other",server="UNKNOWN"} 1
caddy_http_requests_total{handler="test",host="unknown.com",server="UNKNOWN"} 1
`
if err := testutil.GatherAndCompare(ctx.GetMetricsRegistry(), strings.NewReader(expected),
"caddy_http_requests_total",
); err != nil {
t.Errorf("HTTPS catch-all test failed: %s", err)
}
}
func TestMetricsInstrumentedRoute(t *testing.T) {
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
m := &Metrics{
init: sync.Once{},
httpMetrics: &httpMetrics{},
}
handlerErr := errors.New("oh noes")
response := []byte("hello world!")
innerHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if actual := testutil.ToFloat64(m.httpMetrics.requestInFlight); actual != 1.0 {
t.Errorf("Expected requestInFlight to be 1.0, got %v", actual)
}
if handlerErr == nil {
w.Write(response)
}
return handlerErr
})
ih := newMetricsInstrumentedRoute(ctx, "test_handler", innerHandler, m)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// Test with error
if actual := ih.ServeHTTP(w, r); actual != handlerErr {
t.Errorf("Expected error %v, got %v", handlerErr, actual)
}
if actual := testutil.ToFloat64(m.httpMetrics.requestInFlight); actual != 0.0 {
t.Errorf("Expected requestInFlight to be 0.0 after request, got %v", actual)
}
if actual := testutil.ToFloat64(m.httpMetrics.requestErrors); actual != 1.0 {
t.Errorf("Expected requestErrors to be 1.0, got %v", actual)
}
// Test without error
handlerErr = nil
w = httptest.NewRecorder()
if err := ih.ServeHTTP(w, r); err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
func BenchmarkMetricsInstrumentedRoute(b *testing.B) {
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
m := &Metrics{
init: sync.Once{},
httpMetrics: &httpMetrics{},
}
noopHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.Write([]byte("ok"))
return nil
})
ih := newMetricsInstrumentedRoute(ctx, "bench_handler", noopHandler, m)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
ih.ServeHTTP(w, r)
}
}
// BenchmarkSingleRouteMetrics simulates the new behavior where metrics
// are collected once for the entire route.
func BenchmarkSingleRouteMetrics(b *testing.B) {
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
m := &Metrics{
init: sync.Once{},
httpMetrics: &httpMetrics{},
}
// Build a chain of 5 plain middleware handlers (no per-handler metrics)
var next Handler = HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
})
for i := 0; i < 5; i++ {
capturedNext := next
next = HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return capturedNext.ServeHTTP(w, r)
})
}
// Wrap the entire chain with a single route-level metrics handler
ih := newMetricsInstrumentedRoute(ctx, "handler", next, m)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
ih.ServeHTTP(w, r)
}
}
+1
View File
@@ -64,6 +64,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
var err error
// include current token, which we treat as an argument here
// nolint:prealloc
args := []string{h.Val()}
args = append(args, h.RemainingArgs()...)
+27 -1
View File
@@ -229,6 +229,21 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
req.Body = io.NopCloser(buf) // replace real body with buffered data
return buf.String(), true
case "http.request.body_base64":
if req.Body == nil {
return "", true
}
// normally net/http will close the body for us, but since we
// are replacing it with a fake one, we have to ensure we close
// the real body ourselves when we're done
defer req.Body.Close()
// read the request body into a buffer (can't pool because we
// don't know its lifetime and would have to make a copy anyway)
buf := new(bytes.Buffer)
_, _ = io.Copy(buf, req.Body) // can't handle error, so just ignore it
req.Body = io.NopCloser(buf) // replace real body with buffered data
return base64.StdEncoding.EncodeToString(buf.Bytes()), true
// original request, before any internal changes
case "http.request.orig_method":
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
@@ -405,7 +420,16 @@ func getReqTLSReplacement(req *http.Request, key string) (any, bool) {
if strings.HasPrefix(field, "client.") {
cert := getTLSPeerCert(req.TLS)
if cert == nil {
return nil, false
// Instead of returning (nil, false) here, we set it to a dummy
// value to fix #7530. This way, even if there is no client cert,
// evaluating placeholders with ReplaceKnown() will still remove
// the placeholder, which would be expected. It is not expected
// for the placeholder to sometimes get removed based on whether
// the client presented a cert. We also do not return true here
// because we probably should remain accurate about whether a
// placeholder is, in fact, known or not.
// (This allocation may be slightly inefficient.)
cert = new(x509.Certificate)
}
// subject alternate names (SANs)
@@ -511,6 +535,8 @@ func getReqTLSReplacement(req *http.Request, key string) (any, bool) {
return true, true
case "server_name":
return req.TLS.ServerName, true
case "ech":
return req.TLS.ECHAccepted, true
}
return nil, false
}
+26 -2
View File
@@ -73,8 +73,9 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
// Collect the results to respond with
results := []upstreamStatus{}
knownHosts := make(map[string]struct{})
// Iterate over the upstream pool (needs to be fast)
// Iterate over the static upstream pool (needs to be fast)
var rangeErr error
hosts.Range(func(key, val any) bool {
address, ok := key.(string)
@@ -95,6 +96,8 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
return false
}
knownHosts[address] = struct{}{}
results = append(results, upstreamStatus{
Address: address,
NumRequests: upstream.NumRequests(),
@@ -103,11 +106,32 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
return true
})
// If an error happened during the range, return it
currentInFlight := getInFlightRequests()
for address, count := range currentInFlight {
if _, exists := knownHosts[address]; !exists && count > 0 {
results = append(results, upstreamStatus{
Address: address,
NumRequests: int(count),
Fails: 0,
})
}
}
if rangeErr != nil {
return rangeErr
}
// Also include dynamic upstreams
dynamicHostsMu.RLock()
for address, entry := range dynamicHosts {
results = append(results, upstreamStatus{
Address: address,
NumRequests: entry.host.NumRequests(),
Fails: entry.host.Fails(),
})
}
dynamicHostsMu.RUnlock()
err := enc.Encode(results)
if err != nil {
return caddy.APIError{
@@ -0,0 +1,275 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// adminHandlerFixture sets up the global host state for an admin endpoint test
// and returns a cleanup function that must be deferred by the caller.
//
// staticAddrs are inserted into the UsagePool (as a static upstream would be).
// dynamicAddrs are inserted into the dynamicHosts map (as a dynamic upstream would be).
func adminHandlerFixture(t *testing.T, staticAddrs, dynamicAddrs []string) func() {
t.Helper()
for _, addr := range staticAddrs {
u := &Upstream{Dial: addr}
u.fillHost()
}
dynamicHostsMu.Lock()
for _, addr := range dynamicAddrs {
dynamicHosts[addr] = dynamicHostEntry{host: new(Host), lastSeen: time.Now()}
}
dynamicHostsMu.Unlock()
return func() {
// Remove static entries from the UsagePool.
for _, addr := range staticAddrs {
_, _ = hosts.Delete(addr)
}
// Remove dynamic entries.
dynamicHostsMu.Lock()
for _, addr := range dynamicAddrs {
delete(dynamicHosts, addr)
}
dynamicHostsMu.Unlock()
}
}
// callAdminUpstreams fires a GET against handleUpstreams and returns the
// decoded response body.
func callAdminUpstreams(t *testing.T) []upstreamStatus {
t.Helper()
req := httptest.NewRequest(http.MethodGet, "/reverse_proxy/upstreams", nil)
w := httptest.NewRecorder()
handler := adminUpstreams{}
if err := handler.handleUpstreams(w, req); err != nil {
t.Fatalf("handleUpstreams returned unexpected error: %v", err)
}
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("expected Content-Type application/json, got %q", ct)
}
var results []upstreamStatus
if err := json.NewDecoder(w.Body).Decode(&results); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
return results
}
// resultsByAddress indexes a slice of upstreamStatus by address for easier
// lookup in assertions.
func resultsByAddress(statuses []upstreamStatus) map[string]upstreamStatus {
m := make(map[string]upstreamStatus, len(statuses))
for _, s := range statuses {
m[s.Address] = s
}
return m
}
// TestAdminUpstreamsMethodNotAllowed verifies that non-GET methods are rejected.
func TestAdminUpstreamsMethodNotAllowed(t *testing.T) {
for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodDelete} {
req := httptest.NewRequest(method, "/reverse_proxy/upstreams", nil)
w := httptest.NewRecorder()
err := (adminUpstreams{}).handleUpstreams(w, req)
if err == nil {
t.Errorf("method %s: expected an error, got nil", method)
continue
}
apiErr, ok := err.(interface{ HTTPStatus() int })
if !ok {
// caddy.APIError stores the code in HTTPStatus field, access via the
// exported interface it satisfies indirectly; just check non-nil.
continue
}
if code := apiErr.HTTPStatus(); code != http.StatusMethodNotAllowed {
t.Errorf("method %s: expected 405, got %d", method, code)
}
}
}
// TestAdminUpstreamsEmpty verifies that an empty response is valid JSON when
// no upstreams are registered.
func TestAdminUpstreamsEmpty(t *testing.T) {
resetDynamicHosts()
results := callAdminUpstreams(t)
if results == nil {
t.Error("expected non-nil (empty) slice, got nil")
}
if len(results) != 0 {
t.Errorf("expected 0 results with empty pools, got %d", len(results))
}
}
// TestAdminUpstreamsStaticOnly verifies that static upstreams (from the
// UsagePool) appear in the response with correct addresses.
func TestAdminUpstreamsStaticOnly(t *testing.T) {
resetDynamicHosts()
cleanup := adminHandlerFixture(t,
[]string{"10.0.0.1:80", "10.0.0.2:80"},
nil,
)
defer cleanup()
results := callAdminUpstreams(t)
byAddr := resultsByAddress(results)
for _, addr := range []string{"10.0.0.1:80", "10.0.0.2:80"} {
if _, ok := byAddr[addr]; !ok {
t.Errorf("expected static upstream %q in response", addr)
}
}
if len(results) != 2 {
t.Errorf("expected exactly 2 results, got %d", len(results))
}
}
// TestAdminUpstreamsDynamicOnly verifies that dynamic upstreams (from
// dynamicHosts) appear in the response with correct addresses.
func TestAdminUpstreamsDynamicOnly(t *testing.T) {
resetDynamicHosts()
cleanup := adminHandlerFixture(t,
nil,
[]string{"10.0.1.1:80", "10.0.1.2:80"},
)
defer cleanup()
results := callAdminUpstreams(t)
byAddr := resultsByAddress(results)
for _, addr := range []string{"10.0.1.1:80", "10.0.1.2:80"} {
if _, ok := byAddr[addr]; !ok {
t.Errorf("expected dynamic upstream %q in response", addr)
}
}
if len(results) != 2 {
t.Errorf("expected exactly 2 results, got %d", len(results))
}
}
// TestAdminUpstreamsBothPools verifies that static and dynamic upstreams are
// both present in the same response and that there is no overlap or omission.
func TestAdminUpstreamsBothPools(t *testing.T) {
resetDynamicHosts()
cleanup := adminHandlerFixture(t,
[]string{"10.0.2.1:80"},
[]string{"10.0.2.2:80"},
)
defer cleanup()
results := callAdminUpstreams(t)
if len(results) != 2 {
t.Fatalf("expected 2 results (1 static + 1 dynamic), got %d", len(results))
}
byAddr := resultsByAddress(results)
if _, ok := byAddr["10.0.2.1:80"]; !ok {
t.Error("static upstream missing from response")
}
if _, ok := byAddr["10.0.2.2:80"]; !ok {
t.Error("dynamic upstream missing from response")
}
}
// TestAdminUpstreamsNoOverlapBetweenPools verifies that an address registered
// only as a static upstream does not also appear as a dynamic entry, and
// vice-versa.
func TestAdminUpstreamsNoOverlapBetweenPools(t *testing.T) {
resetDynamicHosts()
cleanup := adminHandlerFixture(t,
[]string{"10.0.3.1:80"},
[]string{"10.0.3.2:80"},
)
defer cleanup()
results := callAdminUpstreams(t)
seen := make(map[string]int)
for _, r := range results {
seen[r.Address]++
}
for addr, count := range seen {
if count > 1 {
t.Errorf("address %q appeared %d times; expected exactly once", addr, count)
}
}
}
// TestAdminUpstreamsReportsFailCounts verifies that fail counts accumulated on
// a dynamic upstream's Host are reflected in the response.
func TestAdminUpstreamsReportsFailCounts(t *testing.T) {
resetDynamicHosts()
const addr = "10.0.4.1:80"
h := new(Host)
_ = h.countFail(3)
dynamicHostsMu.Lock()
dynamicHosts[addr] = dynamicHostEntry{host: h, lastSeen: time.Now()}
dynamicHostsMu.Unlock()
defer func() {
dynamicHostsMu.Lock()
delete(dynamicHosts, addr)
dynamicHostsMu.Unlock()
}()
results := callAdminUpstreams(t)
byAddr := resultsByAddress(results)
status, ok := byAddr[addr]
if !ok {
t.Fatalf("expected %q in response", addr)
}
if status.Fails != 3 {
t.Errorf("expected Fails=3, got %d", status.Fails)
}
}
// TestAdminUpstreamsReportsNumRequests verifies that the active request count
// for a static upstream is reflected in the response.
func TestAdminUpstreamsReportsNumRequests(t *testing.T) {
resetDynamicHosts()
const addr = "10.0.4.2:80"
u := &Upstream{Dial: addr}
u.fillHost()
defer func() { _, _ = hosts.Delete(addr) }()
_ = u.Host.countRequest(2)
defer func() { _ = u.Host.countRequest(-2) }()
results := callAdminUpstreams(t)
byAddr := resultsByAddress(results)
status, ok := byAddr[addr]
if !ok {
t.Fatalf("expected %q in response", addr)
}
if status.NumRequests != 2 {
t.Errorf("expected NumRequests=2, got %d", status.NumRequests)
}
}
+6 -2
View File
@@ -888,8 +888,11 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
if commonScheme == "http" && te.TLSEnabled() {
return d.Errf("upstream address scheme is HTTP but transport is configured for HTTP+TLS (HTTPS)")
}
if te, ok := transport.(*HTTPTransport); ok && commonScheme == "h2c" {
te.Versions = []string{"h2c", "2"}
if h2ct, ok := transport.(H2CTransport); ok && commonScheme == "h2c" {
err := h2ct.EnableH2C()
if err != nil {
return err
}
}
} else if commonScheme == "https" {
return d.Errf("upstreams are configured for HTTPS but transport module does not support TLS: %T", transport)
@@ -1525,6 +1528,7 @@ func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return d.Errf("bad delay value '%s': %v", d.Val(), err)
}
u.FallbackDelay = caddy.Duration(dur)
case "grace_period":
if !d.NextArg() {
return d.ArgErr()
@@ -0,0 +1,345 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import (
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2"
)
// resetDynamicHosts clears global dynamic host state between tests.
func resetDynamicHosts() {
dynamicHostsMu.Lock()
dynamicHosts = make(map[string]dynamicHostEntry)
dynamicHostsMu.Unlock()
// Reset the Once so cleanup goroutine tests can re-trigger if needed.
dynamicHostsCleanerOnce = sync.Once{}
}
// TestFillDynamicHostCreatesEntry verifies that calling fillDynamicHost on a
// new address inserts an entry into dynamicHosts and assigns a non-nil Host.
func TestFillDynamicHostCreatesEntry(t *testing.T) {
resetDynamicHosts()
u := &Upstream{Dial: "192.0.2.1:80"}
u.fillDynamicHost()
if u.Host == nil {
t.Fatal("expected Host to be set after fillDynamicHost")
}
dynamicHostsMu.RLock()
entry, ok := dynamicHosts["192.0.2.1:80"]
dynamicHostsMu.RUnlock()
if !ok {
t.Fatal("expected entry in dynamicHosts map")
}
if entry.host != u.Host {
t.Error("dynamicHosts entry host should be the same pointer assigned to Upstream.Host")
}
if entry.lastSeen.IsZero() {
t.Error("expected lastSeen to be set")
}
}
// TestFillDynamicHostReusesSameHost verifies that two calls for the same address
// return the exact same *Host pointer so that state (e.g. fail counts) is shared.
func TestFillDynamicHostReusesSameHost(t *testing.T) {
resetDynamicHosts()
u1 := &Upstream{Dial: "192.0.2.2:80"}
u1.fillDynamicHost()
u2 := &Upstream{Dial: "192.0.2.2:80"}
u2.fillDynamicHost()
if u1.Host != u2.Host {
t.Error("expected both upstreams to share the same *Host pointer")
}
}
// TestFillDynamicHostUpdatesLastSeen verifies that a second call for the same
// address advances the lastSeen timestamp.
func TestFillDynamicHostUpdatesLastSeen(t *testing.T) {
resetDynamicHosts()
u := &Upstream{Dial: "192.0.2.3:80"}
u.fillDynamicHost()
dynamicHostsMu.RLock()
first := dynamicHosts["192.0.2.3:80"].lastSeen
dynamicHostsMu.RUnlock()
// Ensure measurable time passes.
time.Sleep(2 * time.Millisecond)
u2 := &Upstream{Dial: "192.0.2.3:80"}
u2.fillDynamicHost()
dynamicHostsMu.RLock()
second := dynamicHosts["192.0.2.3:80"].lastSeen
dynamicHostsMu.RUnlock()
if !second.After(first) {
t.Error("expected lastSeen to be updated on second fillDynamicHost call")
}
}
// TestFillDynamicHostIndependentAddresses verifies that different addresses get
// independent Host entries.
func TestFillDynamicHostIndependentAddresses(t *testing.T) {
resetDynamicHosts()
u1 := &Upstream{Dial: "192.0.2.4:80"}
u1.fillDynamicHost()
u2 := &Upstream{Dial: "192.0.2.5:80"}
u2.fillDynamicHost()
if u1.Host == u2.Host {
t.Error("different addresses should have different *Host entries")
}
}
// TestFillDynamicHostPreservesFailCount verifies that fail counts on a dynamic
// host survive across multiple fillDynamicHost calls (simulating sequential
// requests), which is the core behaviour fixed by this change.
func TestFillDynamicHostPreservesFailCount(t *testing.T) {
resetDynamicHosts()
// First "request": provision and record a failure.
u1 := &Upstream{Dial: "192.0.2.6:80"}
u1.fillDynamicHost()
_ = u1.Host.countFail(1)
if u1.Host.Fails() != 1 {
t.Fatalf("expected 1 fail, got %d", u1.Host.Fails())
}
// Second "request": provision the same address again (new *Upstream, same address).
u2 := &Upstream{Dial: "192.0.2.6:80"}
u2.fillDynamicHost()
if u2.Host.Fails() != 1 {
t.Errorf("expected fail count to persist across fillDynamicHost calls, got %d", u2.Host.Fails())
}
}
// TestProvisionUpstreamDynamic verifies that provisionUpstream with dynamic=true
// uses fillDynamicHost (not the UsagePool) and sets healthCheckPolicy /
// MaxRequests correctly from handler config.
func TestProvisionUpstreamDynamic(t *testing.T) {
resetDynamicHosts()
passive := &PassiveHealthChecks{
FailDuration: caddy.Duration(10 * time.Second),
MaxFails: 3,
UnhealthyRequestCount: 5,
}
h := Handler{
HealthChecks: &HealthChecks{
Passive: passive,
},
}
u := &Upstream{Dial: "192.0.2.7:80"}
h.provisionUpstream(u, true)
if u.Host == nil {
t.Fatal("Host should be set after provisionUpstream")
}
if u.healthCheckPolicy != passive {
t.Error("healthCheckPolicy should point to the handler's PassiveHealthChecks")
}
if u.MaxRequests != 5 {
t.Errorf("expected MaxRequests=5 from UnhealthyRequestCount, got %d", u.MaxRequests)
}
// Must be in dynamicHosts, not in the static UsagePool.
dynamicHostsMu.RLock()
_, inDynamic := dynamicHosts["192.0.2.7:80"]
dynamicHostsMu.RUnlock()
if !inDynamic {
t.Error("dynamic upstream should be stored in dynamicHosts")
}
_, inPool := hosts.References("192.0.2.7:80")
if inPool {
t.Error("dynamic upstream should NOT be stored in the static UsagePool")
}
}
// TestProvisionUpstreamStatic verifies that provisionUpstream with dynamic=false
// uses the UsagePool and does NOT insert into dynamicHosts.
func TestProvisionUpstreamStatic(t *testing.T) {
resetDynamicHosts()
h := Handler{}
u := &Upstream{Dial: "192.0.2.8:80"}
h.provisionUpstream(u, false)
if u.Host == nil {
t.Fatal("Host should be set after provisionUpstream")
}
refs, inPool := hosts.References("192.0.2.8:80")
if !inPool {
t.Error("static upstream should be in the UsagePool")
}
if refs != 1 {
t.Errorf("expected ref count 1, got %d", refs)
}
dynamicHostsMu.RLock()
_, inDynamic := dynamicHosts["192.0.2.8:80"]
dynamicHostsMu.RUnlock()
if inDynamic {
t.Error("static upstream should NOT be in dynamicHosts")
}
// Clean up the pool entry we just added.
_, _ = hosts.Delete("192.0.2.8:80")
}
// TestDynamicHostHealthyConsultsFails verifies the end-to-end passive health
// check path: after enough failures are recorded against a dynamic upstream's
// shared *Host, Healthy() returns false for a newly provisioned *Upstream with
// the same address.
func TestDynamicHostHealthyConsultsFails(t *testing.T) {
resetDynamicHosts()
passive := &PassiveHealthChecks{
FailDuration: caddy.Duration(time.Minute),
MaxFails: 2,
}
h := Handler{
HealthChecks: &HealthChecks{Passive: passive},
}
// First request: provision and record two failures.
u1 := &Upstream{Dial: "192.0.2.9:80"}
h.provisionUpstream(u1, true)
_ = u1.Host.countFail(1)
_ = u1.Host.countFail(1)
// Second request: fresh *Upstream, same address.
u2 := &Upstream{Dial: "192.0.2.9:80"}
h.provisionUpstream(u2, true)
if u2.Healthy() {
t.Error("upstream should be unhealthy after MaxFails failures have been recorded against its shared Host")
}
}
// TestDynamicHostCleanupEvictsStaleEntries verifies that the cleanup sweep
// removes entries whose lastSeen is older than dynamicHostIdleExpiry.
func TestDynamicHostCleanupEvictsStaleEntries(t *testing.T) {
resetDynamicHosts()
const addr = "192.0.2.10:80"
// Insert an entry directly with a lastSeen far in the past.
dynamicHostsMu.Lock()
dynamicHosts[addr] = dynamicHostEntry{
host: new(Host),
lastSeen: time.Now().Add(-2 * dynamicHostIdleExpiry),
}
dynamicHostsMu.Unlock()
// Run the cleanup logic inline (same logic as the goroutine).
dynamicHostsMu.Lock()
for a, entry := range dynamicHosts {
if time.Since(entry.lastSeen) > dynamicHostIdleExpiry {
delete(dynamicHosts, a)
}
}
dynamicHostsMu.Unlock()
dynamicHostsMu.RLock()
_, stillPresent := dynamicHosts[addr]
dynamicHostsMu.RUnlock()
if stillPresent {
t.Error("stale dynamic host entry should have been evicted by cleanup sweep")
}
}
// TestDynamicHostCleanupRetainsFreshEntries verifies that the cleanup sweep
// keeps entries whose lastSeen is within dynamicHostIdleExpiry.
func TestDynamicHostCleanupRetainsFreshEntries(t *testing.T) {
resetDynamicHosts()
const addr = "192.0.2.11:80"
dynamicHostsMu.Lock()
dynamicHosts[addr] = dynamicHostEntry{
host: new(Host),
lastSeen: time.Now(),
}
dynamicHostsMu.Unlock()
// Run the cleanup logic inline.
dynamicHostsMu.Lock()
for a, entry := range dynamicHosts {
if time.Since(entry.lastSeen) > dynamicHostIdleExpiry {
delete(dynamicHosts, a)
}
}
dynamicHostsMu.Unlock()
dynamicHostsMu.RLock()
_, stillPresent := dynamicHosts[addr]
dynamicHostsMu.RUnlock()
if !stillPresent {
t.Error("fresh dynamic host entry should be retained by cleanup sweep")
}
}
// TestDynamicHostConcurrentFillHost verifies that concurrent calls to
// fillDynamicHost for the same address all get the same *Host pointer and
// don't race (run with -race).
func TestDynamicHostConcurrentFillHost(t *testing.T) {
resetDynamicHosts()
const addr = "192.0.2.12:80"
const goroutines = 50
var wg sync.WaitGroup
hosts := make([]*Host, goroutines)
for i := range goroutines {
wg.Add(1)
go func(idx int) {
defer wg.Done()
u := &Upstream{Dial: addr}
u.fillDynamicHost()
hosts[idx] = u.Host
}(i)
}
wg.Wait()
first := hosts[0]
for i, h := range hosts {
if h != first {
t.Errorf("goroutine %d got a different *Host pointer; expected all to share the same entry", i)
}
}
}
@@ -27,7 +27,7 @@ import (
"fmt"
"io"
"log"
"math/rand"
"math/rand/v2"
"net"
"net/http"
"net/http/fcgi"
@@ -197,7 +197,7 @@ func generateRandFile(size int) (p string, m string) {
h := md5.New()
for i := 0; i < size/16; i++ {
buf := make([]byte, 16)
binary.PutVarint(buf, rand.Int63())
binary.PutVarint(buf, rand.Int64())
if _, err := fo.Write(buf); err != nil {
log.Printf("[ERROR] failed to write buffer: %v\n", err)
}
@@ -16,6 +16,7 @@ package fastcgi
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
@@ -23,9 +24,12 @@ import (
"strconv"
"strings"
"time"
"unicode/utf8"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/text/language"
"golang.org/x/text/search"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
@@ -33,7 +37,11 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddytls"
)
var noopLogger = zap.NewNop()
var (
ErrInvalidSplitPath = errors.New("split path contains non-ASCII characters")
noopLogger = zap.NewNop()
)
func init() {
caddy.RegisterModule(Transport{})
@@ -50,6 +58,9 @@ type Transport struct {
// actual resource (CGI script) name, and the second piece will be set to
// PATH_INFO for the CGI script to use.
//
// Split paths can only contain ASCII characters.
// Comparison is case-insensitive.
//
// Future enhancements should be careful to avoid CVE-2019-11043,
// which can be mitigated with use of a try_files-like behavior
// that 404s if the fastcgi path info is not found.
@@ -109,9 +120,45 @@ func (t *Transport) Provision(ctx caddy.Context) error {
t.DialTimeout = caddy.Duration(3 * time.Second)
}
var b strings.Builder
for i, split := range t.SplitPath {
b.Grow(len(split))
for j := 0; j < len(split); j++ {
c := split[j]
if c >= utf8.RuneSelf {
return ErrInvalidSplitPath
}
if 'A' <= c && c <= 'Z' {
b.WriteByte(c + 'a' - 'A')
} else {
b.WriteByte(c)
}
}
t.SplitPath[i] = b.String()
b.Reset()
}
return nil
}
// DefaultBufferSizes enables request buffering for fastcgi if not configured.
// This is because most fastcgi servers are php-fpm that require the content length to be set to read the body, golang
// std has fastcgi implementation that doesn't need this value to process the body, but we can safely assume that's
// not used.
// http3 requests have a negative content length for GET and HEAD requests, if that header is not sent.
// see: https://github.com/caddyserver/caddy/issues/6678#issuecomment-2472224182
// Though it appears even if CONTENT_LENGTH is invalid, php-fpm can handle just fine if the body is empty (no Stdin records sent).
// php-fpm will hang if there is any data in the body though, https://github.com/caddyserver/caddy/issues/5420#issuecomment-2415943516
// TODO: better default buffering for fastcgi requests without content length, in theory a value of 1 should be enough, make it bigger anyway
func (t Transport) DefaultBufferSizes() (int64, int64) {
return 4096, 0
}
// RoundTrip implements http.RoundTripper.
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
server := r.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server)
@@ -371,8 +418,15 @@ func (t Transport) buildEnv(r *http.Request) (envVars, error) {
return env, nil
}
var splitSearchNonASCII = search.New(language.Und, search.IgnoreCase)
// splitPos returns the index where path should
// be split based on t.SplitPath.
//
// example: if splitPath is [".php"]
// "/path/to/script.php/some/path": ("/path/to/script.php", "/some/path")
//
// Adapted from FrankenPHP's code (copyright 2026 Kévin Dunglas, MIT license)
func (t Transport) splitPos(path string) int {
// TODO: from v1...
// if httpserver.CaseSensitivePath {
@@ -382,12 +436,54 @@ func (t Transport) splitPos(path string) int {
return 0
}
lowerPath := strings.ToLower(path)
pathLen := len(path)
// We are sure that split strings are all ASCII-only and lower-case because of validation and normalization in Provision().
for _, split := range t.SplitPath {
if idx := strings.Index(lowerPath, strings.ToLower(split)); idx > -1 {
return idx + len(split)
splitLen := len(split)
for i := range pathLen {
if path[i] >= utf8.RuneSelf {
if _, end := splitSearchNonASCII.IndexString(path, split); end > -1 {
return end
}
break
}
if i+splitLen > pathLen {
continue
}
match := true
for j := range splitLen {
c := path[i+j]
if c >= utf8.RuneSelf {
if _, end := splitSearchNonASCII.IndexString(path, split); end > -1 {
return end
}
break
}
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
}
if c != split[j] {
match = false
break
}
}
if match {
return i + splitLen
}
}
}
return -1
}
@@ -427,6 +523,7 @@ var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_")
var (
_ zapcore.ObjectMarshaler = (*loggableEnv)(nil)
_ caddy.Provisioner = (*Transport)(nil)
_ http.RoundTripper = (*Transport)(nil)
_ caddy.Provisioner = (*Transport)(nil)
_ http.RoundTripper = (*Transport)(nil)
_ reverseproxy.BufferedTransport = (*Transport)(nil)
)
@@ -0,0 +1,246 @@
package fastcgi
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/caddyserver/caddy/v2"
)
func TestProvisionSplitPath(t *testing.T) {
tests := []struct {
name string
splitPath []string
wantErr error
wantSplitPath []string
}{
{
name: "valid lowercase split path",
splitPath: []string{".php"},
wantErr: nil,
wantSplitPath: []string{".php"},
},
{
name: "valid uppercase split path normalized",
splitPath: []string{".PHP"},
wantErr: nil,
wantSplitPath: []string{".php"},
},
{
name: "valid mixed case split path normalized",
splitPath: []string{".PhP", ".PHTML"},
wantErr: nil,
wantSplitPath: []string{".php", ".phtml"},
},
{
name: "empty split path",
splitPath: []string{},
wantErr: nil,
wantSplitPath: []string{},
},
{
name: "non-ASCII character in split path rejected",
splitPath: []string{".php", ".Ⱥphp"},
wantErr: ErrInvalidSplitPath,
},
{
name: "unicode character in split path rejected",
splitPath: []string{".phpⱥ"},
wantErr: ErrInvalidSplitPath,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tr := Transport{SplitPath: tt.splitPath}
err := tr.Provision(caddy.Context{})
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantSplitPath, tr.SplitPath)
})
}
}
func TestSplitPos(t *testing.T) {
tests := []struct {
name string
path string
splitPath []string
wantPos int
}{
{
name: "simple php extension",
path: "/path/to/script.php",
splitPath: []string{".php"},
wantPos: 19,
},
{
name: "php extension with path info",
path: "/path/to/script.php/some/path",
splitPath: []string{".php"},
wantPos: 19,
},
{
name: "case insensitive match",
path: "/path/to/script.PHP",
splitPath: []string{".php"},
wantPos: 19,
},
{
name: "mixed case match",
path: "/path/to/script.PhP/info",
splitPath: []string{".php"},
wantPos: 19,
},
{
name: "no match",
path: "/path/to/script.txt",
splitPath: []string{".php"},
wantPos: -1,
},
{
name: "empty split path",
path: "/path/to/script.php",
splitPath: []string{},
wantPos: 0,
},
{
name: "multiple split paths first match",
path: "/path/to/script.php",
splitPath: []string{".php", ".phtml"},
wantPos: 19,
},
{
name: "multiple split paths second match",
path: "/path/to/script.phtml",
splitPath: []string{".php", ".phtml"},
wantPos: 21,
},
// Unicode case-folding tests (security fix for GHSA-g966-83w7-6w38)
// U+023A (Ⱥ) lowercases to U+2C65 (ⱥ), which has different UTF-8 byte length
// Ⱥ: 2 bytes (C8 BA), ⱥ: 3 bytes (E2 B1 A5)
{
name: "unicode path with case-folding length expansion",
path: "/ȺȺȺȺshell.php",
splitPath: []string{".php"},
wantPos: 18, // correct position in original string
},
{
name: "unicode path with extension after expansion chars",
path: "/ȺȺȺȺshell.php/path/info",
splitPath: []string{".php"},
wantPos: 18,
},
{
name: "unicode in filename with multiple php occurrences",
path: "/ȺȺȺȺshell.php.txt.php",
splitPath: []string{".php"},
wantPos: 18, // should match first .php, not be confused by byte offset shift
},
{
name: "unicode case insensitive extension",
path: "/ȺȺȺȺshell.PHP",
splitPath: []string{".php"},
wantPos: 18,
},
{
name: "unicode in middle of path",
path: "/path/Ⱥtest/script.php",
splitPath: []string{".php"},
wantPos: 23, // Ⱥ is 2 bytes, so path is 23 bytes total, .php ends at byte 23
},
{
name: "unicode only in directory not filename",
path: "/Ⱥ/script.php",
splitPath: []string{".php"},
wantPos: 14,
},
// Additional Unicode characters that expand when lowercased
// U+0130 (İ - Turkish capital I with dot) lowercases to U+0069 + U+0307
{
name: "turkish capital I with dot",
path: "/İtest.php",
splitPath: []string{".php"},
wantPos: 11,
},
// Ensure standard ASCII still works correctly
{
name: "ascii only path with case variation",
path: "/PATH/TO/SCRIPT.PHP/INFO",
splitPath: []string{".php"},
wantPos: 19,
},
{
name: "path at root",
path: "/index.php",
splitPath: []string{".php"},
wantPos: 10,
},
{
name: "extension in middle of filename",
path: "/test.php.bak",
splitPath: []string{".php"},
wantPos: 9,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotPos := Transport{SplitPath: tt.splitPath}.splitPos(tt.path)
assert.Equal(t, tt.wantPos, gotPos, "splitPos(%q, %v)", tt.path, tt.splitPath)
// Verify that the split produces valid substrings
if gotPos > 0 && gotPos <= len(tt.path) {
scriptName := tt.path[:gotPos]
pathInfo := tt.path[gotPos:]
// The script name should end with one of the split extensions (case-insensitive)
hasValidEnding := false
for _, split := range tt.splitPath {
if strings.HasSuffix(strings.ToLower(scriptName), split) {
hasValidEnding = true
break
}
}
assert.True(t, hasValidEnding, "script name %q should end with one of %v", scriptName, tt.splitPath)
// Original path should be reconstructable
assert.Equal(t, tt.path, scriptName+pathInfo, "path should be reconstructable from split parts")
}
})
}
}
// TestSplitPosUnicodeSecurityRegression specifically tests the vulnerability
// described in GHSA-g966-83w7-6w38 where Unicode case-folding caused
// incorrect SCRIPT_NAME/PATH_INFO splitting
func TestSplitPosUnicodeSecurityRegression(t *testing.T) {
// U+023A: Ⱥ (UTF-8: C8 BA). Lowercase is ⱥ (UTF-8: E2 B1 A5), longer in bytes.
path := "/ȺȺȺȺshell.php.txt.php"
split := []string{".php"}
pos := Transport{SplitPath: split}.splitPos(path)
// The vulnerable code would return 22 (computed on lowercased string)
// The correct code should return 18 (position in original string)
expectedPos := strings.Index(path, ".php") + len(".php")
assert.Equal(t, expectedPos, pos, "split position should match first .php in original string")
assert.Equal(t, 18, pos, "split position should be 18, not 22")
if pos > 0 && pos <= len(path) {
scriptName := path[:pos]
pathInfo := path[pos:]
assert.Equal(t, "/ȺȺȺȺshell.php", scriptName, "script name should be the path up to first .php")
assert.Equal(t, ".txt.php", pathInfo, "path info should be the remainder after first .php")
}
}
@@ -112,7 +112,7 @@ func encodeSize(b []byte, size uint32) int {
binary.BigEndian.PutUint32(b, size)
return 4
}
b[0] = byte(size)
b[0] = byte(size) //nolint:gosec // false positive; b is made 8 bytes long, then this function is always called with b being at least 4 or 1 byte long
return 1
}
@@ -208,6 +208,24 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error)
for _, from := range sortedHeadersToCopy {
to := http.CanonicalHeaderKey(headersToCopy[from])
placeholderName := "http.reverse_proxy.header." + http.CanonicalHeaderKey(from)
// Always delete the client-supplied header before conditionally setting
// it from the auth response. Without this, a client that pre-supplies a
// header listed in copy_headers can inject arbitrary values when the auth
// service does not return that header: the MatchNot guard below would
// skip the Set entirely, leaving the original client-controlled value
// intact and forwarding it to the backend.
copyHeaderRoutes = append(copyHeaderRoutes, caddyhttp.Route{
HandlersRaw: []json.RawMessage{caddyconfig.JSONModuleObject(
&headers.Handler{
Request: &headers.HeaderOps{
Delete: []string{to},
},
},
"handler", "headers", nil,
)},
})
handler := &headers.Handler{
Request: &headers.HeaderOps{
Set: http.Header{
@@ -0,0 +1,127 @@
package reverseproxy
import (
"context"
"net/http/httptest"
"testing"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
func TestAddForwardedHeadersNonIP(t *testing.T) {
h := Handler{}
// Simulate a request with a non-IP remote address (e.g. SCION, abstract socket, or hostname)
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = "my-weird-network:12345"
// Mock the context variables required by Caddy.
// We need to inject the variable map manually since we aren't running the full server.
vars := map[string]interface{}{
caddyhttp.TrustedProxyVarKey: false,
}
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
req = req.WithContext(ctx)
// Execute the unexported function
err := h.addForwardedHeaders(req)
// Expectation: No error should be returned for non-IP addresses.
// The function should simply skip the trusted proxy check.
if err != nil {
t.Errorf("expected no error for non-IP address, got: %v", err)
}
}
func TestAddForwardedHeaders_UnixSocketTrusted(t *testing.T) {
h := Handler{}
req := httptest.NewRequest("GET", "http://example.com/", nil)
req.RemoteAddr = "@"
req.Header.Set("X-Forwarded-For", "1.2.3.4, 10.0.0.1")
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "original.example.com")
vars := map[string]interface{}{
caddyhttp.TrustedProxyVarKey: true,
caddyhttp.ClientIPVarKey: "1.2.3.4",
}
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
req = req.WithContext(ctx)
err := h.addForwardedHeaders(req)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if got := req.Header.Get("X-Forwarded-For"); got != "1.2.3.4, 10.0.0.1" {
t.Errorf("X-Forwarded-For = %q, want %q", got, "1.2.3.4, 10.0.0.1")
}
if got := req.Header.Get("X-Forwarded-Proto"); got != "https" {
t.Errorf("X-Forwarded-Proto = %q, want %q", got, "https")
}
if got := req.Header.Get("X-Forwarded-Host"); got != "original.example.com" {
t.Errorf("X-Forwarded-Host = %q, want %q", got, "original.example.com")
}
}
func TestAddForwardedHeaders_UnixSocketUntrusted(t *testing.T) {
h := Handler{}
req := httptest.NewRequest("GET", "http://example.com/", nil)
req.RemoteAddr = "@"
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "spoofed.example.com")
vars := map[string]interface{}{
caddyhttp.TrustedProxyVarKey: false,
caddyhttp.ClientIPVarKey: "",
}
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
req = req.WithContext(ctx)
err := h.addForwardedHeaders(req)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if got := req.Header.Get("X-Forwarded-For"); got != "" {
t.Errorf("X-Forwarded-For should be deleted, got %q", got)
}
if got := req.Header.Get("X-Forwarded-Proto"); got != "" {
t.Errorf("X-Forwarded-Proto should be deleted, got %q", got)
}
if got := req.Header.Get("X-Forwarded-Host"); got != "" {
t.Errorf("X-Forwarded-Host should be deleted, got %q", got)
}
}
func TestAddForwardedHeaders_UnixSocketTrustedNoExistingHeaders(t *testing.T) {
h := Handler{}
req := httptest.NewRequest("GET", "http://example.com/", nil)
req.RemoteAddr = "@"
vars := map[string]interface{}{
caddyhttp.TrustedProxyVarKey: true,
caddyhttp.ClientIPVarKey: "5.6.7.8",
}
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
req = req.WithContext(ctx)
err := h.addForwardedHeaders(req)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if got := req.Header.Get("X-Forwarded-For"); got != "" {
t.Errorf("X-Forwarded-For should be empty when no prior XFF exists, got %q", got)
}
if got := req.Header.Get("X-Forwarded-Proto"); got != "http" {
t.Errorf("X-Forwarded-Proto = %q, want %q", got, "http")
}
if got := req.Header.Get("X-Forwarded-Host"); got != "example.com" {
t.Errorf("X-Forwarded-Host = %q, want %q", got, "example.com")
}
}
+10 -10
View File
@@ -23,7 +23,6 @@ import (
"net/url"
"regexp"
"runtime/debug"
"slices"
"strconv"
"strings"
"time"
@@ -360,6 +359,12 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
dialInfoUpstream = &Upstream{
Dial: h.HealthChecks.Active.Upstream,
}
} else if upstream.activeHealthCheckPort != 0 {
// health_port overrides the port; addr has already been updated
// with the health port, so use its address for dialing
dialInfoUpstream = &Upstream{
Dial: addr.JoinHostPort(0),
}
}
dialInfo, _ := dialInfoUpstream.fillDialInfo(repl)
@@ -405,14 +410,9 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ
u.Host = net.JoinHostPort(host, port)
}
// this is kind of a hacky way to know if we should use HTTPS, but whatever
if tt, ok := h.Transport.(TLSTransport); ok && tt.TLSEnabled() {
u.Scheme = "https"
// if the port is in the except list, flip back to HTTP
if ht, ok := h.Transport.(*HTTPTransport); ok && slices.Contains(ht.TLS.ExceptPorts, port) {
u.Scheme = "http"
}
// override health check schemes if applicable
if hcsot, ok := h.Transport.(HealthCheckSchemeOverriderTransport); ok {
hcsot.OverrideHealthCheckScheme(u, port)
}
// if we have a provisioned uri, use that, otherwise use
@@ -506,7 +506,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ
}
// do the request, being careful to tame the response body
resp, err := h.HealthChecks.Active.httpClient.Do(req)
resp, err := h.HealthChecks.Active.httpClient.Do(req) //nolint:gosec // no SSRF
if err != nil {
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "HTTP request failed"); c != nil {
c.Write(
+64
View File
@@ -19,7 +19,9 @@ import (
"fmt"
"net/netip"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
@@ -132,6 +134,43 @@ func (u *Upstream) fillHost() {
u.Host = host
}
// fillDynamicHost is like fillHost, but stores the host in the separate
// dynamicHosts map rather than the reference-counted UsagePool. Dynamic
// hosts are not reference-counted; instead, they are retained as long as
// they are actively seen and are evicted by a background cleanup goroutine
// after dynamicHostIdleExpiry of inactivity. This preserves health state
// (e.g. passive fail counts) across sequential requests.
func (u *Upstream) fillDynamicHost() {
dynamicHostsMu.Lock()
entry, ok := dynamicHosts[u.String()]
if ok {
entry.lastSeen = time.Now()
dynamicHosts[u.String()] = entry
u.Host = entry.host
} else {
h := new(Host)
dynamicHosts[u.String()] = dynamicHostEntry{host: h, lastSeen: time.Now()}
u.Host = h
}
dynamicHostsMu.Unlock()
// ensure the cleanup goroutine is running
dynamicHostsCleanerOnce.Do(func() {
go func() {
for {
time.Sleep(dynamicHostCleanupInterval)
dynamicHostsMu.Lock()
for addr, entry := range dynamicHosts {
if time.Since(entry.lastSeen) > dynamicHostIdleExpiry {
delete(dynamicHosts, addr)
}
}
dynamicHostsMu.Unlock()
}
}()
})
}
// Host is the basic, in-memory representation of the state of a remote host.
// Its fields are accessed atomically and Host values must not be copied.
type Host struct {
@@ -268,6 +307,28 @@ func GetDialInfo(ctx context.Context) (DialInfo, bool) {
// through config reloads.
var hosts = caddy.NewUsagePool()
// dynamicHosts tracks hosts that were provisioned from dynamic upstream
// sources. Unlike static upstreams which are reference-counted via the
// UsagePool, dynamic upstream hosts are not reference-counted. Instead,
// their last-seen time is updated on each request, and a background
// goroutine evicts entries that have been idle for dynamicHostIdleExpiry.
// This preserves health state (e.g. passive fail counts) across requests
// to the same dynamic backend.
var (
dynamicHosts = make(map[string]dynamicHostEntry)
dynamicHostsMu sync.RWMutex
dynamicHostsCleanerOnce sync.Once
dynamicHostCleanupInterval = 5 * time.Minute
dynamicHostIdleExpiry = time.Hour
)
// dynamicHostEntry holds a Host and the last time it was seen
// in a set of dynamic upstreams returned for a request.
type dynamicHostEntry struct {
host *Host
lastSeen time.Time
}
// dialInfoVarKey is the key used for the variable that holds
// the dial info for the upstream connection.
const dialInfoVarKey = "reverse_proxy.dial_info"
@@ -285,3 +346,6 @@ type ProxyProtocolInfo struct {
// tlsH1OnlyVarKey is the key used that indicates the connection will use h1 only for TLS.
// https://github.com/caddyserver/caddy/issues/7292
const tlsH1OnlyVarKey = "reverse_proxy.tls_h1_only"
// proxyVarKey is the key used that indicates the proxy server used for a request.
const proxyVarKey = "reverse_proxy.proxy"
+112 -67
View File
@@ -21,9 +21,10 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
weakrand "math/rand"
weakrand "math/rand/v2"
"net"
"net/http"
"net/url"
"os"
"reflect"
"slices"
@@ -39,6 +40,7 @@ import (
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/headers"
"github.com/caddyserver/caddy/v2/modules/caddytls"
"github.com/caddyserver/caddy/v2/modules/internal/network"
)
@@ -159,8 +161,7 @@ type HTTPTransport struct {
// `HTTPS_PROXY`, and `NO_PROXY` environment variables.
NetworkProxyRaw json.RawMessage `json:"network_proxy,omitempty" caddy:"namespace=caddy.network_proxy inline_key=from"`
h2cTransport *http2.Transport
h3Transport *http3.Transport // TODO: EXPERIMENTAL (May 2024)
h3Transport *http3.Transport // TODO: EXPERIMENTAL (May 2024)
}
// CaddyModule returns the Caddy module information.
@@ -204,11 +205,16 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, error) {
// Set keep-alive defaults if it wasn't otherwise configured
if h.KeepAlive == nil {
h.KeepAlive = &KeepAlive{
ProbeInterval: caddy.Duration(30 * time.Second),
IdleConnTimeout: caddy.Duration(2 * time.Minute),
MaxIdleConnsPerHost: 32, // seems about optimal, see #2805
}
h.KeepAlive = new(KeepAlive)
}
if h.KeepAlive.ProbeInterval == 0 {
h.KeepAlive.ProbeInterval = caddy.Duration(30 * time.Second)
}
if h.KeepAlive.IdleConnTimeout == 0 {
h.KeepAlive.IdleConnTimeout = caddy.Duration(2 * time.Minute)
}
if h.KeepAlive.MaxIdleConnsPerHost == 0 {
h.KeepAlive.MaxIdleConnsPerHost = 32 // seems about optimal, see #2805
}
// Set a relatively short default dial timeout.
@@ -260,22 +266,22 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
//nolint:gosec
addr := h.Resolver.netAddrs[weakrand.Intn(len(h.Resolver.netAddrs))]
addr := h.Resolver.netAddrs[weakrand.IntN(len(h.Resolver.netAddrs))]
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
},
}
}
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
// For unix socket upstreams, we need to recover the dial info from
// the request's context, because the Host on the request's URL
// will have been modified by directing the request, overwriting
// the unix socket filename.
// Also, we need to avoid overwriting the address at this point
// when not necessary, because http.ProxyFromEnvironment may have
// modified the address according to the user's env proxy config.
// The network is usually tcp, and the address is the host in http.Request.URL.Host
// and that's been overwritten in directRequest
// However, if proxy is used according to http.ProxyFromEnvironment or proxy providers,
// address will be the address of the proxy server.
// This means we can safely use the address in dialInfo if proxy is not used (the address and network will be same any way)
// or if the upstream is unix (because there is no way socks or http proxy can be used for unix address).
if dialInfo, ok := GetDialInfo(ctx); ok {
if strings.HasPrefix(dialInfo.Network, "unix") {
if caddyhttp.GetVar(ctx, proxyVarKey) == nil || strings.HasPrefix(dialInfo.Network, "unix") {
network = dialInfo.Network
address = dialInfo.Address
}
@@ -376,9 +382,22 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
return nil, fmt.Errorf("network_proxy module is not `(func(*http.Request) (*url.URL, error))``")
}
}
// we need to keep track if a proxy is used for a request
proxyWrapper := func(req *http.Request) (*url.URL, error) {
if proxy == nil {
return nil, nil
}
u, err := proxy(req)
if u == nil || err != nil {
return u, err
}
// there must be a proxy for this request
caddyhttp.SetVar(req.Context(), proxyVarKey, u)
return u, nil
}
rt := &http.Transport{
Proxy: proxy,
Proxy: proxyWrapper,
DialContext: dialContext,
MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
@@ -396,8 +415,13 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
return nil, fmt.Errorf("making TLS client config: %v", err)
}
// servername has a placeholder, so we need to replace it
if strings.Contains(h.TLS.ServerName, "{") {
serverNameHasPlaceholder := strings.Contains(h.TLS.ServerName, "{")
// We need to use custom DialTLSContext if:
// 1. ServerName has a placeholder that needs to be replaced at request-time, OR
// 2. ProxyProtocol is enabled, because req.URL.Host is modified to include
// client address info with "->" separator which breaks Go's address parsing
if serverNameHasPlaceholder || h.ProxyProtocol != "" {
rt.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
// reuses the dialer from above to establish a plaintext connection
conn, err := dialContext(ctx, network, addr)
@@ -406,9 +430,11 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
// but add our own handshake logic
repl := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
tlsConfig := rt.TLSClientConfig.Clone()
tlsConfig.ServerName = repl.ReplaceAll(tlsConfig.ServerName, "")
if serverNameHasPlaceholder {
repl := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
tlsConfig.ServerName = repl.ReplaceAll(tlsConfig.ServerName, "")
}
// h1 only
if caddyhttp.GetVar(ctx, tlsH1OnlyVarKey) == true {
@@ -422,7 +448,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
// complete the handshake before returning the connection
if rt.TLSHandshakeTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, rt.TLSHandshakeTimeout)
ctx, cancel = context.WithTimeoutCause(ctx, rt.TLSHandshakeTimeout, fmt.Errorf("HTTP transport TLS handshake %ds timeout", int(rt.TLSHandshakeTimeout.Seconds())))
defer cancel()
}
err = tlsConn.HandshakeContext(ctx)
@@ -457,24 +483,10 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
}
// The proxy protocol header can only be sent once right after opening the connection.
// So single connection must not be used for multiple requests, which can potentially
// come from different clients.
if !rt.DisableKeepAlives && h.ProxyProtocol != "" {
caddyCtx.Logger().Warn("disabling keepalives, they are incompatible with using PROXY protocol")
rt.DisableKeepAlives = true
}
if h.Compression != nil {
rt.DisableCompression = !*h.Compression
}
if slices.Contains(h.Versions, "2") {
if err := http2.ConfigureTransport(rt); err != nil {
return nil, err
}
}
// configure HTTP/3 transport if enabled; however, this does not
// automatically fall back to lower versions like most web browsers
// do (that'd add latency and complexity, besides, we expect that
@@ -492,30 +504,49 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
return nil, fmt.Errorf("if HTTP/3 is enabled to the upstream, no other HTTP versions are supported")
}
// if h2c is enabled, configure its transport (std lib http.Transport
// does not "HTTP/2 over cleartext TCP")
if slices.Contains(h.Versions, "h2c") {
// crafting our own http2.Transport doesn't allow us to utilize
// most of the customizations/preferences on the http.Transport,
// because, for some reason, only http2.ConfigureTransport()
// is allowed to set the unexported field that refers to a base
// http.Transport config; oh well
h2t := &http2.Transport{
// kind of a hack, but for plaintext/H2C requests, pretend to dial TLS
DialTLSContext: func(ctx context.Context, network, address string, _ *tls.Config) (net.Conn, error) {
return dialContext(ctx, network, address)
},
AllowHTTP: true,
// if h2/c is enabled, configure it explicitly
if slices.Contains(h.Versions, "2") || slices.Contains(h.Versions, "h2c") {
if err := http2.ConfigureTransport(rt); err != nil {
return nil, err
}
if h.Compression != nil {
h2t.DisableCompression = !*h.Compression
// DisableCompression from h2 is configured by http2.ConfigureTransport
// Likewise, DisableKeepAlives from h1 is used too.
// Protocols field is only used when the request is not using TLS,
// http1/2 over tls is still allowed
if slices.Contains(h.Versions, "h2c") {
rt.Protocols = new(http.Protocols)
rt.Protocols.SetUnencryptedHTTP2(true)
rt.Protocols.SetHTTP1(false)
}
h.h2cTransport = h2t
}
return rt, nil
}
// RequestHeaderOps implements TransportHeaderOpsProvider. It returns header
// operations for requests when the transport's configuration indicates they
// should be applied. In particular, when TLS is enabled for this transport,
// return an operation to set the Host header to the upstream host:port
// placeholder so HTTPS upstreams get the proper Host by default.
//
// Note: this is a provision-time hook; the Handler will call this during
// its Provision and cache the resulting HeaderOps. The HeaderOps are
// applied per-request (so placeholders are expanded at request time).
func (h *HTTPTransport) RequestHeaderOps() *headers.HeaderOps {
// If TLS is not configured for this transport, don't inject Host
// defaults. TLS being non-nil indicates HTTPS to the upstream.
if h.TLS == nil {
return nil
}
return &headers.HeaderOps{
Set: http.Header{
"Host": []string{"{http.reverse_proxy.upstream.hostport}"},
},
}
}
// RoundTrip implements http.RoundTripper.
func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
h.SetScheme(req)
@@ -525,15 +556,6 @@ func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return h.h3Transport.RoundTrip(req)
}
// if H2C ("HTTP/2 over cleartext") is enabled and the upstream request is
// HTTP without TLS, use the alternate H2C-capable transport instead
if req.URL.Scheme == "http" && h.h2cTransport != nil {
// There is no dedicated DisableKeepAlives field in *http2.Transport.
// This is an alternative way to disable keep-alive.
req.Close = h.Transport.DisableKeepAlives
return h.h2cTransport.RoundTrip(req)
}
return h.Transport.RoundTrip(req)
}
@@ -575,6 +597,26 @@ func (h *HTTPTransport) EnableTLS(base *TLSConfig) error {
return nil
}
// EnableH2C enables H2C (HTTP/2 over Cleartext) on the transport.
func (h *HTTPTransport) EnableH2C() error {
h.Versions = []string{"h2c", "2"}
return nil
}
// OverrideHealthCheckScheme overrides the scheme of the given URL
// used for health checks.
func (h HTTPTransport) OverrideHealthCheckScheme(base *url.URL, port string) {
// if tls is enabled and the port isn't in the except list, use HTTPs
if h.TLSEnabled() && !slices.Contains(h.TLS.ExceptPorts, port) {
base.Scheme = "https"
}
}
// ProxyProtocolEnabled returns true if proxy protocol is enabled.
func (h HTTPTransport) ProxyProtocolEnabled() bool {
return h.ProxyProtocol != ""
}
// Cleanup implements caddy.CleanerUpper and closes any idle connections.
func (h HTTPTransport) Cleanup() error {
if h.Transport == nil {
@@ -831,8 +873,11 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
// Interface guards
var (
_ caddy.Provisioner = (*HTTPTransport)(nil)
_ http.RoundTripper = (*HTTPTransport)(nil)
_ caddy.CleanerUpper = (*HTTPTransport)(nil)
_ TLSTransport = (*HTTPTransport)(nil)
_ caddy.Provisioner = (*HTTPTransport)(nil)
_ http.RoundTripper = (*HTTPTransport)(nil)
_ caddy.CleanerUpper = (*HTTPTransport)(nil)
_ TLSTransport = (*HTTPTransport)(nil)
_ H2CTransport = (*HTTPTransport)(nil)
_ HealthCheckSchemeOverriderTransport = (*HTTPTransport)(nil)
_ ProxyProtocolTransport = (*HTTPTransport)(nil)
)
@@ -1,11 +1,13 @@
package reverseproxy
import (
"context"
"encoding/json"
"fmt"
"reflect"
"testing"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
)
@@ -94,3 +96,102 @@ func TestHTTPTransportUnmarshalCaddyFileWithCaPools(t *testing.T) {
})
}
}
func TestHTTPTransport_RequestHeaderOps_TLS(t *testing.T) {
var ht HTTPTransport
// When TLS is nil, expect no header ops
if ops := ht.RequestHeaderOps(); ops != nil {
t.Fatalf("expected nil HeaderOps when TLS is nil, got: %#v", ops)
}
// When TLS is configured, expect a HeaderOps that sets Host
ht.TLS = &TLSConfig{}
ops := ht.RequestHeaderOps()
if ops == nil {
t.Fatal("expected non-nil HeaderOps when TLS is set")
}
if ops.Set == nil {
t.Fatalf("expected ops.Set to be non-nil, got nil")
}
if got := ops.Set.Get("Host"); got != "{http.reverse_proxy.upstream.hostport}" {
t.Fatalf("unexpected Host value; want placeholder, got: %s", got)
}
}
// TestHTTPTransport_DialTLSContext_ProxyProtocol verifies that when TLS and
// ProxyProtocol are both enabled, DialTLSContext is set. This is critical because
// ProxyProtocol modifies req.URL.Host to include client info with "->" separator
// (e.g., "[2001:db8::1]:12345->127.0.0.1:443"), which breaks Go's address parsing.
// Without a custom DialTLSContext, Go's HTTP library would fail with
// "too many colons in address" when trying to parse the mangled host.
func TestHTTPTransport_DialTLSContext_ProxyProtocol(t *testing.T) {
ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()
tests := []struct {
name string
tls *TLSConfig
proxyProtocol string
serverNameHasPlaceholder bool
expectDialTLSContext bool
}{
{
name: "no TLS, no proxy protocol",
tls: nil,
proxyProtocol: "",
expectDialTLSContext: false,
},
{
name: "TLS without proxy protocol",
tls: &TLSConfig{},
proxyProtocol: "",
expectDialTLSContext: false,
},
{
name: "TLS with proxy protocol v1",
tls: &TLSConfig{},
proxyProtocol: "v1",
expectDialTLSContext: true,
},
{
name: "TLS with proxy protocol v2",
tls: &TLSConfig{},
proxyProtocol: "v2",
expectDialTLSContext: true,
},
{
name: "TLS with placeholder ServerName",
tls: &TLSConfig{ServerName: "{http.request.host}"},
proxyProtocol: "",
serverNameHasPlaceholder: true,
expectDialTLSContext: true,
},
{
name: "TLS with placeholder ServerName and proxy protocol",
tls: &TLSConfig{ServerName: "{http.request.host}"},
proxyProtocol: "v2",
serverNameHasPlaceholder: true,
expectDialTLSContext: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ht := &HTTPTransport{
TLS: tt.tls,
ProxyProtocol: tt.proxyProtocol,
}
rt, err := ht.NewTransport(ctx)
if err != nil {
t.Fatalf("NewTransport() error = %v", err)
}
hasDialTLSContext := rt.DialTLSContext != nil
if hasDialTLSContext != tt.expectDialTLSContext {
t.Errorf("DialTLSContext set = %v, want %v", hasDialTLSContext, tt.expectDialTLSContext)
}
})
}
}
@@ -0,0 +1,391 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import (
"context"
"testing"
"time"
"github.com/caddyserver/caddy/v2"
)
// newPassiveHandler builds a minimal Handler with passive health checks
// configured and a live caddy.Context so the fail-forgetter goroutine can
// be cancelled cleanly. The caller must call cancel() when done.
func newPassiveHandler(t *testing.T, maxFails int, failDuration time.Duration) (*Handler, context.CancelFunc) {
t.Helper()
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
h := &Handler{
ctx: caddyCtx,
HealthChecks: &HealthChecks{
Passive: &PassiveHealthChecks{
MaxFails: maxFails,
FailDuration: caddy.Duration(failDuration),
},
},
}
return h, cancel
}
// provisionedStaticUpstream creates a static upstream, registers it in the
// UsagePool, and returns a cleanup func that removes it from the pool.
func provisionedStaticUpstream(t *testing.T, h *Handler, addr string) (*Upstream, func()) {
t.Helper()
u := &Upstream{Dial: addr}
h.provisionUpstream(u, false)
return u, func() { _, _ = hosts.Delete(addr) }
}
// provisionedDynamicUpstream creates a dynamic upstream, registers it in
// dynamicHosts, and returns a cleanup func that removes it.
func provisionedDynamicUpstream(t *testing.T, h *Handler, addr string) (*Upstream, func()) {
t.Helper()
u := &Upstream{Dial: addr}
h.provisionUpstream(u, true)
return u, func() {
dynamicHostsMu.Lock()
delete(dynamicHosts, addr)
dynamicHostsMu.Unlock()
}
}
// --- countFailure behaviour ---
// TestCountFailureNoopWhenNoHealthChecks verifies that countFailure is a no-op
// when HealthChecks is nil.
func TestCountFailureNoopWhenNoHealthChecks(t *testing.T) {
resetDynamicHosts()
h := &Handler{}
u := &Upstream{Dial: "10.1.0.1:80", Host: new(Host)}
h.countFailure(u)
if u.Host.Fails() != 0 {
t.Errorf("expected 0 fails with no HealthChecks config, got %d", u.Host.Fails())
}
}
// TestCountFailureNoopWhenZeroDuration verifies that countFailure is a no-op
// when FailDuration is 0 (the zero value disables passive checks).
func TestCountFailureNoopWhenZeroDuration(t *testing.T) {
resetDynamicHosts()
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()
h := &Handler{
ctx: caddyCtx,
HealthChecks: &HealthChecks{
Passive: &PassiveHealthChecks{MaxFails: 1, FailDuration: 0},
},
}
u := &Upstream{Dial: "10.1.0.2:80", Host: new(Host)}
h.countFailure(u)
if u.Host.Fails() != 0 {
t.Errorf("expected 0 fails with zero FailDuration, got %d", u.Host.Fails())
}
}
// TestCountFailureIncrementsCount verifies that countFailure increments the
// fail count on the upstream's Host.
func TestCountFailureIncrementsCount(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Minute)
defer cancel()
u := &Upstream{Dial: "10.1.0.3:80", Host: new(Host)}
h.countFailure(u)
if u.Host.Fails() != 1 {
t.Errorf("expected 1 fail after countFailure, got %d", u.Host.Fails())
}
}
// TestCountFailureDecrementsAfterDuration verifies that the fail count is
// decremented back after FailDuration elapses.
func TestCountFailureDecrementsAfterDuration(t *testing.T) {
resetDynamicHosts()
const failDuration = 50 * time.Millisecond
h, cancel := newPassiveHandler(t, 2, failDuration)
defer cancel()
u := &Upstream{Dial: "10.1.0.4:80", Host: new(Host)}
h.countFailure(u)
if u.Host.Fails() != 1 {
t.Fatalf("expected 1 fail immediately after countFailure, got %d", u.Host.Fails())
}
// Wait long enough for the forgetter goroutine to fire.
time.Sleep(3 * failDuration)
if u.Host.Fails() != 0 {
t.Errorf("expected fail count to return to 0 after FailDuration, got %d", u.Host.Fails())
}
}
// TestCountFailureCancelledContextForgets verifies that cancelling the handler
// context (simulating a config unload) also triggers the forgetter to run,
// decrementing the fail count.
func TestCountFailureCancelledContextForgets(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Hour) // very long duration
u := &Upstream{Dial: "10.1.0.5:80", Host: new(Host)}
h.countFailure(u)
if u.Host.Fails() != 1 {
t.Fatalf("expected 1 fail immediately after countFailure, got %d", u.Host.Fails())
}
// Cancelling the context should cause the forgetter goroutine to exit and
// decrement the count.
cancel()
time.Sleep(50 * time.Millisecond)
if u.Host.Fails() != 0 {
t.Errorf("expected fail count to be decremented after context cancel, got %d", u.Host.Fails())
}
}
// --- static upstream passive health check ---
// TestStaticUpstreamHealthyWithNoFailures verifies that a static upstream with
// no recorded failures is considered healthy.
func TestStaticUpstreamHealthyWithNoFailures(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Minute)
defer cancel()
u, cleanup := provisionedStaticUpstream(t, h, "10.2.0.1:80")
defer cleanup()
if !u.Healthy() {
t.Error("upstream with no failures should be healthy")
}
}
// TestStaticUpstreamUnhealthyAtMaxFails verifies that a static upstream is
// marked unhealthy once its fail count reaches MaxFails.
func TestStaticUpstreamUnhealthyAtMaxFails(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Minute)
defer cancel()
u, cleanup := provisionedStaticUpstream(t, h, "10.2.0.2:80")
defer cleanup()
h.countFailure(u)
if !u.Healthy() {
t.Error("upstream should still be healthy after 1 of 2 allowed failures")
}
h.countFailure(u)
if u.Healthy() {
t.Error("upstream should be unhealthy after reaching MaxFails=2")
}
}
// TestStaticUpstreamRecoversAfterFailDuration verifies that a static upstream
// returns to healthy once its failures expire.
func TestStaticUpstreamRecoversAfterFailDuration(t *testing.T) {
resetDynamicHosts()
const failDuration = 50 * time.Millisecond
h, cancel := newPassiveHandler(t, 1, failDuration)
defer cancel()
u, cleanup := provisionedStaticUpstream(t, h, "10.2.0.3:80")
defer cleanup()
h.countFailure(u)
if u.Healthy() {
t.Fatal("upstream should be unhealthy immediately after MaxFails failure")
}
time.Sleep(3 * failDuration)
if !u.Healthy() {
t.Errorf("upstream should recover to healthy after FailDuration, Fails=%d", u.Host.Fails())
}
}
// TestStaticUpstreamHealthPersistedAcrossReprovisioning verifies that static
// upstreams share a Host via the UsagePool, so a second call to provisionUpstream
// for the same address (as happens on config reload) sees the accumulated state.
func TestStaticUpstreamHealthPersistedAcrossReprovisioning(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Minute)
defer cancel()
u1, cleanup1 := provisionedStaticUpstream(t, h, "10.2.0.4:80")
defer cleanup1()
h.countFailure(u1)
h.countFailure(u1)
// Simulate a second handler instance referencing the same upstream
// (e.g. after a config reload that keeps the same backend address).
u2, cleanup2 := provisionedStaticUpstream(t, h, "10.2.0.4:80")
defer cleanup2()
if u1.Host != u2.Host {
t.Fatal("expected both Upstream structs to share the same *Host via UsagePool")
}
if u2.Healthy() {
t.Error("re-provisioned upstream should still see the prior fail count and be unhealthy")
}
}
// --- dynamic upstream passive health check ---
// TestDynamicUpstreamHealthyWithNoFailures verifies that a freshly provisioned
// dynamic upstream is healthy.
func TestDynamicUpstreamHealthyWithNoFailures(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Minute)
defer cancel()
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.1:80")
defer cleanup()
if !u.Healthy() {
t.Error("dynamic upstream with no failures should be healthy")
}
}
// TestDynamicUpstreamUnhealthyAtMaxFails verifies that a dynamic upstream is
// marked unhealthy once its fail count reaches MaxFails.
func TestDynamicUpstreamUnhealthyAtMaxFails(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Minute)
defer cancel()
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.2:80")
defer cleanup()
h.countFailure(u)
if !u.Healthy() {
t.Error("dynamic upstream should still be healthy after 1 of 2 allowed failures")
}
h.countFailure(u)
if u.Healthy() {
t.Error("dynamic upstream should be unhealthy after reaching MaxFails=2")
}
}
// TestDynamicUpstreamFailCountPersistedBetweenRequests is the core regression
// test: it simulates two sequential (non-concurrent) requests to the same
// dynamic upstream. Before the fix, the UsagePool entry would be deleted
// between requests, wiping the fail count. Now it should survive.
func TestDynamicUpstreamFailCountPersistedBetweenRequests(t *testing.T) {
resetDynamicHosts()
h, cancel := newPassiveHandler(t, 2, time.Minute)
defer cancel()
// --- first request ---
u1 := &Upstream{Dial: "10.3.0.3:80"}
h.provisionUpstream(u1, true)
h.countFailure(u1)
if u1.Host.Fails() != 1 {
t.Fatalf("expected 1 fail after first request, got %d", u1.Host.Fails())
}
// Simulate end of first request: no delete from any pool (key difference
// vs. the old behaviour where hosts.Delete was deferred).
// --- second request: brand-new *Upstream struct, same dial address ---
u2 := &Upstream{Dial: "10.3.0.3:80"}
h.provisionUpstream(u2, true)
if u1.Host != u2.Host {
t.Fatal("expected both requests to share the same *Host pointer from dynamicHosts")
}
if u2.Host.Fails() != 1 {
t.Errorf("expected fail count to persist across requests, got %d", u2.Host.Fails())
}
// A second failure now tips it over MaxFails=2.
h.countFailure(u2)
if u2.Healthy() {
t.Error("upstream should be unhealthy after accumulated failures across requests")
}
// Cleanup.
dynamicHostsMu.Lock()
delete(dynamicHosts, "10.3.0.3:80")
dynamicHostsMu.Unlock()
}
// TestDynamicUpstreamRecoveryAfterFailDuration verifies that a dynamic
// upstream's fail count expires and it returns to healthy.
func TestDynamicUpstreamRecoveryAfterFailDuration(t *testing.T) {
resetDynamicHosts()
const failDuration = 50 * time.Millisecond
h, cancel := newPassiveHandler(t, 1, failDuration)
defer cancel()
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.4:80")
defer cleanup()
h.countFailure(u)
if u.Healthy() {
t.Fatal("upstream should be unhealthy immediately after MaxFails failure")
}
time.Sleep(3 * failDuration)
// Re-provision (as a new request would) to get fresh *Upstream with policy set.
u2 := &Upstream{Dial: "10.3.0.4:80"}
h.provisionUpstream(u2, true)
if !u2.Healthy() {
t.Errorf("dynamic upstream should recover to healthy after FailDuration, Fails=%d", u2.Host.Fails())
}
}
// TestDynamicUpstreamMaxRequestsFromUnhealthyRequestCount verifies that
// UnhealthyRequestCount is copied into MaxRequests so Full() works correctly.
func TestDynamicUpstreamMaxRequestsFromUnhealthyRequestCount(t *testing.T) {
resetDynamicHosts()
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()
h := &Handler{
ctx: caddyCtx,
HealthChecks: &HealthChecks{
Passive: &PassiveHealthChecks{
UnhealthyRequestCount: 3,
},
},
}
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.5:80")
defer cleanup()
if u.MaxRequests != 3 {
t.Errorf("expected MaxRequests=3 from UnhealthyRequestCount, got %d", u.MaxRequests)
}
// Should not be full with fewer requests than the limit.
_ = u.Host.countRequest(2)
if u.Full() {
t.Error("upstream should not be full with 2 of 3 allowed requests")
}
_ = u.Host.countRequest(1)
if !u.Full() {
t.Error("upstream should be full at UnhealthyRequestCount concurrent requests")
}
}
@@ -0,0 +1,257 @@
package reverseproxy
import (
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"go.uber.org/zap"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
// prepareTestRequest injects the context values that ServeHTTP and
// proxyLoopIteration require (caddy.ReplacerCtxKey, VarsCtxKey, etc.) using
// the same helper that the real HTTP server uses.
//
// A zero-value Server is passed so that caddyhttp.ServerCtxKey is set to a
// non-nil pointer; reverseProxy dereferences it to check ShouldLogCredentials.
func prepareTestRequest(req *http.Request) *http.Request {
repl := caddy.NewReplacer()
return caddyhttp.PrepareRequest(req, repl, nil, &caddyhttp.Server{})
}
// closeOnCloseReader is an io.ReadCloser whose Close method actually makes
// subsequent reads fail, mimicking the behaviour of a real HTTP request body
// (as opposed to io.NopCloser, whose Close is a no-op and would mask the bug
// we are testing).
type closeOnCloseReader struct {
mu sync.Mutex
r *strings.Reader
closed bool
}
func newCloseOnCloseReader(s string) *closeOnCloseReader {
return &closeOnCloseReader{r: strings.NewReader(s)}
}
func (c *closeOnCloseReader) Read(p []byte) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return 0, errors.New("http: invalid Read on closed Body")
}
return c.r.Read(p)
}
func (c *closeOnCloseReader) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
// deadUpstreamAddr returns a TCP address that is guaranteed to refuse
// connections: we bind a listener, note its address, close it immediately,
// and return the address. Any dial to that address will get ECONNREFUSED.
func deadUpstreamAddr(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create dead upstream listener: %v", err)
}
addr := ln.Addr().String()
ln.Close()
return addr
}
// testTransport wraps http.Transport to:
// 1. Set the URL scheme to "http" when it is empty (matching what
// HTTPTransport.SetScheme does in production; cloneRequest strips the
// scheme intentionally so a plain *http.Transport would fail with
// "unsupported protocol scheme").
// 2. Wrap dial errors as DialError so that tryAgain correctly identifies them
// as safe-to-retry regardless of request method (as HTTPTransport does in
// production via its custom dialer).
type testTransport struct{ *http.Transport }
func (t testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme == "" {
req.URL.Scheme = "http"
}
resp, err := t.Transport.RoundTrip(req)
if err != nil {
// Wrap dial errors as DialError to match production behaviour.
// Without this wrapping, tryAgain treats ECONNREFUSED on a POST
// request as non-retryable (only GET is retried by default when
// the error is not a DialError).
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" {
return nil, DialError{err}
}
}
return resp, err
}
// minimalHandler returns a Handler with only the fields required by ServeHTTP
// set directly, bypassing Provision (which requires a full Caddy runtime).
// RoundRobinSelection is used so that successive iterations of the proxy loop
// advance through the upstream pool in a predictable order.
func minimalHandler(retries int, upstreams ...*Upstream) *Handler {
return &Handler{
logger: zap.NewNop(),
Transport: testTransport{&http.Transport{}},
Upstreams: upstreams,
LoadBalancing: &LoadBalancing{
Retries: retries,
SelectionPolicy: &RoundRobinSelection{},
// RetryMatch intentionally nil: dial errors are always retried
// regardless of RetryMatch or request method.
},
// ctx, connections, connectionsMu, events: zero/nil values are safe
// for the code paths exercised by these tests (TryInterval=0 so
// ctx.Done() is never consulted; no WebSocket hijacking; no passive
// health-check event emission).
}
}
// TestDialErrorBodyRetry verifies that a POST request whose body has NOT been
// pre-buffered via request_buffers can still be retried after a dial error.
//
// Before the fix, a dial error caused Go's transport to close the shared body
// (via cloneRequest's shallow copy), so the retry attempt would read from an
// already-closed io.ReadCloser and produce:
//
// http: invalid Read on closed Body → HTTP 502
//
// After the fix the handler wraps the body in noCloseBody when retries are
// configured, preventing the transport's Close() from propagating to the
// shared body. Since dial errors never read any bytes, the body remains at
// position 0 for the retry.
func TestDialErrorBodyRetry(t *testing.T) {
// Good upstream: echoes the request body with 200 OK.
goodServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read body: "+err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
}))
t.Cleanup(goodServer.Close)
const requestBody = "hello, retry"
tests := []struct {
name string
method string
body string
retries int
wantStatus int
wantBody string
}{
{
// Core regression case: POST with a body, no request_buffers,
// dial error on first upstream → retry to second upstream succeeds.
name: "POST body retried after dial error",
method: http.MethodPost,
body: requestBody,
retries: 1,
wantStatus: http.StatusOK,
wantBody: requestBody,
},
{
// Dial errors are always retried regardless of method, but there
// is no body to re-read, so GET has always worked. Keep it as a
// sanity check that we did not break the no-body path.
name: "GET without body retried after dial error",
method: http.MethodGet,
body: "",
retries: 1,
wantStatus: http.StatusOK,
wantBody: "",
},
{
// Without any retry configuration the handler must give up on the
// first dial error and return a 502. Confirms no wrapping occurs
// in the no-retry path.
name: "no retries configured returns 502 on dial error",
method: http.MethodPost,
body: requestBody,
retries: 0,
wantStatus: http.StatusBadGateway,
wantBody: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
dead := deadUpstreamAddr(t)
// Build the upstream pool. RoundRobinSelection starts its
// counter at 0 and increments before returning, so with a
// two-element pool it picks index 1 first, then index 0.
// Put the good upstream at index 0 and the dead one at
// index 1 so that:
// attempt 1 → pool[1] = dead → DialError (ECONNREFUSED)
// attempt 2 → pool[0] = good → 200
upstreams := []*Upstream{
{Host: new(Host), Dial: goodServer.Listener.Addr().String()},
{Host: new(Host), Dial: dead},
}
if tc.retries == 0 {
// For the "no retries" case use only the dead upstream so
// there is nowhere to retry to.
upstreams = []*Upstream{
{Host: new(Host), Dial: dead},
}
}
h := minimalHandler(tc.retries, upstreams...)
// Use closeOnCloseReader so that Close() truly prevents further
// reads, matching real http.body semantics. io.NopCloser would
// mask the bug because its Close is a no-op.
var bodyReader io.ReadCloser
if tc.body != "" {
bodyReader = newCloseOnCloseReader(tc.body)
}
req := httptest.NewRequest(tc.method, "http://example.com/", bodyReader)
if bodyReader != nil {
// httptest.NewRequest wraps the reader in NopCloser; replace
// it with our close-aware reader so Close() is propagated.
req.Body = bodyReader
req.ContentLength = int64(len(tc.body))
}
req = prepareTestRequest(req)
rec := httptest.NewRecorder()
err := h.ServeHTTP(rec, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
// For error cases (e.g. 502) ServeHTTP returns a HandlerError
// rather than writing the status itself.
gotStatus := rec.Code
if err != nil {
if herr, ok := err.(caddyhttp.HandlerError); ok {
gotStatus = herr.StatusCode
}
}
if gotStatus != tc.wantStatus {
t.Errorf("status: got %d, want %d (err=%v)", gotStatus, tc.wantStatus, err)
}
if tc.wantBody != "" && rec.Body.String() != tc.wantBody {
t.Errorf("body: got %q, want %q", rec.Body.String(), tc.wantBody)
}
})
}
}
+232 -65
View File
@@ -32,6 +32,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
@@ -46,6 +47,31 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
)
// inFlightRequests uses sync.Map with atomic.Int64 for lock-free updates on the hot path
var inFlightRequests sync.Map
func incInFlightRequest(address string) {
v, _ := inFlightRequests.LoadOrStore(address, new(atomic.Int64))
v.(*atomic.Int64).Add(1)
}
func decInFlightRequest(address string) {
if v, ok := inFlightRequests.Load(address); ok {
if v.(*atomic.Int64).Add(-1) <= 0 {
inFlightRequests.Delete(address)
}
}
}
func getInFlightRequests() map[string]int64 {
copyMap := make(map[string]int64)
inFlightRequests.Range(func(key, value any) bool {
copyMap[key.(string)] = value.(*atomic.Int64).Load()
return true
})
return copyMap
}
func init() {
caddy.RegisterModule(Handler{})
}
@@ -192,6 +218,13 @@ type Handler struct {
CB CircuitBreaker `json:"-"`
DynamicUpstreams UpstreamSource `json:"-"`
// transportHeaderOps is a set of header operations provided
// by the transport at provision time, if the transport
// implements TransportHeaderOpsProvider. These ops are
// applied before any user-configured header ops so the
// user can override transport defaults.
transportHeaderOps *headers.HeaderOps
// Holds the parsed CIDR ranges from TrustedProxies
trustedProxies []netip.Prefix
@@ -243,18 +276,16 @@ func (h *Handler) Provision(ctx caddy.Context) error {
return fmt.Errorf("loading transport: %v", err)
}
h.Transport = mod.(http.RoundTripper)
// enable request buffering for fastcgi if not configured
// This is because most fastcgi servers are php-fpm that require the content length to be set to read the body, golang
// std has fastcgi implementation that doesn't need this value to process the body, but we can safely assume that's
// not used.
// http3 requests have a negative content length for GET and HEAD requests, if that header is not sent.
// see: https://github.com/caddyserver/caddy/issues/6678#issuecomment-2472224182
// Though it appears even if CONTENT_LENGTH is invalid, php-fpm can handle just fine if the body is empty (no Stdin records sent).
// php-fpm will hang if there is any data in the body though, https://github.com/caddyserver/caddy/issues/5420#issuecomment-2415943516
// TODO: better default buffering for fastcgi requests without content length, in theory a value of 1 should be enough, make it bigger anyway
if module, ok := h.Transport.(caddy.Module); ok && module.CaddyModule().ID.Name() == "fastcgi" && h.RequestBuffers == 0 {
h.RequestBuffers = 4096
// set default buffer sizes if applicable
if bt, ok := h.Transport.(BufferedTransport); ok {
reqBuffers, respBuffers := bt.DefaultBufferSizes()
if h.RequestBuffers == 0 {
h.RequestBuffers = reqBuffers
}
if h.ResponseBuffers == 0 {
h.ResponseBuffers = respBuffers
}
}
}
if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil {
@@ -324,6 +355,18 @@ func (h *Handler) Provision(ctx caddy.Context) error {
h.Transport = t
}
// If the transport can provide header ops, cache them now so we don't
// have to compute them per-request. Provision the HeaderOps if present
// so any runtime artifacts (like precompiled regex) are prepared.
if tph, ok := h.Transport.(RequestHeaderOpsTransport); ok {
h.transportHeaderOps = tph.RequestHeaderOps()
if h.transportHeaderOps != nil {
if err := h.transportHeaderOps.Provision(ctx); err != nil {
return fmt.Errorf("provisioning transport header ops: %v", err)
}
}
}
// set up load balancing
if h.LoadBalancing == nil {
h.LoadBalancing = new(LoadBalancing)
@@ -349,7 +392,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
// set up upstreams
for _, u := range h.Upstreams {
h.provisionUpstream(u)
h.provisionUpstream(u, false)
}
if h.HealthChecks != nil {
@@ -439,6 +482,33 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
reqHost := clonedReq.Host
reqHeader := clonedReq.Header
// When retries are configured and there is a body, wrap it in
// io.NopCloser to prevent Go's transport from closing it on dial
// errors. cloneRequest does a shallow copy, so clonedReq.Body and
// r.Body share the same io.ReadCloser — a dial-failure Close()
// would kill the original body for all subsequent retry attempts.
// The real body is closed by the HTTP server when the handler
// returns.
//
// If the body was already fully buffered (via request_buffers),
// we also extract the buffer so the retry loop can replay it
// from the beginning on each attempt. (see #6259, #7546)
var bufferedReqBody *bytes.Buffer
if clonedReq.Body != nil && h.LoadBalancing != nil &&
(h.LoadBalancing.Retries > 0 || h.LoadBalancing.TryDuration > 0) {
if reqBodyBuf, ok := clonedReq.Body.(bodyReadCloser); ok && reqBodyBuf.body == nil && reqBodyBuf.buf != nil {
bufferedReqBody = reqBodyBuf.buf
reqBodyBuf.buf = nil
clonedReq.Body = io.NopCloser(bytes.NewReader(bufferedReqBody.Bytes()))
defer func() {
bufferedReqBody.Reset()
bufPool.Put(bufferedReqBody)
}()
} else {
clonedReq.Body = io.NopCloser(clonedReq.Body)
}
}
start := time.Now()
defer func() {
// total proxying duration, including time spent on LB and retries
@@ -457,8 +527,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
// and reusable, so if a backend partially or fully reads the body but then
// produces an error, the request can be repeated to the next backend with
// the full body (retries should only happen for idempotent requests) (see #6259)
if reqBodyBuf, ok := r.Body.(bodyReadCloser); ok && reqBodyBuf.body == nil {
r.Body = io.NopCloser(bytes.NewReader(reqBodyBuf.buf.Bytes()))
if bufferedReqBody != nil {
clonedReq.Body = io.NopCloser(bytes.NewReader(bufferedReqBody.Bytes()))
}
var done bool
@@ -506,18 +576,11 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
} else {
upstreams = dUpstreams
for _, dUp := range dUpstreams {
h.provisionUpstream(dUp)
h.provisionUpstream(dUp, true)
}
if c := h.logger.Check(zapcore.DebugLevel, "provisioned dynamic upstreams"); c != nil {
c.Write(zap.Int("count", len(dUpstreams)))
}
defer func() {
// these upstreams are dynamic, so they are only used for this iteration
// of the proxy loop; be sure to let them go away when we're done with them
for _, upstream := range dUpstreams {
_, _ = hosts.Delete(upstream.String())
}
}()
}
}
@@ -563,14 +626,26 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
// mutate request headers according to this upstream;
// because we're in a retry loop, we have to copy
// headers (and the r.Host value) from the original
// so that each retry is identical to the first
if h.Headers != nil && h.Headers.Request != nil {
// because we're in a retry loop, we have to copy headers
// (and the r.Host value) from the original so that each
// retry is identical to the first. If either transport or
// user ops exist, apply them in order (transport first,
// then user, so user's config wins).
var userOps *headers.HeaderOps
if h.Headers != nil {
userOps = h.Headers.Request
}
transportOps := h.transportHeaderOps
if transportOps != nil || userOps != nil {
r.Header = make(http.Header)
copyHeader(r.Header, reqHeader)
r.Host = reqHost
h.Headers.Request.ApplyToRequest(r)
if transportOps != nil {
transportOps.ApplyToRequest(r)
}
if userOps != nil {
userOps.ApplyToRequest(r)
}
}
// proxy the request to that upstream
@@ -758,48 +833,71 @@ func (h Handler) prepareRequest(req *http.Request, repl *caddy.Replacer) (*http.
// the headers at all, then they will be added with the values
// that we can glean from the request.
func (h Handler) addForwardedHeaders(req *http.Request) error {
// Parse the remote IP, ignore the error as non-fatal,
// but the remote IP is required to continue, so we
// just return early. This should probably never happen
// though, unless some other module manipulated the request's
// remote address and used an invalid value.
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
// Remove the `X-Forwarded-*` headers to avoid upstreams
// potentially trusting a header that came from the client
req.Header.Del("X-Forwarded-For")
req.Header.Del("X-Forwarded-Proto")
req.Header.Del("X-Forwarded-Host")
return nil
}
// Client IP may contain a zone if IPv6, so we need
// to pull that out before parsing the IP
clientIP, _, _ = strings.Cut(clientIP, "%")
ipAddr, err := netip.ParseAddr(clientIP)
if err != nil {
return fmt.Errorf("invalid IP address: '%s': %v", clientIP, err)
}
// Check if the client is a trusted proxy
trusted := caddyhttp.GetVar(req.Context(), caddyhttp.TrustedProxyVarKey).(bool)
for _, ipRange := range h.trustedProxies {
if ipRange.Contains(ipAddr) {
trusted = true
break
var clientIP string
if req.RemoteAddr == "@" {
// For Unix socket connections, RemoteAddr is "@" which cannot
// be parsed as host:port. If untrusted, strip forwarded headers
// for security. If trusted, there is no peer IP to append to
// X-Forwarded-For, so clientIP stays empty.
if !trusted {
req.Header.Del("X-Forwarded-For")
req.Header.Del("X-Forwarded-Proto")
req.Header.Del("X-Forwarded-Host")
return nil
}
} else {
// Parse the remote IP, ignore the error as non-fatal,
// but the remote IP is required to continue, so we
// just return early. This should probably never happen
// though, unless some other module manipulated the request's
// remote address and used an invalid value.
var err error
clientIP, _, err = net.SplitHostPort(req.RemoteAddr)
if err != nil {
// Remove the `X-Forwarded-*` headers to avoid upstreams
// potentially trusting a header that came from the client
req.Header.Del("X-Forwarded-For")
req.Header.Del("X-Forwarded-Proto")
req.Header.Del("X-Forwarded-Host")
return nil
}
// Client IP may contain a zone if IPv6, so we need
// to pull that out before parsing the IP
clientIP, _, _ = strings.Cut(clientIP, "%")
ipAddr, err := netip.ParseAddr(clientIP)
// If ParseAddr fails (e.g. non-IP network like SCION), we cannot check
// if it is a trusted proxy by IP range. In this case, we ignore the
// error and treat the connection as untrusted (or retain existing status).
if err == nil {
for _, ipRange := range h.trustedProxies {
if ipRange.Contains(ipAddr) {
trusted = true
break
}
}
}
}
// If we aren't the first proxy, and the proxy is trusted,
// retain prior X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
clientXFF := clientIP
prior, ok, omit := allHeaderValues(req.Header, "X-Forwarded-For")
if trusted && ok && prior != "" {
clientXFF = prior + ", " + clientXFF
}
if !omit {
req.Header.Set("X-Forwarded-For", clientXFF)
if trusted && ok && prior != "" {
if clientIP != "" {
req.Header.Set("X-Forwarded-For", prior+", "+clientIP)
} else {
req.Header.Set("X-Forwarded-For", prior)
}
} else if clientIP != "" {
req.Header.Set("X-Forwarded-For", clientIP)
}
}
// Set X-Forwarded-Proto; many backend apps expect this,
@@ -838,8 +936,16 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
// Go standard library which was used as the foundation.)
func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origReq *http.Request, repl *caddy.Replacer, di DialInfo, next caddyhttp.Handler) error {
_ = di.Upstream.Host.countRequest(1)
// Increment the in-flight request count
incInFlightRequest(di.Address)
//nolint:errcheck
defer di.Upstream.Host.countRequest(-1)
defer func() {
di.Upstream.Host.countRequest(-1)
// Decrement the in-flight request count
decInFlightRequest(di.Address)
}()
// point the request to this upstream
h.directRequest(req, di)
@@ -1198,7 +1304,7 @@ func (lb LoadBalancing) tryAgain(ctx caddy.Context, start time.Time, retries int
// directRequest modifies only req.URL so that it points to the upstream
// in the given DialInfo. It must modify ONLY the request URL.
func (Handler) directRequest(req *http.Request, di DialInfo) {
func (h *Handler) directRequest(req *http.Request, di DialInfo) {
// we need a host, so set the upstream's host address
reqHost := di.Address
@@ -1209,12 +1315,31 @@ func (Handler) directRequest(req *http.Request, di DialInfo) {
reqHost = di.Host
}
// add client address to the host to let transport differentiate requests from different clients
if ppt, ok := h.Transport.(ProxyProtocolTransport); ok && ppt.ProxyProtocolEnabled() {
if proxyProtocolInfo, ok := caddyhttp.GetVar(req.Context(), proxyProtocolInfoVarKey).(ProxyProtocolInfo); ok {
// encode the request so it plays well with h2 transport, it's unnecessary for h1 but anyway
// The issue is that h2 transport will use the address to determine if new connections are needed
// to roundtrip requests but the without escaping, new connections are constantly created and closed until
// file descriptors are exhausted.
// see: https://github.com/caddyserver/caddy/issues/7529
reqHost = url.QueryEscape(proxyProtocolInfo.AddrPort.String() + "->" + reqHost)
}
}
req.URL.Host = reqHost
}
func (h Handler) provisionUpstream(upstream *Upstream) {
// create or get the host representation for this upstream
upstream.fillHost()
func (h Handler) provisionUpstream(upstream *Upstream, dynamic bool) {
// create or get the host representation for this upstream;
// dynamic upstreams are tracked in a separate map with last-seen
// timestamps so their health state persists across requests without
// being reference-counted (and thus discarded between requests).
if dynamic {
upstream.fillDynamicHost()
} else {
upstream.fillHost()
}
// give it the circuit breaker, if any
upstream.cb = h.CB
@@ -1494,6 +1619,43 @@ type TLSTransport interface {
EnableTLS(base *TLSConfig) error
}
// H2CTransport is implemented by transports
// that are capable of using h2c.
type H2CTransport interface {
EnableH2C() error
}
// ProxyProtocolTransport is implemented by transports
// that are capable of using proxy protocol.
type ProxyProtocolTransport interface {
ProxyProtocolEnabled() bool
}
// HealthCheckSchemeOverriderTransport is implemented by transports
// that can override the scheme used for health checks.
type HealthCheckSchemeOverriderTransport interface {
OverrideHealthCheckScheme(base *url.URL, port string)
}
// BufferedTransport is implemented by transports
// that needs to buffer requests and/or responses.
type BufferedTransport interface {
// DefaultBufferSizes returns the default buffer sizes
// for requests and responses, respectively if buffering isn't enabled.
DefaultBufferSizes() (int64, int64)
}
// RequestHeaderOpsTransport may be implemented by a transport to provide
// header operations to apply to requests immediately before the RoundTrip.
// For example, overriding the default Host when TLS is enabled.
type RequestHeaderOpsTransport interface {
// RequestHeaderOps allows a transport to provide header operations
// to apply to the request. The transport is asked at provision time
// to return a HeaderOps (or nil) that will be applied before
// user-configured header ops.
RequestHeaderOps() *headers.HeaderOps
}
// roundtripSucceededError is an error type that is returned if the
// roundtrip succeeded, but an error occurred after-the-fact.
type roundtripSucceededError struct{ error }
@@ -1507,7 +1669,12 @@ type bodyReadCloser struct {
}
func (brc bodyReadCloser) Close() error {
bufPool.Put(brc.buf)
// Inside this package this will be set to nil for fully-buffered
// requests due to the possibility of retrial.
if brc.buf != nil {
bufPool.Put(brc.buf)
}
// For fully-buffered bodies, body is nil, so Close is a no-op.
if brc.body != nil {
return brc.body.Close()
}
@@ -20,7 +20,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
weakrand "math/rand"
weakrand "math/rand/v2"
"net"
"net/http"
"strconv"
@@ -225,7 +225,7 @@ func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request, _ http
if !upstream.Available() {
continue
}
j := weakrand.Intn(i + 1) //nolint:gosec
j := weakrand.IntN(i + 1) //nolint:gosec
if j < k {
choices[j] = upstream
}
@@ -274,7 +274,7 @@ func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request, _ http.Resp
// sample: https://en.wikipedia.org/wiki/Reservoir_sampling
if numReqs == leastReqs {
count++
if count == 1 || (weakrand.Int()%count) == 0 { //nolint:gosec
if count == 1 || weakrand.IntN(count) == 0 { //nolint:gosec
bestHost = host
}
}
@@ -312,7 +312,7 @@ func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, _ http.
if n == 0 {
return nil
}
for i := uint32(0); i < n; i++ {
for range n {
robin := atomic.AddUint32(&r.robin, 1)
host := pool[robin%n]
if host.Available() {
@@ -617,7 +617,7 @@ type CookieHashSelection struct {
// The HTTP cookie name whose value is to be hashed and used for upstream selection.
Name string `json:"name,omitempty"`
// Secret to hash (Hmac256) chosen upstream in cookie
Secret string `json:"secret,omitempty"`
Secret string `json:"secret,omitempty"` //nolint:gosec // yes it's exported because it needs to encode to JSON
// The cookie's Max-Age before it expires. Default is no expiry.
MaxAge caddy.Duration `json:"max_age,omitempty"`
@@ -788,7 +788,7 @@ func selectRandomHost(pool []*Upstream) *Upstream {
// upstream will always be chosen if there is at
// least one available
count++
if (weakrand.Int() % count) == 0 { //nolint:gosec
if weakrand.IntN(count) == 0 { //nolint:gosec
randomHost = upstream
}
}
@@ -827,7 +827,7 @@ func leastRequests(upstreams []*Upstream) *Upstream {
if len(best) == 1 {
return best[0]
}
return best[weakrand.Intn(len(best))] //nolint:gosec
return best[weakrand.IntN(len(best))] //nolint:gosec
}
// hostByHashing returns an available host from pool based on a hashable string s.
+7 -4
View File
@@ -24,7 +24,7 @@ import (
"errors"
"fmt"
"io"
weakrand "math/rand"
weakrand "math/rand/v2"
"mime"
"net/http"
"sync"
@@ -214,7 +214,10 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
timeoutc = timer.C
}
errc := make(chan error, 1)
// when a stream timeout is encountered, no error will be read from errc
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
// see: https://github.com/caddyserver/caddy/issues/7418
errc := make(chan error, 2)
wg.Add(2)
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
@@ -526,14 +529,14 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
k[i] = key[(pos+i)&3] // nolint:gosec // false positive, impossible to be out of bounds; see: https://github.com/securego/gosec/issues/1525
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
*(*uintptr)(unsafe.Add(unsafe.Pointer(&b[0]), i)) ^= kw
}
// Mask one byte at a time for remaining bytes.
+11 -3
View File
@@ -4,7 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
weakrand "math/rand"
weakrand "math/rand/v2"
"net"
"net/http"
"strconv"
@@ -70,6 +70,11 @@ type SRVUpstreams struct {
// A negative value disables this.
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
// Specific network to dial when connecting to the upstream(s)
// provided by SRV records upstream. See Go's net package for
// accepted values. For example, to restrict to IPv4, use "tcp4".
DialNetwork string `json:"dial_network,omitempty"`
resolver *net.Resolver
logger *zap.Logger
@@ -102,7 +107,7 @@ func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
//nolint:gosec
addr := su.Resolver.netAddrs[weakrand.Intn(len(su.Resolver.netAddrs))]
addr := su.Resolver.netAddrs[weakrand.IntN(len(su.Resolver.netAddrs))]
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
},
}
@@ -177,6 +182,9 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
)
}
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
if su.DialNetwork != "" {
addr = su.DialNetwork + "/" + addr
}
upstreams[i] = Upstream{Dial: addr}
}
@@ -322,7 +330,7 @@ func (au *AUpstreams) Provision(ctx caddy.Context) error {
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
//nolint:gosec
addr := au.Resolver.netAddrs[weakrand.Intn(len(au.Resolver.netAddrs))]
addr := au.Resolver.netAddrs[weakrand.IntN(len(au.Resolver.netAddrs))]
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
},
}
+1
View File
@@ -173,6 +173,7 @@ func parseCaddyfileURI(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, err
if hasArgs {
return nil, h.Err("Cannot specify uri query rewrites in both argument and block")
}
// nolint:prealloc
queryArgs := []string{h.Val()}
queryArgs = append(queryArgs, h.RemainingArgs()...)
err := applyQueryOps(h, rewr.Query, queryArgs)
+1
View File
@@ -247,6 +247,7 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
} else {
r.URL.Path = path
}
r.URL.RawPath = "" // force recomputing when EscapedPath() is called
}
if qsStart >= 0 {
r.URL.RawQuery = newQuery
@@ -224,6 +224,11 @@ func TestRewrite(t *testing.T) {
input: newRequest(t, "GET", "/foo#fragFirst?c=d"),
expect: newRequest(t, "GET", "/bar#fragFirst?c=d"),
},
{
rule: Rewrite{URI: "/api/admin/panel"},
input: newRequest(t, "GET", "/api/admin%2Fpanel"),
expect: newRequest(t, "GET", "/api/admin/panel"),
},
{
rule: Rewrite{StripPathPrefix: "/prefix"},
+38 -18
View File
@@ -18,6 +18,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/caddyserver/caddy/v2"
)
@@ -96,7 +97,10 @@ type Route struct {
MatcherSets MatcherSets `json:"-"`
Handlers []MiddlewareHandler `json:"-"`
middleware []Middleware
middleware []Middleware
metrics *Metrics
metricsCtx caddy.Context
handlerName string
}
// Empty returns true if the route has all zero/default values.
@@ -110,14 +114,16 @@ func (r Route) Empty() bool {
}
func (r Route) String() string {
handlersRaw := "["
var handlersRaw strings.Builder
handlersRaw.WriteByte('[')
for _, hr := range r.HandlersRaw {
handlersRaw += " " + string(hr)
handlersRaw.WriteByte(' ')
handlersRaw.WriteString(string(hr))
}
handlersRaw += "]"
handlersRaw.WriteByte(']')
return fmt.Sprintf(`{Group:"%s" MatcherSetsRaw:%s HandlersRaw:%s Terminal:%t}`,
r.Group, r.MatcherSetsRaw, handlersRaw, r.Terminal)
r.Group, r.MatcherSetsRaw, handlersRaw.String(), r.Terminal)
}
// Provision sets up both the matchers and handlers in the route.
@@ -159,12 +165,20 @@ func (r *Route) ProvisionHandlers(ctx caddy.Context, metrics *Metrics) error {
r.Handlers = append(r.Handlers, handler.(MiddlewareHandler))
}
// Store metrics info for route-level instrumentation (applied once
// per route in wrapRoute, instead of per-handler which was redundant).
r.metrics = metrics
r.metricsCtx = ctx
if len(r.Handlers) > 0 {
r.handlerName = caddy.GetModuleName(r.Handlers[0])
}
// Make ProvisionHandlers idempotent by clearing the middleware field
r.middleware = []Middleware{}
// pre-compile the middleware handler chain
for _, midhandler := range r.Handlers {
r.middleware = append(r.middleware, wrapMiddleware(ctx, midhandler, metrics))
r.middleware = append(r.middleware, wrapMiddleware(ctx, midhandler))
}
return nil
}
@@ -295,6 +309,16 @@ func wrapRoute(route Route) Middleware {
nextCopy = route.middleware[i](nextCopy)
}
// Apply metrics instrumentation once for the entire route,
// rather than wrapping each individual handler. This avoids
// redundant metrics collection that caused significant CPU
// overhead (see issue #4644).
if route.metrics != nil {
nextCopy = newMetricsInstrumentedRoute(
route.metricsCtx, route.handlerName, nextCopy, route.metrics,
)
}
return nextCopy.ServeHTTP(rw, req)
})
}
@@ -303,20 +327,14 @@ func wrapRoute(route Route) Middleware {
// wrapMiddleware wraps mh such that it can be correctly
// appended to a list of middleware in preparation for
// compiling into a handler chain.
func wrapMiddleware(ctx caddy.Context, mh MiddlewareHandler, metrics *Metrics) Middleware {
handlerToUse := mh
if metrics != nil {
// wrap the middleware with metrics instrumentation
handlerToUse = newMetricsInstrumentedHandler(ctx, caddy.GetModuleName(mh), mh, metrics)
}
func wrapMiddleware(ctx caddy.Context, mh MiddlewareHandler) Middleware {
return func(next Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
// EXPERIMENTAL: Trace each module that gets invoked
if server, ok := r.Context().Value(ServerCtxKey).(*Server); ok && server != nil {
server.logTrace(handlerToUse)
server.logTrace(mh)
}
return handlerToUse.ServeHTTP(w, r, next)
return mh.ServeHTTP(w, r, next)
})
}
}
@@ -440,13 +458,15 @@ func (ms *MatcherSets) FromInterface(matcherSets any) error {
// TODO: Is this used?
func (ms MatcherSets) String() string {
result := "["
var result strings.Builder
result.WriteByte('[')
for _, matcherSet := range ms {
for _, matcher := range matcherSet {
result += fmt.Sprintf(" %#v", matcher)
fmt.Fprintf(&result, " %#v", matcher)
}
}
return result + " ]"
result.WriteByte(']')
return result.String()
}
var routeGroupCtxKey = caddy.CtxKey("route_group")
+117 -77
View File
@@ -18,6 +18,7 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@@ -33,7 +34,7 @@ import (
"github.com/caddyserver/certmagic"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/qlog"
h3qlog "github.com/quic-go/quic-go/http3/qlog"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -55,6 +56,10 @@ type Server struct {
// of the base listener. They are applied in the given order.
ListenerWrappersRaw []json.RawMessage `json:"listener_wrappers,omitempty" caddy:"namespace=caddy.listeners inline_key=wrapper"`
// A list of packet conn wrapper modules, which can modify the behavior
// of the base packet conn. They are applied in the given order.
PacketConnWrappersRaw []json.RawMessage `json:"packet_conn_wrappers,omitempty" caddy:"namespace=caddy.packetconns inline_key=wrapper"`
// How long to allow a read from a client's upload. Setting this
// to a short, non-zero value can mitigate slowloris attacks, but
// may also affect legitimately slow clients.
@@ -248,6 +253,16 @@ type Server struct {
// A nil value or element indicates that Protocols will be used instead.
ListenProtocols [][]string `json:"listen_protocols,omitempty"`
// If set, overrides whether QUIC listeners allow 0-RTT (early data).
// If nil, the default behavior is used (currently allowed).
//
// One reason to disable 0-RTT is if a remote IP matcher is used,
// which introduces a dependency on the remote address being verified
// if routing happens before the TLS handshake completes. An HTTP 425
// response is written in that case, but some clients misbehave and
// don't perform a retry, so disabling 0-RTT can smooth it out.
Allow0RTT *bool `json:"allow_0rtt,omitempty"`
// If set, metrics observations will be enabled.
// This setting is EXPERIMENTAL and subject to change.
// DEPRECATED: Use the app-level `metrics` field.
@@ -258,7 +273,8 @@ type Server struct {
primaryHandlerChain Handler
errorHandlerChain Handler
listenerWrappers []caddy.ListenerWrapper
listeners []net.Listener // stdlib http.Server will close these
packetConnWrappers []caddy.PacketConnWrapper
listeners []net.Listener
quicListeners []http3.QUICListener // http3 now leave the quic.Listener management to us
tlsApp *caddytls.TLS
@@ -285,8 +301,15 @@ type Server struct {
onStopFuncs []func(context.Context) error // TODO: Experimental (Nov. 2023)
}
var (
ServerHeader = "Caddy"
serverHeader = []string{ServerHeader}
)
// ServeHTTP is the entry point for all HTTP requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// If there are listener wrappers that process tls connections but don't return a *tls.Conn, this field will be nil.
if r.TLS == nil {
if tlsConnStateFunc, ok := r.Context().Value(tlsConnectionStateFuncCtxKey).(func() *tls.ConnectionState); ok {
@@ -294,55 +317,37 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
w.Header().Set("Server", "Caddy")
// advertise HTTP/3, if enabled
if s.h3server != nil {
if r.ProtoMajor < 3 {
err := s.h3server.SetQUICHeaders(w.Header())
if err != nil {
if c := s.logger.Check(zapcore.ErrorLevel, "setting HTTP/3 Alt-Svc header"); c != nil {
c.Write(zap.Error(err))
}
}
}
}
// reject very long methods; probably a mistake or an attack
if len(r.Method) > 32 {
if s.shouldLogRequest(r) {
if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil {
c.Write(
zap.String("method_trunc", r.Method[:32]),
zap.String("remote_addr", r.RemoteAddr),
)
}
}
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
repl := caddy.NewReplacer()
r = PrepareRequest(r, repl, w, s)
// enable full-duplex for HTTP/1, ensuring the entire
// request body gets consumed before writing the response
if s.EnableFullDuplex && r.ProtoMajor == 1 {
//nolint:bodyclose
err := http.NewResponseController(w).EnableFullDuplex()
if err != nil {
if err := http.NewResponseController(w).EnableFullDuplex(); err != nil { //nolint:bodyclose
if c := s.logger.Check(zapcore.WarnLevel, "failed to enable full duplex"); c != nil {
c.Write(zap.Error(err))
}
}
}
// clone the request for logging purposes before
// it enters any handler chain; this is necessary
// to capture the original request in case it gets
// modified during handling
// cloning the request and using .WithLazy is considerably faster
// than using .With, which will JSON encode the request immediately
// set the Server header
h := w.Header()
h["Server"] = serverHeader
// advertise HTTP/3, if enabled
if s.h3server != nil && r.ProtoMajor < 3 {
if err := s.h3server.SetQUICHeaders(h); err != nil {
if c := s.logger.Check(zapcore.ErrorLevel, "setting HTTP/3 Alt-Svc header"); c != nil {
c.Write(zap.Error(err))
}
}
}
// prepare internals of the request for the handler pipeline
repl := caddy.NewReplacer()
r = PrepareRequest(r, repl, w, s)
// clone the request for logging purposes before it enters any handler chain;
// this is necessary to capture the original request in case it gets modified
// during handling (cloning the request and using .WithLazy is considerably
// faster than using .With, which will JSON-encode the request immediately)
shouldLogCredentials := s.Logs != nil && s.Logs.ShouldLogCredentials
loggableReq := zap.Object("request", LoggableHTTPRequest{
Request: r.Clone(r.Context()),
@@ -370,36 +375,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// capture the original version of the request
accLog := s.accessLogger.With(loggableReq)
accLog := s.accessLogger.WithLazy(loggableReq)
defer s.logRequest(accLog, r, wrec, &duration, repl, bodyReader, shouldLogCredentials)
}
start := time.Now()
// guarantee ACME HTTP challenges; handle them
// separately from any user-defined handlers
// guarantee ACME HTTP challenges; handle them separately from any user-defined handlers
if s.tlsApp.HandleHTTPChallenge(w, r) {
duration = time.Since(start)
return
}
// execute the primary handler chain
err := s.primaryHandlerChain.ServeHTTP(w, r)
err := s.serveHTTP(w, r)
duration = time.Since(start)
// if no errors, we're done!
if err == nil {
return
}
// restore original request before invoking error handler chain (issue #3717)
// TODO: this does not restore original headers, if modified (for efficiency)
origReq := r.Context().Value(OriginalRequestCtxKey).(http.Request)
r.Method = origReq.Method
r.RemoteAddr = origReq.RemoteAddr
r.RequestURI = origReq.RequestURI
cloneURL(origReq.URL, r.URL)
// NOTE: this does not restore original headers if modified (for efficiency)
origReq, ok := r.Context().Value(OriginalRequestCtxKey).(http.Request)
if ok {
r.Method = origReq.Method
r.RemoteAddr = origReq.RemoteAddr
r.RequestURI = origReq.RequestURI
cloneURL(origReq.URL, r.URL)
}
// prepare the error log
errLog = errLog.With(zap.Duration("duration", duration))
@@ -417,10 +419,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var fields []zapcore.Field
if s.Errors != nil && len(s.Errors.Routes) > 0 {
// execute user-defined error handling route
err2 := s.errorHandlerChain.ServeHTTP(w, r)
if err2 == nil {
// user's error route handled the error response
// successfully, so now just log the error
if err2 := s.errorHandlerChain.ServeHTTP(w, r); err2 == nil {
// user's error route handled the error response successfully, so now just log the error
for _, logger := range errLoggers {
if c := logger.Check(zapcore.DebugLevel, errMsg); c != nil {
if fields == nil {
@@ -468,6 +468,35 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
// reject very long methods; probably a mistake or an attack
if len(r.Method) > 32 {
if s.shouldLogRequest(r) {
if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil {
c.Write(
zap.String("method_trunc", r.Method[:32]),
zap.String("remote_addr", r.RemoteAddr),
)
}
}
return HandlerError{StatusCode: http.StatusMethodNotAllowed}
}
// RFC 9112 section 3.2: "A server MUST respond with a 400 (Bad Request) status
// code to any HTTP/1.1 request message that lacks a Host header field and to any
// request message that contains more than one Host header field line or a Host
// header field with an invalid field value."
if r.ProtoMajor == 1 && r.ProtoMinor == 1 && r.Host == "" {
return HandlerError{
Err: errors.New("rfc9112 forbids empty Host"),
StatusCode: http.StatusBadRequest,
}
}
// execute the primary handler chain
return s.primaryHandlerChain.ServeHTTP(w, r)
}
// wrapPrimaryRoute wraps stack (a compiled middleware handler chain)
// in s.enforcementHandler which performs crucial security checks, etc.
func (s *Server) wrapPrimaryRoute(stack Handler) Handler {
@@ -551,15 +580,21 @@ func (s *Server) hasListenerAddress(fullAddr string) bool {
// The second issue seems very similar to a discussion here:
// https://github.com/nodejs/node/issues/9390
//
// This is very easy to reproduce by creating an HTTP server
// that listens to both addresses or just one with a host
// interface; or for a more confusing reproduction, try
// listening on "127.0.0.1:80" and ":443" and you'll see
// the error, if you take away the GOOS condition below.
//
// So, an address is equivalent if the port is in the port
// range, and if not on Linux, the host is the same... sigh.
if (runtime.GOOS == "linux" || thisAddrs.Host == laddrs.Host) &&
// However, binding to *different specific* interfaces
// (e.g. 127.0.0.2:80 and 127.0.0.3:80) IS allowed on Linux.
// The conflict only happens when mixing specific IPs with
// wildcards (0.0.0.0 or ::).
// Hosts match exactly (e.g. 127.0.0.2 == 127.0.0.2) -> Conflict.
hostMatch := thisAddrs.Host == laddrs.Host
// On Linux, specific IP vs Wildcard fails to bind.
// So if we are on Linux AND either host is empty (wildcard), we treat
// it as a match (conflict). But if both are specific and different
// (127.0.0.2 vs 127.0.0.3), this remains false (no conflict).
linuxWildcardConflict := runtime.GOOS == "linux" && (thisAddrs.Host == "" || laddrs.Host == "")
if (hostMatch || linuxWildcardConflict) &&
(laddrs.StartPort <= thisAddrs.EndPort) &&
(laddrs.StartPort >= thisAddrs.StartPort) {
return true
@@ -625,7 +660,7 @@ func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error
return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err)
}
addr.Network = h3net
h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg)
h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg, s.packetConnWrappers, s.Allow0RTT)
if err != nil {
return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err)
}
@@ -638,7 +673,7 @@ func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error
MaxHeaderBytes: s.MaxHeaderBytes,
QUICConfig: &quic.Config{
Versions: []quic.Version{quic.Version1, quic.Version2},
Tracer: qlog.DefaultConnectionTracer,
Tracer: h3qlog.DefaultConnectionTracer,
},
IdleTimeout: time.Duration(s.IdleTimeout),
}
@@ -763,9 +798,11 @@ func (s *Server) shouldLogRequest(r *http.Request) bool {
hostWithoutPort = r.Host
}
if _, ok := s.Logs.LoggerNames[hostWithoutPort]; ok {
// this host is mapped to a particular logger name
return true
for loggerName := range s.Logs.LoggerNames {
if certmagic.MatchWildcard(hostWithoutPort, loggerName) {
// this host is mapped to a particular logger name
return true
}
}
for _, dh := range s.Logs.SkipHosts {
// logging for this particular host is disabled
@@ -793,8 +830,10 @@ func (s *Server) logRequest(
accLog *zap.Logger, r *http.Request, wrec ResponseRecorder, duration *time.Duration,
repl *caddy.Replacer, bodyReader *lengthReader, shouldLogCredentials bool,
) {
ctx := r.Context()
// this request may be flagged as omitted from the logs
if skip, ok := GetVar(r.Context(), LogSkipVar).(bool); ok && skip {
if skip, ok := GetVar(ctx, LogSkipVar).(bool); ok && skip {
return
}
@@ -812,7 +851,7 @@ func (s *Server) logRequest(
}
message := "handled request"
if nop, ok := GetVar(r.Context(), "unhandled").(bool); ok && nop {
if nop, ok := GetVar(ctx, "unhandled").(bool); ok && nop {
message = "NOP"
}
@@ -836,7 +875,7 @@ func (s *Server) logRequest(
reqBodyLength = bodyReader.Length
}
extra := r.Context().Value(ExtraLogFieldsCtxKey).(*ExtraLogFields)
extra := ctx.Value(ExtraLogFieldsCtxKey).(*ExtraLogFields)
fieldCount := 6
fields = make([]zapcore.Field, 0, fieldCount+len(extra.fields))
@@ -1001,6 +1040,7 @@ func isTrustedClientIP(ipAddr netip.Addr, trusted []netip.Prefix) bool {
// then the first value from those headers is used.
func trustedRealClientIP(r *http.Request, headers []string, clientIP string) string {
// Read all the values of the configured client IP headers, in order
// nolint:prealloc
var values []string
for _, field := range headers {
values = append(values, r.Header.Values(field)...)
+11 -2
View File
@@ -246,7 +246,7 @@ func (s StaticResponse) ServeHTTP(w http.ResponseWriter, r *http.Request, next H
// write response body
if statusCode != http.StatusEarlyHints && body != "" {
fmt.Fprint(w, body)
fmt.Fprint(w, body) //nolint:gosec // no XSS unless you sabatoge your own config
}
// continue handling after Early Hints as they are not the final response
@@ -257,7 +257,16 @@ func (s StaticResponse) ServeHTTP(w http.ResponseWriter, r *http.Request, next H
return nil
}
func buildHTTPServer(i int, port uint, addr string, statusCode int, hdr http.Header, body string, accessLog bool) (*Server, error) {
func buildHTTPServer(
i int,
port uint,
addr string,
statusCode int,
hdr http.Header,
body string,
accessLog bool,
) (*Server, error) {
// nolint:prealloc
var handlers []json.RawMessage
// response body supports a basic template; evaluate it
+22 -4
View File
@@ -306,6 +306,13 @@ func init() {
// find the documentation on time layouts [in Go's docs](https://pkg.go.dev/time#pkg-constants).
// The default time layout is `RFC1123Z`, i.e. `Mon, 02 Jan 2006 15:04:05 -0700`.
//
// ```
// {{humanize "size" "2048000"}}
// {{placeholder "http.response.header.Content-Length" | humanize "size"}}
// {{humanize "time" "Fri, 05 May 2022 15:04:05 +0200"}}
// {{humanize "time:2006-Jan-02" "2022-May-05"}}
// ```
//
// ##### `pathEscape`
//
// Passes a string through `url.PathEscape`, replacing characters that have
@@ -318,11 +325,22 @@ func init() {
// {{pathEscape "50%_valid_filename?.jpg"}}
// ```
//
// ##### `maybe`
//
// Invokes a custom template function only if it is registered (plugged-in)
// in the `http.handlers.templates.functions.*` namespace.
//
// The first argument is the function name, and any subsequent arguments
// are forwarded to that function. If the named function is not available,
// the invocation is ignored and a log message is emitted.
//
// This is useful for templates that optionally use components which may
// not be present in every build or environment.
//
// NOTE: This function is EXPERIMENTAL and subject to change or removal.
//
// ```
// {{humanize "size" "2048000"}}
// {{placeholder "http.response.header.Content-Length" | humanize "size"}}
// {{humanize "time" "Fri, 05 May 2022 15:04:05 +0200"}}
// {{humanize "time:2006-Jan-02" "2022-May-05"}}
// {{ maybe "myOptionalFunc" "arg1" 2 }}
// ```
type Templates struct {
// The root path from which to load files. Required if template functions
+31 -6
View File
@@ -27,6 +27,9 @@ type Tracing struct {
// https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/api.md#span
SpanName string `json:"span"`
// SpanAttributes are custom key-value pairs to be added to spans
SpanAttributes map[string]string `json:"span_attributes,omitempty"`
// otel implements opentelemetry related logic.
otel openTelemetryWrapper
@@ -46,7 +49,7 @@ func (ot *Tracing) Provision(ctx caddy.Context) error {
ot.logger = ctx.Logger()
var err error
ot.otel, err = newOpenTelemetryWrapper(ctx, ot.SpanName)
ot.otel, err = newOpenTelemetryWrapper(ctx, ot.SpanName, ot.SpanAttributes)
return err
}
@@ -69,6 +72,10 @@ func (ot *Tracing) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
//
// tracing {
// [span <span_name>]
// [span_attributes {
// attr1 value1
// attr2 value2
// }]
// }
func (ot *Tracing) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
setParameter := func(d *caddyfile.Dispenser, val *string) error {
@@ -94,12 +101,30 @@ func (ot *Tracing) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
for d.NextBlock(0) {
if dst, ok := paramsMap[d.Val()]; ok {
if err := setParameter(d, dst); err != nil {
return err
switch d.Val() {
case "span_attributes":
if ot.SpanAttributes == nil {
ot.SpanAttributes = make(map[string]string)
}
for d.NextBlock(1) {
key := d.Val()
if !d.NextArg() {
return d.ArgErr()
}
value := d.Val()
if d.NextArg() {
return d.ArgErr()
}
ot.SpanAttributes[key] = value
}
default:
if dst, ok := paramsMap[d.Val()]; ok {
if err := setParameter(d, dst); err != nil {
return err
}
} else {
return d.ArgErr()
}
} else {
return d.ArgErr()
}
}
return nil
+220 -4
View File
@@ -2,12 +2,16 @@ package tracing
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
@@ -15,17 +19,26 @@ import (
func TestTracing_UnmarshalCaddyfile(t *testing.T) {
tests := []struct {
name string
spanName string
d *caddyfile.Dispenser
wantErr bool
name string
spanName string
spanAttributes map[string]string
d *caddyfile.Dispenser
wantErr bool
}{
{
name: "Full config",
spanName: "my-span",
spanAttributes: map[string]string{
"attr1": "value1",
"attr2": "value2",
},
d: caddyfile.NewTestDispenser(`
tracing {
span my-span
span_attributes {
attr1 value1
attr2 value2
}
}`),
wantErr: false,
},
@@ -42,6 +55,21 @@ tracing {
name: "Empty config",
d: caddyfile.NewTestDispenser(`
tracing {
}`),
wantErr: false,
},
{
name: "Only span attributes",
spanAttributes: map[string]string{
"service.name": "my-service",
"service.version": "1.0.0",
},
d: caddyfile.NewTestDispenser(`
tracing {
span_attributes {
service.name my-service
service.version 1.0.0
}
}`),
wantErr: false,
},
@@ -56,6 +84,20 @@ tracing {
if ot.SpanName != tt.spanName {
t.Errorf("UnmarshalCaddyfile() SpanName = %v, want SpanName %v", ot.SpanName, tt.spanName)
}
if len(tt.spanAttributes) > 0 {
if ot.SpanAttributes == nil {
t.Errorf("UnmarshalCaddyfile() SpanAttributes is nil, expected %v", tt.spanAttributes)
} else {
for key, expectedValue := range tt.spanAttributes {
if actualValue, exists := ot.SpanAttributes[key]; !exists {
t.Errorf("UnmarshalCaddyfile() SpanAttributes missing key %v", key)
} else if actualValue != expectedValue {
t.Errorf("UnmarshalCaddyfile() SpanAttributes[%v] = %v, want %v", key, actualValue, expectedValue)
}
}
}
}
})
}
}
@@ -79,6 +121,26 @@ func TestTracing_UnmarshalCaddyfile_Error(t *testing.T) {
d: caddyfile.NewTestDispenser(`
tracing {
span
}`),
wantErr: true,
},
{
name: "Span attributes missing value",
d: caddyfile.NewTestDispenser(`
tracing {
span_attributes {
key
}
}`),
wantErr: true,
},
{
name: "Span attributes too many arguments",
d: caddyfile.NewTestDispenser(`
tracing {
span_attributes {
key value extra
}
}`),
wantErr: true,
},
@@ -181,6 +243,160 @@ func TestTracing_ServeHTTP_Next_Error(t *testing.T) {
}
}
func TestTracing_JSON_Configuration(t *testing.T) {
// Test that our struct correctly marshals to and from JSON
original := &Tracing{
SpanName: "test-span",
SpanAttributes: map[string]string{
"service.name": "test-service",
"service.version": "1.0.0",
"env": "test",
},
}
jsonData, err := json.Marshal(original)
if err != nil {
t.Fatalf("Failed to marshal to JSON: %v", err)
}
var unmarshaled Tracing
if err := json.Unmarshal(jsonData, &unmarshaled); err != nil {
t.Fatalf("Failed to unmarshal from JSON: %v", err)
}
if unmarshaled.SpanName != original.SpanName {
t.Errorf("Expected SpanName %s, got %s", original.SpanName, unmarshaled.SpanName)
}
if len(unmarshaled.SpanAttributes) != len(original.SpanAttributes) {
t.Errorf("Expected %d span attributes, got %d", len(original.SpanAttributes), len(unmarshaled.SpanAttributes))
}
for key, expectedValue := range original.SpanAttributes {
if actualValue, exists := unmarshaled.SpanAttributes[key]; !exists {
t.Errorf("Expected span attribute %s to exist", key)
} else if actualValue != expectedValue {
t.Errorf("Expected span attribute %s = %s, got %s", key, expectedValue, actualValue)
}
}
t.Logf("JSON representation: %s", string(jsonData))
}
func TestTracing_OpenTelemetry_Span_Attributes(t *testing.T) {
// Create an in-memory span recorder to capture actual span data
spanRecorder := tracetest.NewSpanRecorder()
provider := trace.NewTracerProvider(
trace.WithSpanProcessor(spanRecorder),
)
// Create our tracing module with span attributes that include placeholders
ot := &Tracing{
SpanName: "test-span",
SpanAttributes: map[string]string{
"static": "test-service",
"request-placeholder": "{http.request.method}",
"response-placeholder": "{http.response.header.X-Some-Header}",
"mixed": "prefix-{http.request.method}-{http.response.header.X-Some-Header}",
},
}
// Create a specific request to test against
req, _ := http.NewRequest("POST", "https://api.example.com/v1/users?id=123", nil)
req.Host = "api.example.com"
w := httptest.NewRecorder()
// Set up the replacer
repl := caddy.NewReplacer()
ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl)
ctx = context.WithValue(ctx, caddyhttp.VarsCtxKey, make(map[string]any))
req = req.WithContext(ctx)
// Set up request placeholders
repl.Set("http.request.method", req.Method)
repl.Set("http.request.uri", req.URL.RequestURI())
// Handler to generate the response
var handler caddyhttp.HandlerFunc = func(writer http.ResponseWriter, request *http.Request) error {
writer.Header().Set("X-Some-Header", "some-value")
writer.WriteHeader(200)
// Make response headers available to replacer
repl.Set("http.response.header.X-Some-Header", writer.Header().Get("X-Some-Header"))
return nil
}
// Set up Caddy context
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()
// Override the global tracer provider with our test provider
// This is a bit hacky but necessary to capture the actual spans
originalProvider := globalTracerProvider
globalTracerProvider = &tracerProvider{
tracerProvider: provider,
tracerProvidersCounter: 1, // Simulate one user
}
defer func() {
globalTracerProvider = originalProvider
}()
// Provision the tracing module
if err := ot.Provision(caddyCtx); err != nil {
t.Errorf("Provision error: %v", err)
t.FailNow()
}
// Execute the request
if err := ot.ServeHTTP(w, req, handler); err != nil {
t.Errorf("ServeHTTP error: %v", err)
}
// Get the recorded spans
spans := spanRecorder.Ended()
if len(spans) == 0 {
t.Fatal("Expected at least one span to be recorded")
}
// Find our span (should be the one with our test span name)
var testSpan trace.ReadOnlySpan
for _, span := range spans {
if span.Name() == "test-span" {
testSpan = span
break
}
}
if testSpan == nil {
t.Fatal("Could not find test span in recorded spans")
}
// Verify that the span attributes were set correctly with placeholder replacement
expectedAttributes := map[string]string{
"static": "test-service",
"request-placeholder": "POST",
"response-placeholder": "some-value",
"mixed": "prefix-POST-some-value",
}
actualAttributes := make(map[string]string)
for _, attr := range testSpan.Attributes() {
actualAttributes[string(attr.Key)] = attr.Value.AsString()
}
for key, expectedValue := range expectedAttributes {
if actualValue, exists := actualAttributes[key]; !exists {
t.Errorf("Expected span attribute %s to be set", key)
} else if actualValue != expectedValue {
t.Errorf("Expected span attribute %s = %s, got %s", key, expectedValue, actualValue)
}
}
t.Logf("Recorded span attributes: %+v", actualAttributes)
}
func createRequestWithContext(method string, url string) *http.Request {
r, _ := http.NewRequest(method, url, nil)
repl := caddy.NewReplacer()
+22 -4
View File
@@ -5,9 +5,10 @@ import (
"fmt"
"net/http"
"go.opentelemetry.io/contrib/exporters/autoexport"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/contrib/propagators/autoprop"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
@@ -37,20 +38,23 @@ type openTelemetryWrapper struct {
handler http.Handler
spanName string
spanName string
spanAttributes map[string]string
}
// newOpenTelemetryWrapper is responsible for the openTelemetryWrapper initialization using provided configuration.
func newOpenTelemetryWrapper(
ctx context.Context,
spanName string,
spanAttributes map[string]string,
) (openTelemetryWrapper, error) {
if spanName == "" {
spanName = defaultSpanName
}
ot := openTelemetryWrapper{
spanName: spanName,
spanName: spanName,
spanAttributes: spanAttributes,
}
version, _ := caddy.Version()
@@ -59,7 +63,7 @@ func newOpenTelemetryWrapper(
return ot, fmt.Errorf("creating resource error: %w", err)
}
traceExporter, err := otlptracegrpc.New(ctx)
traceExporter, err := autoexport.NewSpanExporter(ctx)
if err != nil {
return ot, fmt.Errorf("creating trace exporter error: %w", err)
}
@@ -99,8 +103,22 @@ func (ot *openTelemetryWrapper) serveHTTP(w http.ResponseWriter, r *http.Request
extra.Add(zap.String("spanID", spanID))
}
}
next := ctx.Value(nextCallCtxKey).(*nextCall)
next.err = next.next.ServeHTTP(w, r)
// Add custom span attributes to the current span
span := trace.SpanFromContext(ctx)
if span.IsRecording() && len(ot.spanAttributes) > 0 {
replacer := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
attributes := make([]attribute.KeyValue, 0, len(ot.spanAttributes))
for key, value := range ot.spanAttributes {
// Allow placeholder replacement in attribute values
replacedValue := replacer.ReplaceAll(value, "")
attributes = append(attributes, attribute.String(key, replacedValue))
}
span.SetAttributes(attributes...)
}
}
// ServeHTTP propagates call to the by wrapped by `otelhttp` next handler.
+1
View File
@@ -16,6 +16,7 @@ func TestOpenTelemetryWrapper_newOpenTelemetryWrapper(t *testing.T) {
if otw, err = newOpenTelemetryWrapper(ctx,
"",
nil,
); err != nil {
t.Errorf("newOpenTelemetryWrapper() error = %v", err)
t.FailNow()
+45 -21
View File
@@ -181,33 +181,46 @@ func (m VarsMatcher) MatchWithError(r *http.Request) (bool, error) {
vars := r.Context().Value(VarsCtxKey).(map[string]any)
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
var fromPlaceholder bool
var matcherValExpanded, valExpanded, varStr, v string
var varValue any
for key, vals := range m {
var varValue any
if strings.HasPrefix(key, "{") &&
strings.HasSuffix(key, "}") &&
strings.Count(key, "{") == 1 {
varValue, _ = repl.Get(strings.Trim(key, "{}"))
fromPlaceholder = true
} else {
varValue = vars[key]
fromPlaceholder = false
}
switch vv := varValue.(type) {
case string:
varStr = vv
case fmt.Stringer:
varStr = vv.String()
case error:
varStr = vv.Error()
case nil:
varStr = ""
default:
varStr = fmt.Sprintf("%v", vv)
}
// Only expand placeholders in values from literal variable names
// (e.g. map outputs). Values resolved from placeholder keys are
// already final and must not be re-expanded, as that would allow
// user input like {env.SECRET} to be evaluated.
valExpanded = varStr
if !fromPlaceholder {
valExpanded = repl.ReplaceAll(varStr, "")
}
// see if any of the values given in the matcher match the actual value
for _, v := range vals {
matcherValExpanded := repl.ReplaceAll(v, "")
var varStr string
switch vv := varValue.(type) {
case string:
varStr = vv
case fmt.Stringer:
varStr = vv.String()
case error:
varStr = vv.Error()
case nil:
varStr = ""
default:
varStr = fmt.Sprintf("%v", vv)
}
if varStr == matcherValExpanded {
for _, v = range vals {
matcherValExpanded = repl.ReplaceAll(v, "")
if valExpanded == matcherValExpanded {
return true, nil
}
}
@@ -310,17 +323,21 @@ func (m MatchVarsRE) Match(r *http.Request) bool {
func (m MatchVarsRE) MatchWithError(r *http.Request) (bool, error) {
vars := r.Context().Value(VarsCtxKey).(map[string]any)
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
var fromPlaceholder, match bool
var valExpanded, varStr string
var varValue any
for key, val := range m {
var varValue any
if strings.HasPrefix(key, "{") &&
strings.HasSuffix(key, "}") &&
strings.Count(key, "{") == 1 {
varValue, _ = repl.Get(strings.Trim(key, "{}"))
fromPlaceholder = true
} else {
varValue = vars[key]
fromPlaceholder = false
}
var varStr string
switch vv := varValue.(type) {
case string:
varStr = vv
@@ -334,8 +351,15 @@ func (m MatchVarsRE) MatchWithError(r *http.Request) (bool, error) {
varStr = fmt.Sprintf("%v", vv)
}
valExpanded := repl.ReplaceAll(varStr, "")
if match := val.Match(valExpanded, repl); match {
// Only expand placeholders in values from literal variable names
// (e.g. map outputs). Values resolved from placeholder keys are
// already final and must not be re-expanded, as that would allow
// user input like {env.SECRET} to be evaluated.
valExpanded = varStr
if !fromPlaceholder {
valExpanded = repl.ReplaceAll(varStr, "")
}
if match = val.Match(valExpanded, repl); match {
return match, nil
}
}
+33 -3
View File
@@ -17,7 +17,7 @@ package acmeserver
import (
"context"
"fmt"
weakrand "math/rand"
weakrand "math/rand/v2"
"net"
"net/http"
"os"
@@ -40,6 +40,7 @@ import (
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddypki"
"github.com/caddyserver/caddy/v2/modules/caddytls"
)
func init() {
@@ -140,6 +141,8 @@ func (ash *Handler) Provision(ctx caddy.Context) error {
}
}
ash.warnIfPolicyAllowsAll()
// get a reference to the configured CA
appModule, err := ctx.App("pki")
if err != nil {
@@ -214,6 +217,21 @@ func (ash *Handler) Provision(ctx caddy.Context) error {
return nil
}
func (ash *Handler) warnIfPolicyAllowsAll() {
allow := ash.Policy.normalizeAllowRules()
deny := ash.Policy.normalizeDenyRules()
if allow != nil || deny != nil {
return
}
allowWildcardNames := ash.Policy != nil && ash.Policy.AllowWildcardNames
ash.logger.Warn(
"acme_server policy has no allow/deny rules; order identifiers are unrestricted (allow-all)",
zap.String("ca", ash.CA),
zap.Bool("allow_wildcard_names", allowWildcardNames),
)
}
func (ash Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
if strings.HasPrefix(r.URL.Path, ash.PathPrefix) {
acmeCtx := acme.NewContext(
@@ -287,7 +305,19 @@ func (ash Handler) openDatabase() (*db.AuthDB, error) {
// makeClient creates an ACME client which will use a custom
// resolver instead of net.DefaultResolver.
func (ash Handler) makeClient() (acme.Client, error) {
for _, v := range ash.Resolvers {
// If no local resolvers are configured, check for global resolvers from TLS app
resolversToUse := ash.Resolvers
if len(resolversToUse) == 0 {
tlsAppIface, err := ash.ctx.App("tls")
if err == nil {
tlsApp := tlsAppIface.(*caddytls.TLS)
if len(tlsApp.Resolvers) > 0 {
resolversToUse = tlsApp.Resolvers
}
}
}
for _, v := range resolversToUse {
addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53)
if err != nil {
return nil, err
@@ -307,7 +337,7 @@ func (ash Handler) makeClient() (acme.Client, error) {
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
//nolint:gosec
addr := ash.resolvers[weakrand.Intn(len(ash.resolvers))]
addr := ash.resolvers[weakrand.IntN(len(ash.resolvers))]
return dialer.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
},
}
@@ -0,0 +1,94 @@
package acmeserver
import (
"strings"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
)
func TestHandler_warnIfPolicyAllowsAll(t *testing.T) {
tests := []struct {
name string
policy *Policy
wantWarns int
wantAllowWildcard bool
}{
{
name: "warns when policy is nil",
policy: nil,
wantWarns: 1,
wantAllowWildcard: false,
},
{
name: "warns when allow/deny rules are empty",
policy: &Policy{},
wantWarns: 1,
wantAllowWildcard: false,
},
{
name: "warns when only allow_wildcard_names is true",
policy: &Policy{
AllowWildcardNames: true,
},
wantWarns: 1,
wantAllowWildcard: true,
},
{
name: "does not warn when allow rules are configured",
policy: &Policy{
Allow: &RuleSet{
Domains: []string{"example.com"},
},
},
wantWarns: 0,
wantAllowWildcard: false,
},
{
name: "does not warn when deny rules are configured",
policy: &Policy{
Deny: &RuleSet{
Domains: []string{"bad.example.com"},
},
},
wantWarns: 0,
wantAllowWildcard: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
core, logs := observer.New(zap.WarnLevel)
ash := &Handler{
CA: "local",
Policy: tt.policy,
logger: zap.New(core),
}
ash.warnIfPolicyAllowsAll()
if logs.Len() != tt.wantWarns {
t.Fatalf("expected %d warning logs, got %d", tt.wantWarns, logs.Len())
}
if tt.wantWarns == 0 {
return
}
entry := logs.All()[0]
if entry.Level != zap.WarnLevel {
t.Fatalf("expected warn level, got %v", entry.Level)
}
if !strings.Contains(entry.Message, "policy has no allow/deny rules") {
t.Fatalf("unexpected log message: %q", entry.Message)
}
ctx := entry.ContextMap()
if ctx["ca"] != "local" {
t.Fatalf("expected ca=local, got %v", ctx["ca"])
}
if ctx["allow_wildcard_names"] != tt.wantAllowWildcard {
t.Fatalf("expected allow_wildcard_names=%v, got %v", tt.wantAllowWildcard, ctx["allow_wildcard_names"])
}
})
}
}
+11 -6
View File
@@ -163,9 +163,9 @@ func (a *adminAPI) handleCACerts(w http.ResponseWriter, r *http.Request) error {
}
w.Header().Set("Content-Type", "application/pem-certificate-chain")
_, err = w.Write(interCert)
_, err = w.Write(interCert) //nolint:gosec // false positive... no XSS in a PEM for cryin' out loud
if err == nil {
_, _ = w.Write(rootCert)
_, _ = w.Write(rootCert) //nolint:gosec // false positive... no XSS in a PEM for cryin' out loud
}
return nil
@@ -222,11 +222,16 @@ func rootAndIntermediatePEM(ca *CA) (root, inter []byte, err error) {
if err != nil {
return root, inter, err
}
inter, err = pemEncodeCert(ca.IntermediateCertificate().Raw)
if err != nil {
return root, inter, err
for _, interCert := range ca.IntermediateCertificateChain() {
pemBytes, err := pemEncodeCert(interCert.Raw)
if err != nil {
return nil, nil, err
}
inter = append(inter, pemBytes...)
}
return root, inter, err
return
}
// caInfo is the response structure for the CA info API endpoint.
+60 -30
View File
@@ -63,6 +63,15 @@ type CA struct {
// The intermediate (signing) certificate; if null, one will be generated.
Intermediate *KeyPair `json:"intermediate,omitempty"`
// How often to check if intermediate (and root, when applicable) certificates need renewal.
// Default: 10m.
MaintenanceInterval caddy.Duration `json:"maintenance_interval,omitempty"`
// The fraction of certificate lifetime (0.01.0) after which renewal is attempted.
// For example, 0.2 means renew when 20% of the lifetime remains (e.g. ~73 days for a 1-year cert).
// Default: 0.2.
RenewalWindowRatio float64 `json:"renewal_window_ratio,omitempty"`
// Optionally configure a separate storage module associated with this
// issuer, instead of using Caddy's global/default-configured storage.
// This can be useful if you want to keep your signing keys in a
@@ -75,10 +84,11 @@ type CA struct {
// and module provisioning.
ID string `json:"-"`
storage certmagic.Storage
root, inter *x509.Certificate
interKey any // TODO: should we just store these as crypto.Signer?
mu *sync.RWMutex
storage certmagic.Storage
root *x509.Certificate
interChain []*x509.Certificate
interKey crypto.Signer
mu *sync.RWMutex
rootCertPath string // mainly used for logging purposes if trusting
log *zap.Logger
@@ -125,16 +135,24 @@ func (ca *CA) Provision(ctx caddy.Context, id string, log *zap.Logger) error {
if ca.IntermediateLifetime == 0 {
ca.IntermediateLifetime = caddy.Duration(defaultIntermediateLifetime)
}
if ca.MaintenanceInterval == 0 {
ca.MaintenanceInterval = caddy.Duration(defaultMaintenanceInterval)
}
if ca.RenewalWindowRatio <= 0 || ca.RenewalWindowRatio > 1 {
ca.RenewalWindowRatio = defaultRenewalWindowRatio
}
// load the certs and key that will be used for signing
var rootCert, interCert *x509.Certificate
var rootCert *x509.Certificate
var rootCertChain, interCertChain []*x509.Certificate
var rootKey, interKey crypto.Signer
var err error
if ca.Root != nil {
if ca.Root.Format == "" || ca.Root.Format == "pem_file" {
ca.rootCertPath = ca.Root.Certificate
}
rootCert, rootKey, err = ca.Root.Load()
rootCertChain, rootKey, err = ca.Root.Load()
rootCert = rootCertChain[0]
} else {
ca.rootCertPath = "storage:" + ca.storageKeyRootCert()
rootCert, rootKey, err = ca.loadOrGenRoot()
@@ -142,21 +160,23 @@ func (ca *CA) Provision(ctx caddy.Context, id string, log *zap.Logger) error {
if err != nil {
return err
}
actualRootLifetime := time.Until(rootCert.NotAfter)
if time.Duration(ca.IntermediateLifetime) >= actualRootLifetime {
return fmt.Errorf("intermediate certificate lifetime must be less than actual root certificate lifetime (%s)", actualRootLifetime)
}
if ca.Intermediate != nil {
interCert, interKey, err = ca.Intermediate.Load()
interCertChain, interKey, err = ca.Intermediate.Load()
} else {
interCert, interKey, err = ca.loadOrGenIntermediate(rootCert, rootKey)
actualRootLifetime := time.Until(rootCert.NotAfter)
if time.Duration(ca.IntermediateLifetime) >= actualRootLifetime {
return fmt.Errorf("intermediate certificate lifetime must be less than actual root certificate lifetime (%s)", actualRootLifetime)
}
interCertChain, interKey, err = ca.loadOrGenIntermediate(rootCert, rootKey)
}
if err != nil {
return err
}
ca.mu.Lock()
ca.root, ca.inter, ca.interKey = rootCert, interCert, interKey
ca.root, ca.interChain, ca.interKey = rootCert, interCertChain, interKey
ca.mu.Unlock()
return nil
@@ -172,21 +192,21 @@ func (ca CA) RootCertificate() *x509.Certificate {
// RootKey returns the CA's root private key. Since the root key is
// not cached in memory long-term, it needs to be loaded from storage,
// which could yield an error.
func (ca CA) RootKey() (any, error) {
func (ca CA) RootKey() (crypto.Signer, error) {
_, rootKey, err := ca.loadOrGenRoot()
return rootKey, err
}
// IntermediateCertificate returns the CA's intermediate
// certificate (public key).
func (ca CA) IntermediateCertificate() *x509.Certificate {
// IntermediateCertificateChain returns the CA's intermediate
// certificate chain.
func (ca CA) IntermediateCertificateChain() []*x509.Certificate {
ca.mu.RLock()
defer ca.mu.RUnlock()
return ca.inter
return ca.interChain
}
// IntermediateKey returns the CA's intermediate private key.
func (ca CA) IntermediateKey() any {
func (ca CA) IntermediateKey() crypto.Signer {
ca.mu.RLock()
defer ca.mu.RUnlock()
return ca.interKey
@@ -207,26 +227,27 @@ func (ca *CA) NewAuthority(authorityConfig AuthorityConfig) (*authority.Authorit
// cert/key directly, since it's unlikely to expire
// while Caddy is running (long lifetime)
var issuerCert *x509.Certificate
var issuerKey any
var issuerKey crypto.Signer
issuerCert = rootCert
var err error
issuerKey, err = ca.RootKey()
if err != nil {
return nil, fmt.Errorf("loading signing key: %v", err)
}
signerOption = authority.WithX509Signer(issuerCert, issuerKey.(crypto.Signer))
signerOption = authority.WithX509Signer(issuerCert, issuerKey)
} else {
// if we're signing with intermediate, we need to make
// sure it's always fresh, because the intermediate may
// renew while Caddy is running (medium lifetime)
signerOption = authority.WithX509SignerFunc(func() ([]*x509.Certificate, crypto.Signer, error) {
issuerCert := ca.IntermediateCertificate()
issuerKey := ca.IntermediateKey().(crypto.Signer)
issuerChain := ca.IntermediateCertificateChain()
issuerCert := issuerChain[0]
issuerKey := ca.IntermediateKey()
ca.log.Debug("using intermediate signer",
zap.String("serial", issuerCert.SerialNumber.String()),
zap.String("not_before", issuerCert.NotBefore.String()),
zap.String("not_after", issuerCert.NotAfter.String()))
return []*x509.Certificate{issuerCert}, issuerKey, nil
return issuerChain, issuerKey, nil
})
}
@@ -252,7 +273,11 @@ func (ca *CA) NewAuthority(authorityConfig AuthorityConfig) (*authority.Authorit
func (ca CA) loadOrGenRoot() (rootCert *x509.Certificate, rootKey crypto.Signer, err error) {
if ca.Root != nil {
return ca.Root.Load()
rootChain, rootSigner, err := ca.Root.Load()
if err != nil {
return nil, nil, err
}
return rootChain[0], rootSigner, nil
}
rootCertPEM, err := ca.storage.Load(ca.ctx, ca.storageKeyRootCert())
if err != nil {
@@ -268,7 +293,7 @@ func (ca CA) loadOrGenRoot() (rootCert *x509.Certificate, rootKey crypto.Signer,
}
if rootCert == nil {
rootCert, err = pemDecodeSingleCert(rootCertPEM)
rootCert, err = pemDecodeCertificate(rootCertPEM)
if err != nil {
return nil, nil, fmt.Errorf("parsing root certificate PEM: %v", err)
}
@@ -314,7 +339,8 @@ func (ca CA) genRoot() (rootCert *x509.Certificate, rootKey crypto.Signer, err e
return rootCert, rootKey, nil
}
func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Signer) (interCert *x509.Certificate, interKey crypto.Signer, err error) {
func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Signer) (interCertChain []*x509.Certificate, interKey crypto.Signer, err error) {
var interCert *x509.Certificate
interCertPEM, err := ca.storage.Load(ca.ctx, ca.storageKeyIntermediateCert())
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
@@ -326,10 +352,12 @@ func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Si
if err != nil {
return nil, nil, fmt.Errorf("generating new intermediate cert: %v", err)
}
interCertChain = append(interCertChain, interCert)
}
if interCert == nil {
interCert, err = pemDecodeSingleCert(interCertPEM)
if len(interCertChain) == 0 {
interCertChain, err = pemDecodeCertificateChain(interCertPEM)
if err != nil {
return nil, nil, fmt.Errorf("decoding intermediate certificate PEM: %v", err)
}
@@ -346,7 +374,7 @@ func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Si
}
}
return interCert, interKey, nil
return interCertChain, interKey, nil
}
func (ca CA) genIntermediate(rootCert *x509.Certificate, rootKey crypto.Signer) (interCert *x509.Certificate, interKey crypto.Signer, err error) {
@@ -443,4 +471,6 @@ const (
defaultRootLifetime = 24 * time.Hour * 30 * 12 * 10
defaultIntermediateLifetime = 24 * time.Hour * 7
defaultMaintenanceInterval = 10 * time.Minute
defaultRenewalWindowRatio = 0.2
)
+61 -6
View File
@@ -17,15 +17,20 @@ package caddypki
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
"github.com/caddyserver/certmagic"
"go.step.sm/crypto/pemutil"
)
func pemDecodeSingleCert(pemDER []byte) (*x509.Certificate, error) {
func pemDecodeCertificate(pemDER []byte) (*x509.Certificate, error) {
pemBlock, remaining := pem.Decode(pemDER)
if pemBlock == nil {
return nil, fmt.Errorf("no PEM block found")
@@ -39,6 +44,15 @@ func pemDecodeSingleCert(pemDER []byte) (*x509.Certificate, error) {
return x509.ParseCertificate(pemBlock.Bytes)
}
func pemDecodeCertificateChain(pemDER []byte) ([]*x509.Certificate, error) {
chain, err := pemutil.ParseCertificateBundle(pemDER)
if err != nil {
return nil, fmt.Errorf("failed parsing certificate chain: %w", err)
}
return chain, nil
}
func pemEncodeCert(der []byte) ([]byte, error) {
return pemEncode("CERTIFICATE", der)
}
@@ -63,22 +77,25 @@ type KeyPair struct {
// The private key. By default, this should be the path to
// a PEM file unless format is something else.
PrivateKey string `json:"private_key,omitempty"`
PrivateKey string `json:"private_key,omitempty"` //nolint:gosec // false positive: yes it's exported, since it needs to encode/decode as JSON; and is often just a filepath
// The format in which the certificate and private
// key are provided. Default: pem_file
Format string `json:"format,omitempty"`
}
// Load loads the certificate and key.
func (kp KeyPair) Load() (*x509.Certificate, crypto.Signer, error) {
// Load loads the certificate chain and (optional) private key from
// the corresponding files, using the configured format. If a
// private key is read, it will be verified to belong to the first
// certificate in the chain.
func (kp KeyPair) Load() ([]*x509.Certificate, crypto.Signer, error) {
switch kp.Format {
case "", "pem_file":
certData, err := os.ReadFile(kp.Certificate)
if err != nil {
return nil, nil, err
}
cert, err := pemDecodeSingleCert(certData)
chain, err := pemDecodeCertificateChain(certData)
if err != nil {
return nil, nil, err
}
@@ -93,11 +110,49 @@ func (kp KeyPair) Load() (*x509.Certificate, crypto.Signer, error) {
if err != nil {
return nil, nil, err
}
if err := verifyKeysMatch(chain[0], key); err != nil {
return nil, nil, err
}
}
return cert, key, nil
return chain, key, nil
default:
return nil, nil, fmt.Errorf("unsupported format: %s", kp.Format)
}
}
// verifyKeysMatch verifies that the public key in the [x509.Certificate] matches
// the public key of the [crypto.Signer].
func verifyKeysMatch(crt *x509.Certificate, signer crypto.Signer) error {
switch pub := crt.PublicKey.(type) {
case *rsa.PublicKey:
pk, ok := signer.Public().(*rsa.PublicKey)
if !ok {
return fmt.Errorf("private key type %T does not match issuer public key type %T", signer.Public(), pub)
}
if !pub.Equal(pk) {
return errors.New("private key does not match issuer public key")
}
case *ecdsa.PublicKey:
pk, ok := signer.Public().(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("private key type %T does not match issuer public key type %T", signer.Public(), pub)
}
if !pub.Equal(pk) {
return errors.New("private key does not match issuer public key")
}
case ed25519.PublicKey:
pk, ok := signer.Public().(ed25519.PublicKey)
if !ok {
return fmt.Errorf("private key type %T does not match issuer public key type %T", signer.Public(), pub)
}
if !pub.Equal(pk) {
return errors.New("private key does not match issuer public key")
}
default:
return fmt.Errorf("unsupported key type: %T", pub)
}
return nil
}
+314
View File
@@ -0,0 +1,314 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddypki
import (
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"os"
"path/filepath"
"testing"
"time"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
)
func TestKeyPair_Load(t *testing.T) {
rootSigner, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Failed creating signer: %v", err)
}
tmpl := &x509.Certificate{
Subject: pkix.Name{CommonName: "test-root"},
IsCA: true,
MaxPathLen: 3,
}
rootBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, rootSigner.Public(), rootSigner)
if err != nil {
t.Fatalf("Creating root certificate failed: %v", err)
}
root, err := x509.ParseCertificate(rootBytes)
if err != nil {
t.Fatalf("Parsing root certificate failed: %v", err)
}
intermediateSigner, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Creating intermedaite signer failed: %v", err)
}
intermediateBytes, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
Subject: pkix.Name{CommonName: "test-first-intermediate"},
IsCA: true,
MaxPathLen: 2,
NotAfter: time.Now().Add(time.Hour),
}, root, intermediateSigner.Public(), rootSigner)
if err != nil {
t.Fatalf("Creating intermediate certificate failed: %v", err)
}
intermediate, err := x509.ParseCertificate(intermediateBytes)
if err != nil {
t.Fatalf("Parsing intermediate certificate failed: %v", err)
}
var chainContents []byte
chain := []*x509.Certificate{intermediate, root}
for _, cert := range chain {
b, err := pemutil.Serialize(cert)
if err != nil {
t.Fatalf("Failed serializing intermediate certificate: %v", err)
}
chainContents = append(chainContents, pem.EncodeToMemory(b)...)
}
dir := t.TempDir()
rootCertFile := filepath.Join(dir, "root.pem")
if _, err = pemutil.Serialize(root, pemutil.WithFilename(rootCertFile)); err != nil {
t.Fatalf("Failed serializing root certificate: %v", err)
}
rootKeyFile := filepath.Join(dir, "root.key")
if _, err = pemutil.Serialize(rootSigner, pemutil.WithFilename(rootKeyFile)); err != nil {
t.Fatalf("Failed serializing root key: %v", err)
}
intermediateCertFile := filepath.Join(dir, "intermediate.pem")
if _, err = pemutil.Serialize(intermediate, pemutil.WithFilename(intermediateCertFile)); err != nil {
t.Fatalf("Failed serializing intermediate certificate: %v", err)
}
intermediateKeyFile := filepath.Join(dir, "intermediate.key")
if _, err = pemutil.Serialize(intermediateSigner, pemutil.WithFilename(intermediateKeyFile)); err != nil {
t.Fatalf("Failed serializing intermediate key: %v", err)
}
chainFile := filepath.Join(dir, "chain.pem")
if err := os.WriteFile(chainFile, chainContents, 0644); err != nil {
t.Fatalf("Failed writing intermediate chain: %v", err)
}
t.Run("ok/single-certificate-without-signer", func(t *testing.T) {
kp := KeyPair{
Certificate: rootCertFile,
}
chain, signer, err := kp.Load()
if err != nil {
t.Fatalf("Failed loading KeyPair: %v", err)
}
if len(chain) != 1 {
t.Errorf("Expected 1 certificate in chain; got %d", len(chain))
}
if signer != nil {
t.Error("Expected no signer to be returned")
}
})
t.Run("ok/single-certificate-with-signer", func(t *testing.T) {
kp := KeyPair{
Certificate: rootCertFile,
PrivateKey: rootKeyFile,
}
chain, signer, err := kp.Load()
if err != nil {
t.Fatalf("Failed loading KeyPair: %v", err)
}
if len(chain) != 1 {
t.Errorf("Expected 1 certificate in chain; got %d", len(chain))
}
if signer == nil {
t.Error("Expected signer to be returned")
}
})
t.Run("ok/multiple-certificates-with-signer", func(t *testing.T) {
kp := KeyPair{
Certificate: chainFile,
PrivateKey: intermediateKeyFile,
}
chain, signer, err := kp.Load()
if err != nil {
t.Fatalf("Failed loading KeyPair: %v", err)
}
if len(chain) != 2 {
t.Errorf("Expected 2 certificates in chain; got %d", len(chain))
}
if signer == nil {
t.Error("Expected signer to be returned")
}
})
t.Run("fail/non-matching-public-key", func(t *testing.T) {
kp := KeyPair{
Certificate: intermediateCertFile,
PrivateKey: rootKeyFile,
}
chain, signer, err := kp.Load()
if err == nil {
t.Error("Expected loading KeyPair to return an error")
}
if chain != nil {
t.Error("Expected no chain to be returned")
}
if signer != nil {
t.Error("Expected no signer to be returned")
}
})
}
func Test_pemDecodeCertificate(t *testing.T) {
signer, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Failed creating signer: %v", err)
}
tmpl := &x509.Certificate{
Subject: pkix.Name{CommonName: "test-cert"},
IsCA: true,
MaxPathLen: 3,
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, signer.Public(), signer)
if err != nil {
t.Fatalf("Creating root certificate failed: %v", err)
}
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
t.Fatalf("Parsing root certificate failed: %v", err)
}
pemBlock, err := pemutil.Serialize(cert)
if err != nil {
t.Fatalf("Failed serializing certificate: %v", err)
}
pemData := pem.EncodeToMemory(pemBlock)
t.Run("ok", func(t *testing.T) {
cert, err := pemDecodeCertificate(pemData)
if err != nil {
t.Fatalf("Failed decoding PEM data: %v", err)
}
if cert == nil {
t.Errorf("Expected a certificate in PEM data")
}
})
t.Run("fail/no-pem-data", func(t *testing.T) {
cert, err := pemDecodeCertificate(nil)
if err == nil {
t.Fatalf("Expected pemDecodeCertificate to return an error")
}
if cert != nil {
t.Errorf("Expected pemDecodeCertificate to return nil")
}
})
t.Run("fail/multiple", func(t *testing.T) {
multiplePEMData := append(pemData, pemData...)
cert, err := pemDecodeCertificate(multiplePEMData)
if err == nil {
t.Fatalf("Expected pemDecodeCertificate to return an error")
}
if cert != nil {
t.Errorf("Expected pemDecodeCertificate to return nil")
}
})
t.Run("fail/no-pem-certificate", func(t *testing.T) {
pkData := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: []byte("some-bogus-private-key"),
})
cert, err := pemDecodeCertificate(pkData)
if err == nil {
t.Fatalf("Expected pemDecodeCertificate to return an error")
}
if cert != nil {
t.Errorf("Expected pemDecodeCertificate to return nil")
}
})
}
func Test_pemDecodeCertificateChain(t *testing.T) {
signer, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Failed creating signer: %v", err)
}
tmpl := &x509.Certificate{
Subject: pkix.Name{CommonName: "test-cert"},
IsCA: true,
MaxPathLen: 3,
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, signer.Public(), signer)
if err != nil {
t.Fatalf("Creating root certificate failed: %v", err)
}
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
t.Fatalf("Parsing root certificate failed: %v", err)
}
pemBlock, err := pemutil.Serialize(cert)
if err != nil {
t.Fatalf("Failed serializing certificate: %v", err)
}
pemData := pem.EncodeToMemory(pemBlock)
t.Run("ok/single", func(t *testing.T) {
certs, err := pemDecodeCertificateChain(pemData)
if err != nil {
t.Fatalf("Failed decoding PEM data: %v", err)
}
if len(certs) != 1 {
t.Errorf("Expected 1 certificate in PEM data; got %d", len(certs))
}
})
t.Run("ok/multiple", func(t *testing.T) {
multiplePEMData := append(pemData, pemData...)
certs, err := pemDecodeCertificateChain(multiplePEMData)
if err != nil {
t.Fatalf("Failed decoding PEM data: %v", err)
}
if len(certs) != 2 {
t.Errorf("Expected 2 certificates in PEM data; got %d", len(certs))
}
})
t.Run("fail/no-pem-certificate", func(t *testing.T) {
pkData := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: []byte("some-bogus-private-key"),
})
certs, err := pemDecodeCertificateChain(pkData)
if err == nil {
t.Fatalf("Expected pemDecodeCertificateChain to return an error")
}
if len(certs) != 0 {
t.Errorf("Expected 0 certificates in PEM data; got %d", len(certs))
}
})
t.Run("fail/no-der-certificate", func(t *testing.T) {
certs, err := pemDecodeCertificateChain([]byte("invalid-der-data"))
if err == nil {
t.Fatalf("Expected pemDecodeCertificateChain to return an error")
}
if len(certs) != 0 {
t.Errorf("Expected 0 certificates in PEM data; got %d", len(certs))
}
})
}
+22 -14
View File
@@ -24,20 +24,24 @@ import (
"go.uber.org/zap"
)
func (p *PKI) maintenance() {
func (p *PKI) maintenanceForCA(ca *CA) {
defer func() {
if err := recover(); err != nil {
log.Printf("[PANIC] PKI maintenance: %v\n%s", err, debug.Stack())
log.Printf("[PANIC] PKI maintenance for CA %s: %v\n%s", ca.ID, err, debug.Stack())
}
}()
ticker := time.NewTicker(10 * time.Minute) // TODO: make configurable
interval := time.Duration(ca.MaintenanceInterval)
if interval <= 0 {
interval = defaultMaintenanceInterval
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.renewCerts()
_ = p.renewCertsForCA(ca)
case <-p.ctx.Done():
return
}
@@ -63,19 +67,19 @@ func (p *PKI) renewCertsForCA(ca *CA) error {
// only maintain the root if it's not manually provided in the config
if ca.Root == nil {
if needsRenewal(ca.root) {
if ca.needsRenewal(ca.root) {
// TODO: implement root renewal (use same key)
log.Warn("root certificate expiring soon (FIXME: ROOT RENEWAL NOT YET IMPLEMENTED)",
zap.Duration("time_remaining", time.Until(ca.inter.NotAfter)),
zap.Duration("time_remaining", time.Until(ca.interChain[0].NotAfter)),
)
}
}
// only maintain the intermediate if it's not manually provided in the config
if ca.Intermediate == nil {
if needsRenewal(ca.inter) {
if ca.needsRenewal(ca.interChain[0]) {
log.Info("intermediate expires soon; renewing",
zap.Duration("time_remaining", time.Until(ca.inter.NotAfter)),
zap.Duration("time_remaining", time.Until(ca.interChain[0].NotAfter)),
)
rootCert, rootKey, err := ca.loadOrGenRoot()
@@ -86,10 +90,10 @@ func (p *PKI) renewCertsForCA(ca *CA) error {
if err != nil {
return fmt.Errorf("generating new certificate: %v", err)
}
ca.inter, ca.interKey = interCert, interKey
ca.interChain, ca.interKey = []*x509.Certificate{interCert}, interKey
log.Info("renewed intermediate",
zap.Time("new_expiration", ca.inter.NotAfter),
zap.Time("new_expiration", ca.interChain[0].NotAfter),
)
}
}
@@ -97,11 +101,15 @@ func (p *PKI) renewCertsForCA(ca *CA) error {
return nil
}
func needsRenewal(cert *x509.Certificate) bool {
// needsRenewal reports whether the certificate is within its renewal window
// (i.e. the fraction of lifetime remaining is less than or equal to RenewalWindowRatio).
func (ca *CA) needsRenewal(cert *x509.Certificate) bool {
ratio := ca.RenewalWindowRatio
if ratio <= 0 {
ratio = defaultRenewalWindowRatio
}
lifetime := cert.NotAfter.Sub(cert.NotBefore)
renewalWindow := time.Duration(float64(lifetime) * renewalWindowRatio)
renewalWindow := time.Duration(float64(lifetime) * ratio)
renewalWindowStart := cert.NotAfter.Add(-renewalWindow)
return time.Now().After(renewalWindowStart)
}
const renewalWindowRatio = 0.2 // TODO: make configurable
+86
View File
@@ -0,0 +1,86 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddypki
import (
"crypto/x509"
"testing"
"time"
)
func TestCA_needsRenewal(t *testing.T) {
now := time.Now()
// cert with 100 days lifetime; last 20% = 20 days before expiry
// So renewal window starts at (NotAfter - 20 days)
makeCert := func(daysUntilExpiry int, lifetimeDays int) *x509.Certificate {
notAfter := now.AddDate(0, 0, daysUntilExpiry)
notBefore := notAfter.AddDate(0, 0, -lifetimeDays)
return &x509.Certificate{NotBefore: notBefore, NotAfter: notAfter}
}
tests := []struct {
name string
ca *CA
cert *x509.Certificate
expect bool
}{
{
name: "inside renewal window with ratio 0.2",
ca: &CA{RenewalWindowRatio: 0.2},
cert: makeCert(10, 100),
expect: true,
},
{
name: "outside renewal window with ratio 0.2",
ca: &CA{RenewalWindowRatio: 0.2},
cert: makeCert(50, 100),
expect: false,
},
{
name: "outside renewal window with 21 days left",
ca: &CA{RenewalWindowRatio: 0.2},
cert: makeCert(21, 100),
expect: false,
},
{
name: "just inside renewal window with ratio 0.5",
ca: &CA{RenewalWindowRatio: 0.5},
cert: makeCert(30, 100),
expect: true,
},
{
name: "zero ratio uses default",
ca: &CA{RenewalWindowRatio: 0},
cert: makeCert(10, 100),
expect: true,
},
{
name: "invalid ratio uses default",
ca: &CA{RenewalWindowRatio: 1.5},
cert: makeCert(10, 100),
expect: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.ca.needsRenewal(tt.cert)
if got != tt.expect {
t.Errorf("needsRenewal() = %v, want %v", got, tt.expect)
}
})
}
}
+4 -2
View File
@@ -109,8 +109,10 @@ func (p *PKI) Start() error {
// see if root/intermediates need renewal...
p.renewCerts()
// ...and keep them renewed
go p.maintenance()
// ...and keep them renewed (one goroutine per CA with its own interval)
for _, ca := range p.CAs {
go p.maintenanceForCA(ca)
}
return nil
}
+4 -3
View File
@@ -178,6 +178,7 @@ func (iss *ACMEIssuer) Provision(ctx caddy.Context) error {
PropagationTimeout: time.Duration(iss.Challenges.DNS.PropagationTimeout),
Resolvers: iss.Challenges.DNS.Resolvers,
OverrideDomain: iss.Challenges.DNS.OverrideDomain,
Logger: iss.logger.Named("dns_manager"),
},
}
}
@@ -336,7 +337,7 @@ func (iss *ACMEIssuer) generateZeroSSLEABCredentials(ctx context.Context, acct a
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("User-Agent", certmagic.UserAgent)
resp, err := http.DefaultClient.Do(req)
resp, err := http.DefaultClient.Do(req) //nolint:gosec // no SSRF since URL is from trusted config
if err != nil {
return nil, acct, fmt.Errorf("performing EAB credentials request: %v", err)
}
@@ -671,7 +672,7 @@ func ParseCaddyfilePreferredChainsOptions(d *caddyfile.Dispenser) (*ChainPrefere
switch d.Val() {
case "root_common_name":
rootCommonNameOpt := d.RemainingArgs()
chainPref.RootCommonName = rootCommonNameOpt
chainPref.RootCommonName = append(chainPref.RootCommonName, rootCommonNameOpt...)
if rootCommonNameOpt == nil {
return nil, d.ArgErr()
}
@@ -681,7 +682,7 @@ func ParseCaddyfilePreferredChainsOptions(d *caddyfile.Dispenser) (*ChainPrefere
case "any_common_name":
anyCommonNameOpt := d.RemainingArgs()
chainPref.AnyCommonName = anyCommonNameOpt
chainPref.AnyCommonName = append(chainPref.AnyCommonName, anyCommonNameOpt...)
if anyCommonNameOpt == nil {
return nil, d.ArgErr()
}
+56 -6
View File
@@ -243,22 +243,49 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
}
}
// build certmagic.Config and attach it to the policy
storage := ap.storage
if storage == nil {
storage = tlsApp.ctx.Storage()
}
cfg, err := ap.makeCertMagicConfig(tlsApp, issuers, storage)
if err != nil {
return err
}
certCacheMu.RLock()
ap.magic = certmagic.New(certCache, cfg)
certCacheMu.RUnlock()
// give issuers a chance to see the config pointer
for _, issuer := range ap.magic.Issuers {
if annoying, ok := issuer.(ConfigSetter); ok {
annoying.SetConfig(ap.magic)
}
}
return nil
}
// makeCertMagicConfig constructs a certmagic.Config for this policy using the
// provided issuers and storage. It encapsulates common logic shared between
// Provision and RebuildCertMagic so we don't duplicate code.
func (ap *AutomationPolicy) makeCertMagicConfig(tlsApp *TLS, issuers []certmagic.Issuer, storage certmagic.Storage) (certmagic.Config, error) {
// key source
keyType := ap.KeyType
if keyType != "" {
var err error
keyType, err = caddy.NewReplacer().ReplaceOrErr(ap.KeyType, true, true)
if err != nil {
return fmt.Errorf("invalid key type %s: %s", ap.KeyType, err)
return certmagic.Config{}, fmt.Errorf("invalid key type %s: %s", ap.KeyType, err)
}
if _, ok := supportedCertKeyTypes[keyType]; !ok {
return fmt.Errorf("unrecognized key type: %s", keyType)
return certmagic.Config{}, fmt.Errorf("unrecognized key type: %s", keyType)
}
}
keySource := certmagic.StandardKeyGenerator{
KeyType: supportedCertKeyTypes[keyType],
}
storage := ap.storage
if storage == nil {
storage = tlsApp.ctx.Storage()
}
@@ -277,7 +304,7 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
if noProtections {
if !ap.hadExplicitManagers {
// no managers, no explicitly-configured permission module, this is a config error
return fmt.Errorf("on-demand TLS cannot be enabled without a permission module to prevent abuse; please refer to documentation for details")
return certmagic.Config{}, fmt.Errorf("on-demand TLS cannot be enabled without a permission module to prevent abuse; please refer to documentation for details")
}
// allow on-demand to be enabled but only for the purpose of the Managers; issuance won't be allowed from Issuers
tlsApp.logger.Warn("on-demand TLS can only get certificates from the configured external manager(s) because no ask endpoint / permission module is specified")
@@ -334,7 +361,7 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
}
}
template := certmagic.Config{
cfg := certmagic.Config{
MustStaple: ap.MustStaple,
RenewalWindowRatio: ap.RenewalWindowRatio,
KeySource: keySource,
@@ -349,8 +376,31 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
Issuers: issuers,
Logger: tlsApp.logger,
}
return cfg, nil
}
// IsProvisioned reports whether the automation policy has been
// provisioned. A provisioned policy has an initialized CertMagic
// instance (i.e. ap.magic != nil).
func (ap *AutomationPolicy) IsProvisioned() bool { return ap.magic != nil }
// RebuildCertMagic rebuilds the policy's CertMagic configuration from the
// policy's already-populated fields (Issuers, Managers, storage, etc.) and
// replaces the internal CertMagic instance. This is a lightweight
// alternative to calling Provision because it does not re-provision
// modules or re-run module Provision; instead, it constructs a new
// certmagic.Config and calls SetConfig on issuers so they receive updated
// templates (for example, alternate HTTP/TLS ports supplied by the HTTP
// app). RebuildCertMagic should only be called when the policy's required
// fields are already populated.
func (ap *AutomationPolicy) RebuildCertMagic(tlsApp *TLS) error {
cfg, err := ap.makeCertMagicConfig(tlsApp, ap.Issuers, ap.storage)
if err != nil {
return err
}
certCacheMu.RLock()
ap.magic = certmagic.New(certCache, template)
ap.magic = certmagic.New(certCache, cfg)
certCacheMu.RUnlock()
// sometimes issuers may need the parent certmagic.Config in
+7 -5
View File
@@ -257,7 +257,7 @@ func (PKIIntermediateCAPool) CaddyModule() caddy.ModuleInfo {
}
}
// Loads the PKI app and load the intermediate certificates into the certificate pool
// Loads the PKI app and loads the intermediate certificates into the certificate pool
func (p *PKIIntermediateCAPool) Provision(ctx caddy.Context) error {
pkiApp, err := ctx.AppIfConfigured("pki")
if err != nil {
@@ -274,7 +274,9 @@ func (p *PKIIntermediateCAPool) Provision(ctx caddy.Context) error {
caPool := x509.NewCertPool()
for _, ca := range p.ca {
caPool.AddCert(ca.IntermediateCertificate())
for _, c := range ca.IntermediateCertificateChain() {
caPool.AddCert(c)
}
}
p.pool = caPool
return nil
@@ -500,8 +502,8 @@ func (t *TLSConfig) unmarshalCaddyfile(d *caddyfile.Dispenser) error {
// If there is no custom TLS configuration, a nil config may be returned.
// copied from with minor modifications: modules/caddyhttp/reverseproxy/httptransport.go
func (t *TLSConfig) makeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) {
repl := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
if repl == nil {
repl, ok := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
if !ok || repl == nil {
repl = caddy.NewReplacer()
}
cfg := new(tls.Config)
@@ -586,7 +588,7 @@ func (hcp *HTTPCertPool) Provision(ctx caddy.Context) error {
if err != nil {
return err
}
res, err := httpClient.Do(req)
res, err := httpClient.Do(req) //nolint:gosec // SSRF false positive... uri comes from config
if err != nil {
return err
}
+1 -1
View File
@@ -155,7 +155,7 @@ func (hcg HTTPCertGetter) GetCertificate(ctx context.Context, hello *tls.ClientH
return nil, err
}
resp, err := http.DefaultClient.Do(req)
resp, err := http.DefaultClient.Do(req) //nolint:gosec // SSRF false positive... request URI comes from config
if err != nil {
return nil, err
}
+20 -25
View File
@@ -168,21 +168,11 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy.Context) *tls.Config {
tlsApp.RegisterServerNames(echNames)
}
// TODO: Ideally, ECH keys should be rotated. However, as of Go 1.24, the std lib implementation
// does not support safely modifying the tls.Config's EncryptedClientHelloKeys field.
// So, we implement static ECH keys temporarily. See https://github.com/golang/go/issues/71920.
// Revisit this after Go 1.25 is released and implement key rotation.
var stdECHKeys []tls.EncryptedClientHelloKey
for _, echConfigs := range tlsApp.EncryptedClientHello.configs {
for _, c := range echConfigs {
stdECHKeys = append(stdECHKeys, tls.EncryptedClientHelloKey{
Config: c.configBin,
PrivateKey: c.privKeyBin,
SendAsRetry: c.sendAsRetry,
})
}
tlsCfg.GetEncryptedClientHelloKeys = func(chi *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
tlsApp.EncryptedClientHello.configsMu.RLock()
defer tlsApp.EncryptedClientHello.configsMu.RUnlock()
return tlsApp.EncryptedClientHello.stdlibReady, nil
}
tlsCfg.EncryptedClientHelloKeys = stdECHKeys
}
}
@@ -794,7 +784,7 @@ func (clientauth *ClientAuthentication) provision(ctx caddy.Context) error {
for _, fpath := range clientauth.TrustedCACertPEMFiles {
ders, err := convertPEMFilesToDER(fpath)
if err != nil {
return nil
return err
}
clientauth.TrustedCACerts = append(clientauth.TrustedCACerts, ders...)
}
@@ -807,7 +797,7 @@ func (clientauth *ClientAuthentication) provision(ctx caddy.Context) error {
}
err := caPool.Provision(ctx)
if err != nil {
return nil
return err
}
clientauth.ca = caPool
}
@@ -895,24 +885,29 @@ func (clientauth *ClientAuthentication) ConfigureTLSConfig(cfg *tls.Config) erro
// if a custom verification function already exists, wrap it
clientauth.existingVerifyPeerCert = cfg.VerifyPeerCertificate
cfg.VerifyPeerCertificate = clientauth.verifyPeerCertificate
cfg.VerifyConnection = clientauth.verifyConnection
return nil
}
// verifyPeerCertificate is for use as a tls.Config.VerifyPeerCertificate
// callback to do custom client certificate verification. It is intended
// for installation only by clientauth.ConfigureTLSConfig().
func (clientauth *ClientAuthentication) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// verifyConnection is for use as a tls.Config.VerifyConnection callback
// to do custom client certificate verification. It is intended for
// installation only by clientauth.ConfigureTLSConfig().
//
// Unlike VerifyPeerCertificate, VerifyConnection is called on every
// connection including resumed sessions, preventing session-resumption bypass.
func (clientauth *ClientAuthentication) verifyConnection(cs tls.ConnectionState) error {
// first use any pre-existing custom verification function
if clientauth.existingVerifyPeerCert != nil {
err := clientauth.existingVerifyPeerCert(rawCerts, verifiedChains)
if err != nil {
rawCerts := make([][]byte, len(cs.PeerCertificates))
for i, cert := range cs.PeerCertificates {
rawCerts[i] = cert.Raw
}
if err := clientauth.existingVerifyPeerCert(rawCerts, cs.VerifiedChains); err != nil {
return err
}
}
for _, verifier := range clientauth.verifiers {
err := verifier.VerifyClientCertificate(rawCerts, verifiedChains)
if err != nil {
if err := verifier.VerifyClientCertificate(nil, cs.VerifiedChains); err != nil {
return err
}
}
+214 -64
View File
@@ -2,6 +2,7 @@ package caddytls
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
@@ -11,6 +12,7 @@ import (
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/caddyserver/certmagic"
@@ -48,12 +50,6 @@ func init() {
// applied will automatically upgrade the minimum TLS version to 1.3, even if
// configured to a lower version.
//
// Note that, as of Caddy 2.10.0 (~March 2025), ECH keys are not automatically
// rotated due to a limitation in the Go standard library (see
// https://github.com/golang/go/issues/71920). This should be resolved when
// Go 1.25 is released (~Aug. 2025), and Caddy will be updated to automatically
// rotate ECH keys/configs at that point.
//
// EXPERIMENTAL: Subject to change.
type ECH struct {
// The list of ECH configurations for which to automatically generate
@@ -73,14 +69,17 @@ type ECH struct {
// DNS RRs. (This also typically requires that they use DoH or DoT.)
Publication []*ECHPublication `json:"publication,omitempty"`
// map of public_name to list of configs
configs map[string][]echConfig
configsMu *sync.RWMutex // protects both configs and the list of configs/keys the standard library uses
configs map[string][]echConfig // map of public_name to list of configs
stdlibReady []tls.EncryptedClientHelloKey // ECH configs+keys in a format the standard library can use
}
// Provision loads or creates ECH configs and returns outer names (for certificate
// management), but does not publish any ECH configs. The DNS module is used as
// a default for later publishing if needed.
func (ech *ECH) Provision(ctx caddy.Context) ([]string, error) {
ech.configsMu = new(sync.RWMutex)
logger := ctx.Logger().Named("ech")
// set up publication modules before we need to obtain a lock in storage,
@@ -98,17 +97,60 @@ func (ech *ECH) Provision(ctx caddy.Context) ([]string, error) {
// the rest of provisioning needs an exclusive lock so that instances aren't
// stepping on each other when setting up ECH configs
storage := ctx.Storage()
const echLockName = "ech_provision"
if err := storage.Lock(ctx, echLockName); err != nil {
if err := storage.Lock(ctx, echStorageLockName); err != nil {
return nil, err
}
defer func() {
if err := storage.Unlock(ctx, echLockName); err != nil {
if err := storage.Unlock(ctx, echStorageLockName); err != nil {
logger.Error("unable to unlock ECH provisioning in storage", zap.Error(err))
}
}()
var outerNames []string //nolint:prealloc // (FALSE POSITIVE - see https://github.com/alexkohler/prealloc/issues/30)
ech.configsMu.Lock()
defer ech.configsMu.Unlock()
outerNames, err := ech.setConfigsFromStorage(ctx, logger)
if err != nil {
return nil, fmt.Errorf("loading configs from storage: %w", err)
}
// see if we need to make any new ones based on the input configuration
for _, cfg := range ech.Configs {
publicName := strings.ToLower(strings.TrimSpace(cfg.PublicName))
if list, ok := ech.configs[publicName]; !ok || len(list) == 0 {
// no config with this public name was loaded, so create one
echCfg, err := generateAndStoreECHConfig(ctx, publicName)
if err != nil {
return nil, err
}
logger.Debug("generated new ECH config",
zap.String("public_name", echCfg.RawPublicName),
zap.Uint8("id", echCfg.ConfigID))
ech.configs[publicName] = append(ech.configs[publicName], echCfg)
outerNames = append(outerNames, publicName)
}
}
// convert the configs into a structure ready for the std lib to use
ech.updateKeyList()
// ensure any old keys are rotated out
if err = ech.rotateECHKeys(ctx, logger, true); err != nil {
return nil, fmt.Errorf("rotating ECH configs: %w", err)
}
return outerNames, nil
}
// setConfigsFromStorage sets the ECH configs in memory to those in storage.
// It must be called in a write lock on ech.configsMu.
func (ech *ECH) setConfigsFromStorage(ctx caddy.Context, logger *zap.Logger) ([]string, error) {
storage := ctx.Storage()
ech.configs = make(map[string][]echConfig)
var outerNames []string
// start by loading all the existing configs (even the older ones on the way out,
// since some clients may still be using them if they haven't yet picked up on the
@@ -131,48 +173,145 @@ func (ech *ECH) Provision(ctx caddy.Context) ([]string, error) {
logger.Debug("loaded ECH config",
zap.String("public_name", cfg.RawPublicName),
zap.Uint8("id", cfg.ConfigID))
ech.configs[cfg.RawPublicName] = append(ech.configs[cfg.RawPublicName], cfg)
outerNames = append(outerNames, cfg.RawPublicName)
}
// all existing configs are now loaded; see if we need to make any new ones
// based on the input configuration, and also mark the most recent one(s) as
// current/active, so they can be used for ECH retries
for _, cfg := range ech.Configs {
publicName := strings.ToLower(strings.TrimSpace(cfg.PublicName))
if list, ok := ech.configs[publicName]; ok && len(list) > 0 {
// at least one config with this public name was loaded, so find the
// most recent one and mark it as active to be used with retries
var mostRecentDate time.Time
var mostRecentIdx int
for i, c := range list {
if mostRecentDate.IsZero() || c.meta.Created.After(mostRecentDate) {
mostRecentDate = c.meta.Created
mostRecentIdx = i
}
}
list[mostRecentIdx].sendAsRetry = true
} else {
// no config with this public name was loaded, so create one
echCfg, err := generateAndStoreECHConfig(ctx, publicName)
if err != nil {
return nil, err
}
logger.Debug("generated new ECH config",
zap.String("public_name", echCfg.RawPublicName),
zap.Uint8("id", echCfg.ConfigID))
ech.configs[publicName] = append(ech.configs[publicName], echCfg)
outerNames = append(outerNames, publicName)
if _, seen := ech.configs[cfg.RawPublicName]; !seen {
outerNames = append(outerNames, cfg.RawPublicName)
}
ech.configs[cfg.RawPublicName] = append(ech.configs[cfg.RawPublicName], cfg)
}
return outerNames, nil
}
func (t *TLS) publishECHConfigs() error {
logger := t.logger.Named("ech")
// rotateECHKeys updates the ECH keys/configs that are outdated if rotation is needed.
// It should be called in a write lock on ech.configsMu. If a lock is already obtained
// in storage, then pass true for storageSynced.
//
// This function sets/updates the stdlib-ready key list only if a rotation occurs.
func (ech *ECH) rotateECHKeys(ctx caddy.Context, logger *zap.Logger, storageSynced bool) error {
storage := ctx.Storage()
// all existing configs are now loaded; rotate keys "regularly" as recommended by the spec
// (also: "Rotating too frequently limits the client anonymity set." - but the more server
// names, the more frequently rotation can be done safely)
const (
rotationInterval = 24 * time.Hour * 30
deleteAfter = 24 * time.Hour * 90
)
if !ech.rotationNeeded(rotationInterval, deleteAfter) {
return nil
}
// sync this operation across cluster if not already
if !storageSynced {
if err := storage.Lock(ctx, echStorageLockName); err != nil {
return err
}
defer func() {
if err := storage.Unlock(ctx, echStorageLockName); err != nil {
logger.Error("unable to unlock ECH rotation in storage", zap.Error(err))
}
}()
}
// update what storage has, in case another instance already updated things
if _, err := ech.setConfigsFromStorage(ctx, logger); err != nil {
return fmt.Errorf("updating ECH keys from storage: %v", err)
}
// iterate the updated list and do any updates as needed
for publicName := range ech.configs {
for i := 0; i < len(ech.configs[publicName]); i++ {
cfg := ech.configs[publicName][i]
if time.Since(cfg.meta.Created) >= rotationInterval && cfg.meta.Replaced.IsZero() {
// key is due for rotation and it hasn't been replaced yet; do that now
logger.Debug("ECH config is due for rotation",
zap.String("public_name", cfg.RawPublicName),
zap.Uint8("id", cfg.ConfigID),
zap.Time("created", cfg.meta.Created),
zap.Duration("age", time.Since(cfg.meta.Created)),
zap.Duration("rotation_interval", rotationInterval))
// start by generating and storing the replacement ECH config
newCfg, err := generateAndStoreECHConfig(ctx, publicName)
if err != nil {
return fmt.Errorf("generating and storing new replacement ECH config: %w", err)
}
// mark the key as replaced so we don't rotate it again, and instead delete it later
ech.configs[publicName][i].meta.Replaced = time.Now()
// persist the updated metadata
metaBytes, err := json.Marshal(ech.configs[publicName][i].meta)
if err != nil {
return fmt.Errorf("marshaling updated ECH config metadata: %v", err)
}
if err := storage.Store(ctx, echMetaKey(cfg.ConfigID), metaBytes); err != nil {
return fmt.Errorf("storing updated ECH config metadata: %v", err)
}
ech.configs[publicName] = append(ech.configs[publicName], newCfg)
logger.Debug("rotated ECH key",
zap.String("public_name", cfg.RawPublicName),
zap.Uint8("old_id", cfg.ConfigID),
zap.Uint8("new_id", newCfg.ConfigID))
} else if time.Since(cfg.meta.Created) >= deleteAfter && !cfg.meta.Replaced.IsZero() {
// key has expired and is no longer supported; delete it from storage and memory
cfgIDKey := path.Join(echConfigsKey, strconv.Itoa(int(cfg.ConfigID)))
if err := storage.Delete(ctx, cfgIDKey); err != nil {
return fmt.Errorf("deleting expired ECH config: %v", err)
}
ech.configs[publicName] = append(ech.configs[publicName][:i], ech.configs[publicName][i+1:]...)
i--
logger.Debug("deleted expired ECH key",
zap.String("public_name", cfg.RawPublicName),
zap.Uint8("id", cfg.ConfigID),
zap.Duration("age", time.Since(cfg.meta.Created)))
}
}
}
ech.updateKeyList()
return nil
}
// rotationNeeded returns true if any ECH key needs to be replaced, or deleted.
// It must be called inside a read or write lock of ech.configsMu (probably a
// write lock, so that the rotation can occur correctly in the same lock).)
func (ech *ECH) rotationNeeded(rotationInterval, deleteAfter time.Duration) bool {
for publicName := range ech.configs {
for i := 0; i < len(ech.configs[publicName]); i++ {
cfg := ech.configs[publicName][i]
if (time.Since(cfg.meta.Created) >= rotationInterval && cfg.meta.Replaced.IsZero()) ||
(time.Since(cfg.meta.Created) >= deleteAfter && !cfg.meta.Replaced.IsZero()) {
return true
}
}
}
return false
}
// updateKeyList updates the list of ECH keys the std lib uses to serve ECH.
// It must be called inside a write lock on ech.configsMu.
func (ech *ECH) updateKeyList() {
ech.stdlibReady = []tls.EncryptedClientHelloKey{}
for _, cfgs := range ech.configs {
for _, cfg := range cfgs {
ech.stdlibReady = append(ech.stdlibReady, tls.EncryptedClientHelloKey{
Config: cfg.configBin,
PrivateKey: cfg.privKeyBin,
SendAsRetry: cfg.meta.Replaced.IsZero(), // only send during retries if key has not been rotated out
})
}
}
}
// publishECHConfigs publishes any configs that are configured for publication and which haven't been published already.
func (t *TLS) publishECHConfigs(logger *zap.Logger) error {
// make publication exclusive, since we don't need to repeat this unnecessarily
storage := t.ctx.Storage()
const echLockName = "ech_publish"
@@ -197,7 +336,7 @@ func (t *TLS) publishECHConfigs() error {
publishers: []ECHPublisher{
&ECHDNSPublisher{
provider: dnsProv,
logger: t.logger,
logger: logger,
},
},
},
@@ -209,6 +348,7 @@ func (t *TLS) publishECHConfigs() error {
// publish with it, and figure out which inner names to publish
// to/for, then publish
for _, publication := range publicationList {
t.EncryptedClientHello.configsMu.RLock()
// this publication is either configured for specific ECH configs,
// or we just use an implied default of all ECH configs
var echCfgList echConfigList
@@ -231,6 +371,7 @@ func (t *TLS) publishECHConfigs() error {
}
}
}
t.EncryptedClientHello.configsMu.RUnlock()
// marshal the ECH config list as binary for publication
echCfgListBin, err := echCfgList.MarshalBinary()
@@ -250,6 +391,10 @@ func (t *TLS) publishECHConfigs() error {
if publication.Domains == nil {
serverNamesSet = make(map[string]struct{}, len(t.serverNames))
for name := range t.serverNames {
// skip Tailscale names, a special case we also handle differently in our auto-HTTPS
if strings.HasSuffix(name, ".ts.net") {
continue
}
serverNamesSet[name] = struct{}{}
}
} else {
@@ -304,7 +449,7 @@ func (t *TLS) publishECHConfigs() error {
// at least a partial failure, maybe a complete failure, but we can
// log each error by domain
for innerName, domainErr := range publishErrs {
t.logger.Error("failed to publish ECH configuration list",
logger.Error("failed to publish ECH configuration list",
zap.String("publisher", publisherKey),
zap.String("domain", innerName),
zap.Uint8s("config_ids", configIDs),
@@ -312,7 +457,7 @@ func (t *TLS) publishECHConfigs() error {
}
} else if err != nil {
// generic error; assume the entire thing failed, I guess
t.logger.Error("failed publishing ECH configuration list",
logger.Error("failed publishing ECH configuration list",
zap.String("publisher", publisherKey),
zap.Strings("domains", dnsNamesToPublish),
zap.Uint8s("config_ids", configIDs),
@@ -334,7 +479,7 @@ func (t *TLS) publishECHConfigs() error {
successNames = append(successNames, name)
}
}
t.logger.Info("successfully published ECH configuration list for "+someAll+" domains",
logger.Info("successfully published ECH configuration list for "+someAll+" domains",
zap.String("publisher", publisherKey),
zap.Strings("domains", successNames),
zap.Uint8s("config_ids", configIDs))
@@ -353,13 +498,12 @@ func (t *TLS) publishECHConfigs() error {
if err != nil {
return fmt.Errorf("marshaling ECH config metadata: %v", err)
}
metaKey := path.Join(echConfigsKey, strconv.Itoa(int(cfg.ConfigID)), "meta.json")
if err := t.ctx.Storage().Store(t.ctx, metaKey, metaBytes); err != nil {
if err := t.ctx.Storage().Store(t.ctx, echMetaKey(cfg.ConfigID), metaBytes); err != nil {
return fmt.Errorf("storing updated ECH config metadata: %v", err)
}
}
} else {
t.logger.Error("all domains failed to publish ECH configuration list (see earlier errors)",
logger.Error("all domains failed to publish ECH configuration list (see earlier errors)",
zap.String("publisher", publisherKey),
zap.Strings("domains", dnsNamesToPublish),
zap.Uint8s("config_ids", configIDs))
@@ -489,7 +633,7 @@ func generateAndStoreECHConfig(ctx caddy.Context, publicName string) (echConfig,
echCfg := echConfig{
PublicKey: publicKey,
Version: draftTLSESNI22,
Version: draftTLSESNI25,
ConfigID: configID,
RawPublicName: publicName,
KEMID: kemChoice,
@@ -507,7 +651,6 @@ func generateAndStoreECHConfig(ctx caddy.Context, publicName string) (echConfig,
AEADID: hpke.AEAD_ChaCha20Poly1305,
},
},
sendAsRetry: true,
}
meta := echConfigMeta{
Created: time.Now(),
@@ -786,10 +929,9 @@ type echConfig struct {
// these fields are not part of the spec, but are here for
// our use when setting up TLS servers or maintenance
configBin []byte
privKeyBin []byte
meta echConfigMeta
sendAsRetry bool
configBin []byte
privKeyBin []byte
meta echConfigMeta
}
func (echCfg echConfig) MarshalBinary() ([]byte, error) {
@@ -811,8 +953,8 @@ func (echCfg *echConfig) UnmarshalBinary(data []byte) error {
if !b.ReadUint16(&echCfg.Version) {
return errInvalidLen
}
if echCfg.Version != draftTLSESNI22 {
return fmt.Errorf("supported version must be %d: got %d", draftTLSESNI22, echCfg.Version)
if echCfg.Version != draftTLSESNI25 {
return fmt.Errorf("supported version must be %d: got %d", draftTLSESNI25, echCfg.Version)
}
if !b.ReadUint16LengthPrefixed(&content) || !b.Empty() {
@@ -1022,19 +1164,27 @@ func (p PublishECHConfigListErrors) Error() string {
type echConfigMeta struct {
Created time.Time `json:"created"`
Replaced time.Time `json:"replaced,omitzero"`
Publications publicationHistory `json:"publications"`
}
func echMetaKey(configID uint8) string {
return path.Join(echConfigsKey, strconv.Itoa(int(configID)), "meta.json")
}
// publicationHistory is a map of publisher key to
// map of inner name to timestamp
type publicationHistory map[string]map[string]time.Time
// echStorageLockName is the name of the storage lock to sync ECH updates.
const echStorageLockName = "ech_rotation"
// The key prefix when putting ECH configs in storage. After this
// comes the config ID.
const echConfigsKey = "ech/configs"
// https://www.ietf.org/archive/id/draft-ietf-tls-esni-22.html
const draftTLSESNI22 = 0xfe0d
// https://www.ietf.org/archive/id/draft-ietf-tls-esni-25.html
const draftTLSESNI25 = 0xfe0d
// Interface guard
var _ ECHPublisher = (*ECHDNSPublisher)(nil)
+16 -6
View File
@@ -19,6 +19,7 @@ import (
"crypto/tls"
"encoding/pem"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
@@ -62,18 +63,27 @@ func (fl FolderLoader) Provision(ctx caddy.Context) error {
func (fl FolderLoader) LoadCertificates() ([]Certificate, error) {
var certs []Certificate
for _, dir := range fl {
err := filepath.Walk(dir, func(fpath string, info os.FileInfo, err error) error {
root, err := os.OpenRoot(dir)
if err != nil {
return nil, fmt.Errorf("unable to open root directory %s: %w", dir, err)
}
err = filepath.WalkDir(dir, func(fpath string, d fs.DirEntry, err error) error {
if err != nil {
return fmt.Errorf("unable to traverse into path: %s", fpath)
}
if info.IsDir() {
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
if !strings.HasSuffix(strings.ToLower(d.Name()), ".pem") {
return nil
}
bundle, err := os.ReadFile(fpath)
rel, err := filepath.Rel(dir, fpath)
if err != nil {
return fmt.Errorf("unable to get relative path for %s: %w", fpath, err)
}
bundle, err := root.ReadFile(rel)
if err != nil {
return err
}
@@ -83,11 +93,11 @@ func (fl FolderLoader) LoadCertificates() ([]Certificate, error) {
}
certs = append(certs, Certificate{Certificate: cert})
return nil
})
_ = root.Close()
if err != nil {
return nil, err
return nil, fmt.Errorf("walking certificates directory %s: %w", dir, err)
}
}
return certs, nil
+2 -1
View File
@@ -115,7 +115,8 @@ func (iss InternalIssuer) Issue(ctx context.Context, csr *x509.CertificateReques
if iss.SignWithRoot {
issuerCert = iss.ca.RootCertificate()
} else {
issuerCert = iss.ca.IntermediateCertificate()
chain := iss.ca.IntermediateCertificateChain()
issuerCert = chain[0]
}
// ensure issued certificate does not expire later than its issuer
+262
View File
@@ -0,0 +1,262 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddytls
import (
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddypki"
"go.uber.org/zap"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
)
func TestInternalIssuer_Issue(t *testing.T) {
rootSigner, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Creating root signer failed: %v", err)
}
tmpl := &x509.Certificate{
Subject: pkix.Name{CommonName: "test-root"},
IsCA: true,
MaxPathLen: 3,
NotAfter: time.Now().Add(7 * 24 * time.Hour),
NotBefore: time.Now().Add(-7 * 24 * time.Hour),
}
rootBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, rootSigner.Public(), rootSigner)
if err != nil {
t.Fatalf("Creating root certificate failed: %v", err)
}
root, err := x509.ParseCertificate(rootBytes)
if err != nil {
t.Fatalf("Parsing root certificate failed: %v", err)
}
firstIntermediateSigner, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Creating intermedaite signer failed: %v", err)
}
firstIntermediateBytes, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
Subject: pkix.Name{CommonName: "test-first-intermediate"},
IsCA: true,
MaxPathLen: 2,
NotAfter: time.Now().Add(24 * time.Hour),
NotBefore: time.Now().Add(-24 * time.Hour),
}, root, firstIntermediateSigner.Public(), rootSigner)
if err != nil {
t.Fatalf("Creating intermediate certificate failed: %v", err)
}
firstIntermediate, err := x509.ParseCertificate(firstIntermediateBytes)
if err != nil {
t.Fatalf("Parsing intermediate certificate failed: %v", err)
}
secondIntermediateSigner, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Creating second intermedaite signer failed: %v", err)
}
secondIntermediateBytes, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
Subject: pkix.Name{CommonName: "test-second-intermediate"},
IsCA: true,
MaxPathLen: 2,
NotAfter: time.Now().Add(24 * time.Hour),
NotBefore: time.Now().Add(-24 * time.Hour),
}, firstIntermediate, secondIntermediateSigner.Public(), firstIntermediateSigner)
if err != nil {
t.Fatalf("Creating second intermediate certificate failed: %v", err)
}
secondIntermediate, err := x509.ParseCertificate(secondIntermediateBytes)
if err != nil {
t.Fatalf("Parsing second intermediate certificate failed: %v", err)
}
dir := t.TempDir()
storageDir := filepath.Join(dir, "certmagic")
rootCertFile := filepath.Join(dir, "root.pem")
if _, err = pemutil.Serialize(root, pemutil.WithFilename(rootCertFile)); err != nil {
t.Fatalf("Failed serializing root certificate: %v", err)
}
intermediateCertFile := filepath.Join(dir, "intermediate.pem")
if _, err = pemutil.Serialize(firstIntermediate, pemutil.WithFilename(intermediateCertFile)); err != nil {
t.Fatalf("Failed serializing intermediate certificate: %v", err)
}
intermediateKeyFile := filepath.Join(dir, "intermediate.key")
if _, err = pemutil.Serialize(firstIntermediateSigner, pemutil.WithFilename(intermediateKeyFile)); err != nil {
t.Fatalf("Failed serializing intermediate key: %v", err)
}
var intermediateChainContents []byte
intermediateChain := []*x509.Certificate{secondIntermediate, firstIntermediate}
for _, cert := range intermediateChain {
b, err := pemutil.Serialize(cert)
if err != nil {
t.Fatalf("Failed serializing intermediate certificate: %v", err)
}
intermediateChainContents = append(intermediateChainContents, pem.EncodeToMemory(b)...)
}
intermediateChainFile := filepath.Join(dir, "intermediates.pem")
if err := os.WriteFile(intermediateChainFile, intermediateChainContents, 0644); err != nil {
t.Fatalf("Failed writing intermediate chain: %v", err)
}
intermediateChainKeyFile := filepath.Join(dir, "intermediates.key")
if _, err = pemutil.Serialize(secondIntermediateSigner, pemutil.WithFilename(intermediateChainKeyFile)); err != nil {
t.Fatalf("Failed serializing intermediate key: %v", err)
}
signer, err := keyutil.GenerateDefaultSigner()
if err != nil {
t.Fatalf("Failed creating signer: %v", err)
}
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{
Subject: pkix.Name{CommonName: "test"},
}, signer)
if err != nil {
t.Fatalf("Failed creating CSR: %v", err)
}
csr, err := x509.ParseCertificateRequest(csrBytes)
if err != nil {
t.Fatalf("Failed parsing CSR: %v", err)
}
t.Run("generated-with-defaults", func(t *testing.T) {
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: t.Context()})
t.Cleanup(cancel)
logger := zap.NewNop()
ca := &caddypki.CA{
StorageRaw: []byte(fmt.Sprintf(`{"module": "file_system", "root": %q}`, storageDir)),
}
if err := ca.Provision(caddyCtx, "local-test-generated", logger); err != nil {
t.Fatalf("Failed provisioning CA: %v", err)
}
iss := InternalIssuer{
SignWithRoot: false,
ca: ca,
logger: logger,
}
c, err := iss.Issue(t.Context(), csr)
if err != nil {
t.Fatalf("Failed issuing certificate: %v", err)
}
chain, err := pemutil.ParseCertificateBundle(c.Certificate)
if err != nil {
t.Errorf("Failed issuing certificate: %v", err)
}
if len(chain) != 2 {
t.Errorf("Expected 2 certificates in chain; got %d", len(chain))
}
})
t.Run("single-intermediate-from-disk", func(t *testing.T) {
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: t.Context()})
t.Cleanup(cancel)
logger := zap.NewNop()
ca := &caddypki.CA{
Root: &caddypki.KeyPair{
Certificate: rootCertFile,
},
Intermediate: &caddypki.KeyPair{
Certificate: intermediateCertFile,
PrivateKey: intermediateKeyFile,
},
StorageRaw: []byte(fmt.Sprintf(`{"module": "file_system", "root": %q}`, storageDir)),
}
if err := ca.Provision(caddyCtx, "local-test-single-intermediate", logger); err != nil {
t.Fatalf("Failed provisioning CA: %v", err)
}
iss := InternalIssuer{
ca: ca,
SignWithRoot: false,
logger: logger,
}
c, err := iss.Issue(t.Context(), csr)
if err != nil {
t.Fatalf("Failed issuing certificate: %v", err)
}
chain, err := pemutil.ParseCertificateBundle(c.Certificate)
if err != nil {
t.Errorf("Failed issuing certificate: %v", err)
}
if len(chain) != 2 {
t.Errorf("Expected 2 certificates in chain; got %d", len(chain))
}
})
t.Run("multiple-intermediates-from-disk", func(t *testing.T) {
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: t.Context()})
t.Cleanup(cancel)
logger := zap.NewNop()
ca := &caddypki.CA{
Root: &caddypki.KeyPair{
Certificate: rootCertFile,
},
Intermediate: &caddypki.KeyPair{
Certificate: intermediateChainFile,
PrivateKey: intermediateChainKeyFile,
},
StorageRaw: []byte(fmt.Sprintf(`{"module": "file_system", "root": %q}`, storageDir)),
}
if err := ca.Provision(caddyCtx, "local-test", zap.NewNop()); err != nil {
t.Fatalf("Failed provisioning CA: %v", err)
}
iss := InternalIssuer{
ca: ca,
SignWithRoot: false,
logger: logger,
}
c, err := iss.Issue(t.Context(), csr)
if err != nil {
t.Fatalf("Failed issuing certificate: %v", err)
}
chain, err := pemutil.ParseCertificateBundle(c.Certificate)
if err != nil {
t.Errorf("Failed issuing certificate: %v", err)
}
if len(chain) != 3 {
t.Errorf("Expected 3 certificates in chain; got %d", len(chain))
}
})
}
+2 -2
View File
@@ -29,9 +29,9 @@ func init() {
caddy.RegisterModule(LeafFolderLoader{})
}
// LeafFolderLoader loads certificates and their associated keys from disk
// LeafFolderLoader loads certificates from disk
// by recursively walking the specified directories, looking for PEM
// files which contain both a certificate and a key.
// files which contain a certificate.
type LeafFolderLoader struct {
Folders []string `json:"folders,omitempty"`
}
+37 -6
View File
@@ -123,8 +123,15 @@ type TLS struct {
//
// EXPERIMENTAL: Subject to change.
DNSRaw json.RawMessage `json:"dns,omitempty" caddy:"namespace=dns.providers inline_key=name"`
dns any // technically, it should be any/all of the libdns interfaces (RecordSetter, RecordAppender, etc.)
// The default DNS resolvers to use for TLS-related DNS operations, specifically
// for ACME DNS challenges and ACME server DNS validations.
// If not specified, the system default resolvers will be used.
//
// EXPERIMENTAL: Subject to change.
Resolvers []string `json:"resolvers,omitempty"`
dns any // technically, it should be any/all of the libdns interfaces (RecordSetter, RecordAppender, etc.)
certificateLoaders []CertificateLoader
automateNames map[string]struct{}
ctx caddy.Context
@@ -335,7 +342,6 @@ func (t *TLS) Provision(ctx caddy.Context) error {
// ECH (Encrypted ClientHello) initialization
if t.EncryptedClientHello != nil {
t.EncryptedClientHello.configs = make(map[string][]echConfig)
outerNames, err := t.EncryptedClientHello.Provision(ctx)
if err != nil {
return fmt.Errorf("provisioning Encrypted ClientHello components: %v", err)
@@ -411,12 +417,37 @@ func (t *TLS) Start() error {
return fmt.Errorf("automate: managing %v: %v", t.automateNames, err)
}
// publish ECH configs in the background; does not need to block
// server startup, as it could take a while
if t.EncryptedClientHello != nil {
echLogger := t.logger.Named("ech")
// publish ECH configs in the background; does not need to block
// server startup, as it could take a while; then keep keys rotated
go func() {
if err := t.publishECHConfigs(); err != nil {
t.logger.Named("ech").Error("publication(s) failed", zap.Error(err))
// publish immediately first
if err := t.publishECHConfigs(echLogger); err != nil {
echLogger.Error("publication(s) failed", zap.Error(err))
}
// then every so often, rotate and publish if needed
// (both of these functions only do something if needed)
for {
select {
case <-time.After(1 * time.Hour):
// ensure old keys are rotated out
t.EncryptedClientHello.configsMu.Lock()
err = t.EncryptedClientHello.rotateECHKeys(t.ctx, echLogger, false)
t.EncryptedClientHello.configsMu.Unlock()
if err != nil {
echLogger.Error("rotating ECH configs failed", zap.Error(err))
return
}
err := t.publishECHConfigs(echLogger)
if err != nil {
echLogger.Error("publication(s) failed", zap.Error(err))
}
case <-t.ctx.Done():
return
}
}
}()
}
+1 -1
View File
@@ -40,7 +40,7 @@ func init() {
type ZeroSSLIssuer struct {
// The API key (or "access key") for using the ZeroSSL API.
// REQUIRED.
APIKey string `json:"api_key,omitempty"`
APIKey string `json:"api_key,omitempty"` //nolint:gosec // false positive... yes this is exported, for JSON interop
// How many days the certificate should be valid for.
// Only certain values are accepted; see ZeroSSL docs.
+184 -23
View File
@@ -63,7 +63,7 @@ func (m *fileMode) UnmarshalJSON(b []byte) error {
// MarshalJSON satisfies json.Marshaler.
func (m *fileMode) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%04o\"", *m)), nil
return fmt.Appendf(nil, "\"%04o\"", *m), nil
}
// parseFileMode parses a file mode string,
@@ -90,6 +90,15 @@ type FileWriter struct {
// 0600 by default.
Mode fileMode `json:"mode,omitempty"`
// DirMode controls permissions for any directories created to reach Filename.
// Default: 0700 (current behavior).
//
// Special values:
// - "inherit" → copy the nearest existing parent directory's perms (with r→x normalization)
// - "from_file" → derive from the file Mode (with r→x), e.g. 0644 → 0755, 0600 → 0700
// Numeric octal strings (e.g. "0755") are also accepted. Subject to process umask.
DirMode string `json:"dir_mode,omitempty"`
// Roll toggles log rolling or rotation, which is
// enabled by default.
Roll *bool `json:"roll,omitempty"`
@@ -113,9 +122,16 @@ type FileWriter struct {
// See https://github.com/DeRuina/timberjack#%EF%B8%8F-rotation-notes--warnings for caveats
RollAt []string `json:"roll_at,omitempty"`
// Whether to compress rolled files. Default: true
// Whether to compress rolled files.
// Default: true.
// Deprecated: Use RollCompression instead, setting it to "none".
RollCompress *bool `json:"roll_gzip,omitempty"`
// RollCompression selects the compression algorithm for rolled files.
// Accepted values: "none", "gzip", "zstd".
// Default: gzip
RollCompression string `json:"roll_compression,omitempty"`
// Whether to use local timestamps in rolled filenames.
// Default: false
RollLocalTime bool `json:"roll_local_time,omitempty"`
@@ -177,11 +193,33 @@ func (fw FileWriter) OpenWriter() (io.WriteCloser, error) {
// roll log files as a sensible default to avoid disk space exhaustion
roll := fw.Roll == nil || *fw.Roll
// create the file if it does not exist; create with the configured mode, or default
// to restrictive if not set. (timberjack will reuse the file mode across log rotation)
if err := os.MkdirAll(filepath.Dir(fw.Filename), 0o700); err != nil {
return nil, err
// Ensure directory exists before opening the file.
dirPath := filepath.Dir(fw.Filename)
switch strings.ToLower(strings.TrimSpace(fw.DirMode)) {
case "", "0":
// Preserve current behavior: locked-down directories by default.
if err := os.MkdirAll(dirPath, 0o700); err != nil {
return nil, err
}
case "inherit":
if err := mkdirAllInherit(dirPath); err != nil {
return nil, err
}
case "from_file":
if err := mkdirAllFromFile(dirPath, os.FileMode(fw.Mode)); err != nil {
return nil, err
}
default:
dm, err := parseFileMode(fw.DirMode)
if err != nil {
return nil, fmt.Errorf("dir_mode: %w", err)
}
if err := os.MkdirAll(dirPath, dm); err != nil {
return nil, err
}
}
// create/open the file
file, err := os.OpenFile(fw.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, modeIfCreating)
if err != nil {
return nil, err
@@ -223,27 +261,104 @@ func (fw FileWriter) OpenWriter() (io.WriteCloser, error) {
if fw.RollKeepDays == 0 {
fw.RollKeepDays = 90
}
// Determine compression algorithm to use. Priority:
// 1) explicit RollCompression (none|gzip|zstd)
// 2) if RollCompress is unset or true -> gzip
// 3) if RollCompress is false -> none
var compression string
if fw.RollCompression != "" {
compression = strings.ToLower(strings.TrimSpace(fw.RollCompression))
if compression != "none" && compression != "gzip" && compression != "zstd" {
return nil, fmt.Errorf("invalid roll_compression: %s", fw.RollCompression)
}
} else {
if fw.RollCompress == nil || *fw.RollCompress {
compression = "gzip"
} else {
compression = "none"
}
}
return &timberjack.Logger{
Filename: fw.Filename,
MaxSize: fw.RollSizeMB,
MaxAge: fw.RollKeepDays,
MaxBackups: fw.RollKeep,
LocalTime: fw.RollLocalTime,
Compress: *fw.RollCompress,
Compression: compression,
RotationInterval: fw.RollInterval,
RotateAtMinutes: fw.RollAtMinutes,
RotateAt: fw.RollAt,
BackupTimeFormat: fw.BackupTimeFormat,
FileMode: os.FileMode(fw.Mode),
}, nil
}
// normalizeDirPerm ensures that read bits also have execute bits set.
func normalizeDirPerm(p os.FileMode) os.FileMode {
if p&0o400 != 0 {
p |= 0o100
}
if p&0o040 != 0 {
p |= 0o010
}
if p&0o004 != 0 {
p |= 0o001
}
return p
}
// mkdirAllInherit creates missing dirs using the nearest existing parent's
// permissions, normalized with r→x.
func mkdirAllInherit(dir string) error {
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
return nil
}
cur := dir
var parent string
for {
next := filepath.Dir(cur)
if next == cur {
parent = next
break
}
if fi, err := os.Stat(next); err == nil {
if !fi.IsDir() {
return fmt.Errorf("path component %s exists and is not a directory", next)
}
parent = next
break
}
cur = next
}
perm := os.FileMode(0o700)
if fi, err := os.Stat(parent); err == nil && fi.IsDir() {
perm = fi.Mode().Perm()
}
perm = normalizeDirPerm(perm)
return os.MkdirAll(dir, perm)
}
// mkdirAllFromFile creates missing dirs using the file's mode (with r→x) so
// 0644 → 0755, 0600 → 0700, etc.
func mkdirAllFromFile(dir string, fileMode os.FileMode) error {
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
return nil
}
perm := normalizeDirPerm(fileMode.Perm()) | 0o200 // ensure owner write on dir so files can be created
return os.MkdirAll(dir, perm)
}
// UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax:
//
// file <filename> {
// mode <mode>
// dir_mode <mode|inherit|from_file>
// roll_disabled
// roll_size <size>
// roll_uncompressed
// roll_compression <none|gzip|zstd>
// roll_local_time
// roll_keep <num>
// roll_keep_for <days>
@@ -284,6 +399,22 @@ func (fw *FileWriter) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
fw.Mode = fileMode(mode)
case "dir_mode":
var val string
if !d.AllArgs(&val) {
return d.ArgErr()
}
val = strings.TrimSpace(val)
switch strings.ToLower(val) {
case "inherit", "from_file":
fw.DirMode = val
default:
if _, err := parseFileMode(val); err != nil {
return d.Errf("parsing dir_mode: %v", err)
}
fw.DirMode = val
}
case "roll_disabled":
var f bool
fw.Roll = &f
@@ -309,6 +440,19 @@ func (fw *FileWriter) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return d.ArgErr()
}
case "roll_compression":
var comp string
if !d.AllArgs(&comp) {
return d.ArgErr()
}
comp = strings.ToLower(strings.TrimSpace(comp))
switch comp {
case "none", "gzip", "zstd":
fw.RollCompression = comp
default:
return d.Errf("parsing roll_compression: must be 'none', 'gzip' or 'zstd'")
}
case "roll_local_time":
fw.RollLocalTime = true
if d.NextArg() {
@@ -352,31 +496,48 @@ func (fw *FileWriter) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
fw.RollInterval = duration
case "roll_minutes":
var minutesArrayStr string
if !d.AllArgs(&minutesArrayStr) {
// Accept either a single comma-separated argument or
// multiple space-separated arguments. Collect all
// remaining args on the line and split on commas.
args := d.RemainingArgs()
if len(args) == 0 {
return d.ArgErr()
}
minutesStr := strings.Split(minutesArrayStr, ",")
minutes := make([]int, len(minutesStr))
for i := range minutesStr {
ms := strings.Trim(minutesStr[i], " ")
m, err := strconv.Atoi(ms)
if err != nil {
return d.Errf("parsing roll_minutes number: %v", err)
var minutes []int
for _, arg := range args {
parts := strings.SplitSeq(arg, ",")
for p := range parts {
ms := strings.TrimSpace(p)
if ms == "" {
return d.Errf("parsing roll_minutes: empty value")
}
m, err := strconv.Atoi(ms)
if err != nil {
return d.Errf("parsing roll_minutes number: %v", err)
}
minutes = append(minutes, m)
}
minutes[i] = m
}
fw.RollAtMinutes = minutes
case "roll_at":
var timeArrayStr string
if !d.AllArgs(&timeArrayStr) {
// Accept either a single comma-separated argument or
// multiple space-separated arguments. Collect all
// remaining args on the line and split on commas.
args := d.RemainingArgs()
if len(args) == 0 {
return d.ArgErr()
}
timeStr := strings.Split(timeArrayStr, ",")
times := make([]string, len(timeStr))
for i := range timeStr {
times[i] = strings.Trim(timeStr[i], " ")
var times []string
for _, arg := range args {
parts := strings.SplitSeq(arg, ",")
for p := range parts {
ts := strings.TrimSpace(p)
if ts == "" {
return d.Errf("parsing roll_at: empty value")
}
times = append(times, ts)
}
}
fw.RollAt = times
+222
View File
@@ -385,3 +385,225 @@ func TestFileModeModification(t *testing.T) {
t.Errorf("file mode is %v, want %v", st.Mode(), want)
}
}
func TestDirMode_Inherit(t *testing.T) {
m := syscall.Umask(0)
defer syscall.Umask(m)
parent := t.TempDir()
if err := os.Chmod(parent, 0o755); err != nil {
t.Fatal(err)
}
targetDir := filepath.Join(parent, "a", "b")
fw := &FileWriter{
Filename: filepath.Join(targetDir, "test.log"),
DirMode: "inherit",
Mode: 0o640,
Roll: func() *bool { f := false; return &f }(),
}
w, err := fw.OpenWriter()
if err != nil {
t.Fatal(err)
}
_ = w.Close()
st, err := os.Stat(targetDir)
if err != nil {
t.Fatal(err)
}
if got := st.Mode().Perm(); got != 0o755 {
t.Fatalf("dir perm = %o, want 0755", got)
}
}
func TestDirMode_FromFile(t *testing.T) {
m := syscall.Umask(0)
defer syscall.Umask(m)
base := t.TempDir()
dir1 := filepath.Join(base, "logs1")
fw1 := &FileWriter{
Filename: filepath.Join(dir1, "app.log"),
DirMode: "from_file",
Mode: 0o644, // => dir 0755
Roll: func() *bool { f := false; return &f }(),
}
w1, err := fw1.OpenWriter()
if err != nil {
t.Fatal(err)
}
_ = w1.Close()
st1, err := os.Stat(dir1)
if err != nil {
t.Fatal(err)
}
if got := st1.Mode().Perm(); got != 0o755 {
t.Fatalf("dir perm = %o, want 0755", got)
}
dir2 := filepath.Join(base, "logs2")
fw2 := &FileWriter{
Filename: filepath.Join(dir2, "app.log"),
DirMode: "from_file",
Mode: 0o600, // => dir 0700
Roll: func() *bool { f := false; return &f }(),
}
w2, err := fw2.OpenWriter()
if err != nil {
t.Fatal(err)
}
_ = w2.Close()
st2, err := os.Stat(dir2)
if err != nil {
t.Fatal(err)
}
if got := st2.Mode().Perm(); got != 0o700 {
t.Fatalf("dir perm = %o, want 0700", got)
}
}
func TestDirMode_ExplicitOctal(t *testing.T) {
m := syscall.Umask(0)
defer syscall.Umask(m)
base := t.TempDir()
dest := filepath.Join(base, "logs3")
fw := &FileWriter{
Filename: filepath.Join(dest, "app.log"),
DirMode: "0750",
Mode: 0o640,
Roll: func() *bool { f := false; return &f }(),
}
w, err := fw.OpenWriter()
if err != nil {
t.Fatal(err)
}
_ = w.Close()
st, err := os.Stat(dest)
if err != nil {
t.Fatal(err)
}
if got := st.Mode().Perm(); got != 0o750 {
t.Fatalf("dir perm = %o, want 0750", got)
}
}
func TestDirMode_Default0700(t *testing.T) {
m := syscall.Umask(0)
defer syscall.Umask(m)
base := t.TempDir()
dest := filepath.Join(base, "logs4")
fw := &FileWriter{
Filename: filepath.Join(dest, "app.log"),
Mode: 0o640,
Roll: func() *bool { f := false; return &f }(),
}
w, err := fw.OpenWriter()
if err != nil {
t.Fatal(err)
}
_ = w.Close()
st, err := os.Stat(dest)
if err != nil {
t.Fatal(err)
}
if got := st.Mode().Perm(); got != 0o700 {
t.Fatalf("dir perm = %o, want 0700", got)
}
}
func TestDirMode_UmaskInteraction(t *testing.T) {
_ = syscall.Umask(0o022) // typical umask; restore after
defer syscall.Umask(0)
base := t.TempDir()
dest := filepath.Join(base, "logs5")
fw := &FileWriter{
Filename: filepath.Join(dest, "app.log"),
DirMode: "0755",
Mode: 0o644,
Roll: func() *bool { f := false; return &f }(),
}
w, err := fw.OpenWriter()
if err != nil {
t.Fatal(err)
}
_ = w.Close()
st, err := os.Stat(dest)
if err != nil {
t.Fatal(err)
}
// 0755 &^ 0022 still 0755 for dirs; this just sanity-checks we didn't get stricter unexpectedly
if got := st.Mode().Perm(); got != 0o755 {
t.Fatalf("dir perm = %o, want 0755 (considering umask)", got)
}
}
func TestCaddyfile_DirMode_Inherit(t *testing.T) {
d := caddyfile.NewTestDispenser(`
file /var/log/app.log {
dir_mode inherit
mode 0640
}`)
var fw FileWriter
if err := fw.UnmarshalCaddyfile(d); err != nil {
t.Fatal(err)
}
if fw.DirMode != "inherit" {
t.Fatalf("got %q", fw.DirMode)
}
if fw.Mode != 0o640 {
t.Fatalf("mode = %o", fw.Mode)
}
}
func TestCaddyfile_DirMode_FromFile(t *testing.T) {
d := caddyfile.NewTestDispenser(`
file /var/log/app.log {
dir_mode from_file
mode 0600
}`)
var fw FileWriter
if err := fw.UnmarshalCaddyfile(d); err != nil {
t.Fatal(err)
}
if fw.DirMode != "from_file" {
t.Fatalf("got %q", fw.DirMode)
}
if fw.Mode != 0o600 {
t.Fatalf("mode = %o", fw.Mode)
}
}
func TestCaddyfile_DirMode_Octal(t *testing.T) {
d := caddyfile.NewTestDispenser(`
file /var/log/app.log {
dir_mode 0755
}`)
var fw FileWriter
if err := fw.UnmarshalCaddyfile(d); err != nil {
t.Fatal(err)
}
if fw.DirMode != "0755" {
t.Fatalf("got %q", fw.DirMode)
}
}
func TestCaddyfile_DirMode_Invalid(t *testing.T) {
d := caddyfile.NewTestDispenser(`
file /var/log/app.log {
dir_mode nope
}`)
var fw FileWriter
if err := fw.UnmarshalCaddyfile(d); err == nil {
t.Fatal("expected error for invalid dir_mode")
}
}
+40 -2
View File
@@ -23,9 +23,9 @@ import (
)
// Windows relies on ACLs instead of unix permissions model.
// Go allows to open files with a particular mode put it is limited to read or write.
// Go allows to open files with a particular mode but it is limited to read or write.
// See https://cs.opensource.google/go/go/+/refs/tags/go1.22.3:src/syscall/syscall_windows.go;l=708.
// This is pretty restrictive and has few interest for log files and thus we just test that log files are
// This is pretty restrictive and has little interest for log files and thus we just test that log files are
// opened with R/W permissions by default on Windows too.
func TestFileCreationMode(t *testing.T) {
dir, err := os.MkdirTemp("", "caddytest")
@@ -53,3 +53,41 @@ func TestFileCreationMode(t *testing.T) {
t.Fatalf("file mode is %v, want rw for user", st.Mode().Perm())
}
}
func TestDirMode_Windows_CreateSucceeds(t *testing.T) {
dir, err := os.MkdirTemp("", "caddytest")
if err != nil {
t.Fatalf("failed to create tempdir: %v", err)
}
defer os.RemoveAll(dir)
tests := []struct {
name string
dirMode string
}{
{"inherit", "inherit"},
{"from_file", "from_file"},
{"octal", "0755"},
{"default", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
subdir := path.Join(dir, "logs-"+tt.name)
fw := &FileWriter{
Filename: path.Join(subdir, "test.log"),
DirMode: tt.dirMode,
Mode: 0o600,
}
w, err := fw.OpenWriter()
if err != nil {
t.Fatalf("failed to open writer: %v", err)
}
defer w.Close()
if _, err := os.Stat(fw.Filename); err != nil {
t.Fatalf("expected file to exist: %v", err)
}
})
}
}
+5 -5
View File
@@ -255,7 +255,7 @@ func (m IPMaskFilter) Filter(in zapcore.Field) zapcore.Field {
}
func (m IPMaskFilter) mask(s string) string {
output := ""
parts := make([]string, 0)
for value := range strings.SplitSeq(s, ",") {
value = strings.TrimSpace(value)
host, port, err := net.SplitHostPort(value)
@@ -264,7 +264,7 @@ func (m IPMaskFilter) mask(s string) string {
}
ipAddr := net.ParseIP(host)
if ipAddr == nil {
output += value + ", "
parts = append(parts, value)
continue
}
mask := m.v4Mask
@@ -273,13 +273,13 @@ func (m IPMaskFilter) mask(s string) string {
}
masked := ipAddr.Mask(mask)
if port == "" {
output += masked.String() + ", "
parts = append(parts, masked.String())
continue
}
output += net.JoinHostPort(masked.String(), port) + ", "
parts = append(parts, net.JoinHostPort(masked.String(), port))
}
return strings.TrimSuffix(output, ", ")
return strings.Join(parts, ", ")
}
type filterAction string