mirror of
https://github.com/caddyserver/caddy.git
synced 2026-05-27 01:02:29 -04:00
Merge branch 'master' into add-tests
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
@@ -51,6 +52,7 @@ func init() {
|
||||
// Placeholder | Description
|
||||
// ------------|---------------
|
||||
// `{http.request.body}` | The request body (⚠️ inefficient; use only for debugging)
|
||||
// `{http.request.body_base64}` | The request body, base64-encoded (⚠️ for debugging)
|
||||
// `{http.request.cookie.*}` | HTTP request cookie
|
||||
// `{http.request.duration}` | Time up to now spent handling the request (after decoding headers from client)
|
||||
// `{http.request.duration_ms}` | Same as 'duration', but in milliseconds.
|
||||
@@ -82,6 +84,7 @@ func init() {
|
||||
// `{http.request.tls.proto}` | The negotiated next protocol
|
||||
// `{http.request.tls.proto_mutual}` | The negotiated next protocol was advertised by the server
|
||||
// `{http.request.tls.server_name}` | The server name requested by the client, if any
|
||||
// `{http.request.tls.ech}` | Whether ECH was offered by the client and accepted by the server
|
||||
// `{http.request.tls.client.fingerprint}` | The SHA256 checksum of the client certificate
|
||||
// `{http.request.tls.client.public_key}` | The public key of the client certificate.
|
||||
// `{http.request.tls.client.public_key_sha256}` | The SHA256 checksum of the client's public key.
|
||||
@@ -198,6 +201,8 @@ func (app *App) Provision(ctx caddy.Context) error {
|
||||
if app.Metrics != nil {
|
||||
app.Metrics.init = sync.Once{}
|
||||
app.Metrics.httpMetrics = &httpMetrics{}
|
||||
// Scan config for allowed hosts to prevent cardinality explosion
|
||||
app.Metrics.scanConfigForHosts(app)
|
||||
}
|
||||
// prepare each server
|
||||
oldContext := ctx.Context
|
||||
@@ -344,6 +349,20 @@ func (app *App) Provision(ctx caddy.Context) error {
|
||||
srv.listenerWrappers = append([]caddy.ListenerWrapper{new(tlsPlaceholderWrapper)}, srv.listenerWrappers...)
|
||||
}
|
||||
}
|
||||
|
||||
// set up each packet conn modifier
|
||||
if srv.PacketConnWrappersRaw != nil {
|
||||
vals, err := ctx.LoadModule(srv, "PacketConnWrappersRaw")
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading packet conn wrapper modules: %v", err)
|
||||
}
|
||||
// if any wrappers were configured, they come before the QUIC handshake;
|
||||
// unlike TLS above, there is no QUIC placeholder
|
||||
for _, val := range vals.([]any) {
|
||||
srv.packetConnWrappers = append(srv.packetConnWrappers, val.(caddy.PacketConnWrapper))
|
||||
}
|
||||
}
|
||||
|
||||
// pre-compile the primary handler chain, and be sure to wrap it in our
|
||||
// route handler so that important security checks are done, etc.
|
||||
primaryRoute := emptyHandler
|
||||
@@ -693,9 +712,10 @@ func (app *App) Stop() error {
|
||||
// enforce grace period if configured
|
||||
if app.GracePeriod > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(app.GracePeriod))
|
||||
timeout := time.Duration(app.GracePeriod)
|
||||
ctx, cancel = context.WithTimeoutCause(ctx, timeout, fmt.Errorf("server graceful shutdown %ds timeout", int(timeout.Seconds())))
|
||||
defer cancel()
|
||||
app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", time.Duration(app.GracePeriod)))
|
||||
app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", timeout))
|
||||
} else {
|
||||
app.logger.Info("servers shutting down with eternal grace period")
|
||||
}
|
||||
@@ -721,6 +741,9 @@ func (app *App) Stop() error {
|
||||
}
|
||||
|
||||
if err := server.server.Shutdown(ctx); err != nil {
|
||||
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
|
||||
err = cause
|
||||
}
|
||||
app.logger.Error("server shutdown",
|
||||
zap.Error(err),
|
||||
zap.Strings("addresses", server.Listen))
|
||||
@@ -744,6 +767,9 @@ func (app *App) Stop() error {
|
||||
}
|
||||
|
||||
if err := server.h3server.Shutdown(ctx); err != nil {
|
||||
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
|
||||
err = cause
|
||||
}
|
||||
app.logger.Error("HTTP/3 server shutdown",
|
||||
zap.Error(err),
|
||||
zap.Strings("addresses", server.Listen))
|
||||
|
||||
@@ -90,7 +90,16 @@ func (app *App) automaticHTTPSPhase1(ctx caddy.Context, repl *caddy.Replacer) er
|
||||
// the log configuration for an HTTPS enabled server
|
||||
var logCfg *ServerLogConfig
|
||||
|
||||
for srvName, srv := range app.Servers {
|
||||
// Sort server names to ensure deterministic iteration.
|
||||
// This prevents race conditions where the order of server processing
|
||||
// could affect which server gets assigned the HTTP->HTTPS redirect listener.
|
||||
srvNames := make([]string, 0, len(app.Servers))
|
||||
for name := range app.Servers {
|
||||
srvNames = append(srvNames, name)
|
||||
}
|
||||
slices.Sort(srvNames)
|
||||
for _, srvName := range srvNames {
|
||||
srv := app.Servers[srvName]
|
||||
// as a prerequisite, provision route matchers; this is
|
||||
// required for all routes on all servers, and must be
|
||||
// done before we attempt to do phase 1 of auto HTTPS,
|
||||
@@ -398,15 +407,60 @@ uniqueDomainsLoop:
|
||||
return append(routes, app.makeRedirRoute(uint(app.httpsPort()), MatcherSet{MatchProtocol("http")}))
|
||||
}
|
||||
|
||||
// Sort redirect addresses to ensure deterministic process
|
||||
redirServerAddrsSorted := make([]string, 0, len(redirServers))
|
||||
for addr := range redirServers {
|
||||
redirServerAddrsSorted = append(redirServerAddrsSorted, addr)
|
||||
}
|
||||
slices.Sort(redirServerAddrsSorted)
|
||||
|
||||
redirServersLoop:
|
||||
for redirServerAddr, routes := range redirServers {
|
||||
for _, redirServerAddr := range redirServerAddrsSorted {
|
||||
routes := redirServers[redirServerAddr]
|
||||
// for each redirect listener, see if there's already a
|
||||
// server configured to listen on that exact address; if so,
|
||||
// insert the redirect route to the end of its route list
|
||||
// after any other routes with host matchers; otherwise,
|
||||
// we'll create a new server for all the listener addresses
|
||||
// that are unused and serve the remaining redirects from it
|
||||
for _, srv := range app.Servers {
|
||||
|
||||
// Sort redirect routes by host specificity to ensure exact matches
|
||||
// take precedence over wildcards, preventing ambiguous routing.
|
||||
slices.SortFunc(routes, func(a, b Route) int {
|
||||
hostA := getFirstHostFromRoute(a)
|
||||
hostB := getFirstHostFromRoute(b)
|
||||
|
||||
// Catch-all routes (empty host) have the lowest priority
|
||||
if hostA == "" && hostB != "" {
|
||||
return 1
|
||||
}
|
||||
if hostB == "" && hostA != "" {
|
||||
return -1
|
||||
}
|
||||
|
||||
hasWildcardA := strings.Contains(hostA, "*")
|
||||
hasWildcardB := strings.Contains(hostB, "*")
|
||||
|
||||
// Exact domains take precedence over wildcards
|
||||
if !hasWildcardA && hasWildcardB {
|
||||
return -1
|
||||
}
|
||||
if hasWildcardA && !hasWildcardB {
|
||||
return 1
|
||||
}
|
||||
|
||||
// If both are exact or both are wildcards, the longer one is more specific
|
||||
if len(hostA) != len(hostB) {
|
||||
return len(hostB) - len(hostA)
|
||||
}
|
||||
|
||||
// Tie-breaker: alphabetical order to ensure determinism
|
||||
return strings.Compare(hostA, hostB)
|
||||
})
|
||||
|
||||
// Use the sorted srvNames to consistently find the target server
|
||||
for _, srvName := range srvNames {
|
||||
srv := app.Servers[srvName]
|
||||
// only look at servers which listen on an address which
|
||||
// we want to add redirects to
|
||||
if !srv.hasListenerAddress(redirServerAddr) {
|
||||
@@ -560,6 +614,27 @@ func (app *App) createAutomationPolicies(ctx caddy.Context, internalNames, tails
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure automation policies' CertMagic configs are rebuilt when
|
||||
// ACME issuer templates may have been modified above (for example,
|
||||
// alternate ports filled in by the HTTP app). If a policy is already
|
||||
// provisioned, perform a lightweight rebuild of the CertMagic config
|
||||
// so issuers receive SetConfig with the updated templates; otherwise
|
||||
// run a normal Provision to initialize the policy.
|
||||
for i, ap := range app.tlsApp.Automation.Policies {
|
||||
// If the policy is already provisioned, rebuild only the CertMagic
|
||||
// config so issuers get SetConfig with updated templates. Otherwise
|
||||
// provision the policy normally (which may load modules).
|
||||
if ap.IsProvisioned() {
|
||||
if err := ap.RebuildCertMagic(app.tlsApp); err != nil {
|
||||
return fmt.Errorf("rebuilding certmagic config for automation policy %d: %v", i, err)
|
||||
}
|
||||
} else {
|
||||
if err := ap.Provision(app.tlsApp); err != nil {
|
||||
return fmt.Errorf("provisioning automation policy %d after auto-HTTPS defaults: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if basePolicy == nil {
|
||||
// no base policy found; we will make one
|
||||
basePolicy = new(caddytls.AutomationPolicy)
|
||||
@@ -773,3 +848,26 @@ func isTailscaleDomain(name string) bool {
|
||||
}
|
||||
|
||||
type acmeCapable interface{ GetACMEIssuer() *caddytls.ACMEIssuer }
|
||||
|
||||
// getFirstHostFromRoute traverses a route's matchers to find the Host rule.
|
||||
// Since we are dealing with internally generated redirect routes, the host
|
||||
// is typically the first string within the MatchHost.
|
||||
func getFirstHostFromRoute(r Route) string {
|
||||
for _, matcherSet := range r.MatcherSets {
|
||||
for _, m := range matcherSet {
|
||||
// Check if the matcher is of type MatchHost (value or pointer)
|
||||
switch hm := m.(type) {
|
||||
case MatchHost:
|
||||
if len(hm) > 0 {
|
||||
return hm[0]
|
||||
}
|
||||
case *MatchHost:
|
||||
if len(*hm) > 0 {
|
||||
return (*hm)[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Return an empty string if it's a catch-all route (no specific host)
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -244,7 +244,7 @@ func (c *Cache) makeRoom() {
|
||||
// strategy; generating random numbers is cheap and
|
||||
// ensures a much better distribution.
|
||||
//nolint:gosec
|
||||
rnd := weakrand.Intn(len(c.cache))
|
||||
rnd := weakrand.IntN(len(c.cache))
|
||||
i := 0
|
||||
for key := range c.cache {
|
||||
if i == rnd {
|
||||
@@ -287,7 +287,7 @@ type Account struct {
|
||||
|
||||
// The user's hashed password, in Modular Crypt Format (with `$` prefix)
|
||||
// or base64-encoded.
|
||||
Password string `json:"password"`
|
||||
Password string `json:"password"` //nolint:gosec // false positive, this is a hashed password
|
||||
|
||||
password []byte
|
||||
}
|
||||
|
||||
@@ -412,10 +412,12 @@ func CELMatcherImpl(macroName, funcName string, matcherDataTypes []*cel.Type, fa
|
||||
return nil, fmt.Errorf("unsupported matcher data type: %s, %s", matcherDataTypes[0], matcherDataTypes[1])
|
||||
}
|
||||
case 3:
|
||||
// nolint:gosec // false positive, impossible to be out of bounds; see: https://github.com/securego/gosec/issues/1525
|
||||
if matcherDataTypes[0] == cel.StringType && matcherDataTypes[1] == cel.StringType && matcherDataTypes[2] == cel.StringType {
|
||||
macro = parser.NewGlobalMacro(macroName, 3, celMatcherStringListMacroExpander(funcName))
|
||||
matcherDataTypes = []*cel.Type{cel.ListType(cel.StringType)}
|
||||
} else {
|
||||
// nolint:gosec // false positive, impossible to be out of bounds; see: https://github.com/securego/gosec/issues/1525
|
||||
return nil, fmt.Errorf("unsupported matcher data type: %s, %s, %s", matcherDataTypes[0], matcherDataTypes[1], matcherDataTypes[2])
|
||||
}
|
||||
}
|
||||
@@ -665,12 +667,29 @@ func celMatcherJSONMacroExpander(funcName string) parser.MacroExpander {
|
||||
// map literals containing heterogeneous values, in this case string and list
|
||||
// of string.
|
||||
func CELValueToMapStrList(data ref.Val) (map[string][]string, error) {
|
||||
mapStrType := reflect.TypeOf(map[string]any{})
|
||||
// Prefer map[string]any, but newer cel-go versions may return map[any]any
|
||||
mapStrType := reflect.TypeFor[map[string]any]()
|
||||
mapStrRaw, err := data.ConvertToNative(mapStrType)
|
||||
var mapStrIface map[string]any
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Try map[any]any and convert keys to strings
|
||||
mapAnyType := reflect.TypeFor[map[any]any]()
|
||||
mapAnyRaw, err2 := data.ConvertToNative(mapAnyType)
|
||||
if err2 != nil {
|
||||
return nil, err
|
||||
}
|
||||
mapAnyIface := mapAnyRaw.(map[any]any)
|
||||
mapStrIface = make(map[string]any, len(mapAnyIface))
|
||||
for k, v := range mapAnyIface {
|
||||
ks, ok := k.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported map key type in header match: %T", k)
|
||||
}
|
||||
mapStrIface[ks] = v
|
||||
}
|
||||
} else {
|
||||
mapStrIface = mapStrRaw.(map[string]any)
|
||||
}
|
||||
mapStrIface := mapStrRaw.(map[string]any)
|
||||
mapStrListStr := make(map[string][]string, len(mapStrIface))
|
||||
for k, v := range mapStrIface {
|
||||
switch val := v.(type) {
|
||||
@@ -685,13 +704,26 @@ func CELValueToMapStrList(data ref.Val) (map[string][]string, error) {
|
||||
for i, elem := range val {
|
||||
strVal, ok := elem.(types.String)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported value type in header match: %T", val)
|
||||
return nil, fmt.Errorf("unsupported value type in matcher input: %T", val)
|
||||
}
|
||||
convVals[i] = string(strVal)
|
||||
}
|
||||
mapStrListStr[k] = convVals
|
||||
case []any:
|
||||
convVals := make([]string, len(val))
|
||||
for i, elem := range val {
|
||||
switch e := elem.(type) {
|
||||
case string:
|
||||
convVals[i] = e
|
||||
case types.String:
|
||||
convVals[i] = string(e)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported element type in matcher input list: %T", elem)
|
||||
}
|
||||
}
|
||||
mapStrListStr[k] = convVals
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported value type in header match: %T", val)
|
||||
return nil, fmt.Errorf("unsupported value type in matcher input: %T", val)
|
||||
}
|
||||
}
|
||||
return mapStrListStr, nil
|
||||
|
||||
@@ -168,8 +168,8 @@ func (enc *Encode) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
|
||||
// caches without knowing about our changes...
|
||||
if etag := r.Header.Get("If-None-Match"); etag != "" && !strings.HasPrefix(etag, "W/") {
|
||||
ourSuffix := "-" + encName + `"`
|
||||
if strings.HasSuffix(etag, ourSuffix) {
|
||||
etag = strings.TrimSuffix(etag, ourSuffix) + `"`
|
||||
if before, ok := strings.CutSuffix(etag, ourSuffix); ok {
|
||||
etag = before + `"`
|
||||
r.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ package caddyhttp
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -101,7 +101,7 @@ func randString(n int, sameCase bool) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
//nolint:gosec
|
||||
b[i] = dict[weakrand.Int63()%int64(len(dict))]
|
||||
b[i] = dict[weakrand.IntN(len(dict))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
@@ -169,6 +169,7 @@ func (fsrv *FileServer) serveBrowse(fileSystem fs.FS, root, dirPath string, w ht
|
||||
|
||||
// Actual files
|
||||
for _, item := range listing.Items {
|
||||
//nolint:gosec // not sure how this could be XSS unless you lose control of the file system (like aren't sanitizing) and client ignores Content-Type of text/plain
|
||||
if _, err := fmt.Fprintf(writer, "%s\t%s\t%s\n",
|
||||
item.Name, item.HumanSize(), item.HumanModTime("January 2, 2006 at 15:04:05"),
|
||||
); err != nil {
|
||||
|
||||
@@ -404,7 +404,7 @@ func (m MatchFile) selectFile(r *http.Request) (bool, error) {
|
||||
}
|
||||
|
||||
// for each glob result, combine all the forms of the path
|
||||
var candidates []matchCandidate
|
||||
candidates := make([]matchCandidate, 0, len(globResults))
|
||||
for _, result := range globResults {
|
||||
candidates = append(candidates, matchCandidate{
|
||||
fullpath: result,
|
||||
@@ -720,6 +720,7 @@ var globSafeRepl = strings.NewReplacer(
|
||||
"*", "\\*",
|
||||
"[", "\\[",
|
||||
"?", "\\?",
|
||||
"\\", "\\\\",
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -20,7 +20,9 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
@@ -28,6 +30,13 @@ import (
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
path string
|
||||
expectedPath string
|
||||
expectedType string
|
||||
matched bool
|
||||
}
|
||||
|
||||
func TestFileMatcher(t *testing.T) {
|
||||
// Windows doesn't like colons in files names
|
||||
isWindows := runtime.GOOS == "windows"
|
||||
@@ -45,12 +54,7 @@ func TestFileMatcher(t *testing.T) {
|
||||
f.Close()
|
||||
}
|
||||
|
||||
for i, tc := range []struct {
|
||||
path string
|
||||
expectedPath string
|
||||
expectedType string
|
||||
matched bool
|
||||
}{
|
||||
for i, tc := range []testCase{
|
||||
{
|
||||
path: "/foo.txt",
|
||||
expectedPath: "/foo.txt",
|
||||
@@ -116,44 +120,71 @@ func TestFileMatcher(t *testing.T) {
|
||||
matched: !isWindows,
|
||||
},
|
||||
} {
|
||||
m := &MatchFile{
|
||||
fsmap: &filesystems.FileSystemMap{},
|
||||
Root: "./testdata",
|
||||
TryFiles: []string{"{http.request.uri.path}", "{http.request.uri.path}/"},
|
||||
}
|
||||
fileMatcherTest(t, i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
u, err := url.Parse(tc.path)
|
||||
if err != nil {
|
||||
t.Errorf("Test %d: parsing path: %v", i, err)
|
||||
}
|
||||
func TestFileMatcherNonWindows(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
req := &http.Request{URL: u}
|
||||
repl := caddyhttp.NewTestReplacer(req)
|
||||
// this is impossible to test on Windows, but tests a security patch for other platforms
|
||||
tc := testCase{
|
||||
path: "/foodir/secr%5Cet.txt",
|
||||
expectedPath: "/foodir/secr\\et.txt",
|
||||
expectedType: "file",
|
||||
matched: true,
|
||||
}
|
||||
|
||||
result, err := m.MatchWithError(req)
|
||||
if err != nil {
|
||||
t.Errorf("Test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if result != tc.matched {
|
||||
t.Errorf("Test %d: expected match=%t, got %t", i, tc.matched, result)
|
||||
}
|
||||
f, err := os.Create(filepath.Join("testdata", strings.TrimPrefix(tc.expectedPath, "/")))
|
||||
if err != nil {
|
||||
t.Fatalf("could not create test file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
rel, ok := repl.Get("http.matchers.file.relative")
|
||||
if !ok && result {
|
||||
t.Errorf("Test %d: expected replacer value", i)
|
||||
}
|
||||
if !result {
|
||||
continue
|
||||
}
|
||||
fileMatcherTest(t, 0, tc)
|
||||
}
|
||||
|
||||
if rel != tc.expectedPath {
|
||||
t.Errorf("Test %d: actual path: %v, expected: %v", i, rel, tc.expectedPath)
|
||||
}
|
||||
func fileMatcherTest(t *testing.T, i int, tc testCase) {
|
||||
m := &MatchFile{
|
||||
fsmap: &filesystems.FileSystemMap{},
|
||||
Root: "./testdata",
|
||||
TryFiles: []string{"{http.request.uri.path}", "{http.request.uri.path}/"},
|
||||
}
|
||||
|
||||
fileType, _ := repl.Get("http.matchers.file.type")
|
||||
if fileType != tc.expectedType {
|
||||
t.Errorf("Test %d: actual file type: %v, expected: %v", i, fileType, tc.expectedType)
|
||||
}
|
||||
u, err := url.Parse(tc.path)
|
||||
if err != nil {
|
||||
t.Errorf("Test %d: parsing path: %v", i, err)
|
||||
}
|
||||
|
||||
req := &http.Request{URL: u}
|
||||
repl := caddyhttp.NewTestReplacer(req)
|
||||
|
||||
result, err := m.MatchWithError(req)
|
||||
if err != nil {
|
||||
t.Errorf("Test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if result != tc.matched {
|
||||
t.Errorf("Test %d: expected match=%t, got %t", i, tc.matched, result)
|
||||
}
|
||||
|
||||
rel, ok := repl.Get("http.matchers.file.relative")
|
||||
if !ok && result {
|
||||
t.Errorf("Test %d: expected replacer value", i)
|
||||
}
|
||||
if !result {
|
||||
return
|
||||
}
|
||||
|
||||
if rel != tc.expectedPath {
|
||||
t.Errorf("Test %d: actual path: %v, expected: %v", i, rel, tc.expectedPath)
|
||||
}
|
||||
|
||||
fileType, _ := repl.Get("http.matchers.file.type")
|
||||
if fileType != tc.expectedType {
|
||||
t.Errorf("Test %d: actual file type: %v, expected: %v", i, fileType, tc.expectedType)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -125,6 +125,11 @@ type FileServer struct {
|
||||
// When possible, all paths are resolved to their absolute form before
|
||||
// comparisons are made. For maximum clarity and explictness, use complete,
|
||||
// absolute paths; or, for greater portability, use relative paths instead.
|
||||
//
|
||||
// Note that hide comparisons are case-sensitive. On case-insensitive
|
||||
// filesystems, requests with different path casing may still resolve to the
|
||||
// same file or directory on disk, so hide should not be treated as a
|
||||
// security boundary for sensitive paths.
|
||||
Hide []string `json:"hide,omitempty"`
|
||||
|
||||
// The names of files to try as index files if a folder is requested.
|
||||
@@ -601,7 +606,7 @@ func (fsrv *FileServer) openFile(fileSystem fs.FS, filename string, w http.Respo
|
||||
// maybe the server is under load and ran out of file descriptors?
|
||||
// have client wait arbitrary seconds to help prevent a stampede
|
||||
//nolint:gosec
|
||||
backoff := weakrand.Intn(maxBackoff-minBackoff) + minBackoff
|
||||
backoff := weakrand.IntN(maxBackoff-minBackoff) + minBackoff
|
||||
w.Header().Set("Retry-After", strconv.Itoa(backoff))
|
||||
if c := fsrv.logger.Check(zapcore.DebugLevel, "retry after backoff"); c != nil {
|
||||
c.Write(zap.String("filename", filename), zap.Int("backoff", backoff), zap.Error(err))
|
||||
|
||||
@@ -168,8 +168,6 @@ func parseReqHdrCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
|
||||
}
|
||||
h.Next() // consume the directive name again (matcher parsing resets)
|
||||
|
||||
configValues := []httpcaddyfile.ConfigValue{}
|
||||
|
||||
if !h.NextArg() {
|
||||
return nil, h.ArgErr()
|
||||
}
|
||||
@@ -204,7 +202,7 @@ func parseReqHdrCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue,
|
||||
return nil, h.Err(err.Error())
|
||||
}
|
||||
|
||||
configValues = append(configValues, h.NewRoute(matcherSet, hdr)...)
|
||||
configValues := h.NewRoute(matcherSet, hdr)
|
||||
|
||||
if h.NextArg() {
|
||||
return nil, h.ArgErr()
|
||||
|
||||
@@ -161,11 +161,11 @@ func (ops *HeaderOps) Provision(_ caddy.Context) error {
|
||||
|
||||
// containsPlaceholders checks if the string contains Caddy placeholder syntax {key}
|
||||
func containsPlaceholders(s string) bool {
|
||||
openIdx := strings.Index(s, "{")
|
||||
if openIdx == -1 {
|
||||
_, after, ok := strings.Cut(s, "{")
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
closeIdx := strings.Index(s[openIdx+1:], "}")
|
||||
closeIdx := strings.Index(after, "}")
|
||||
if closeIdx == -1 {
|
||||
return false
|
||||
}
|
||||
@@ -217,7 +217,10 @@ type RespHeaderOps struct {
|
||||
}
|
||||
|
||||
// ApplyTo applies ops to hdr using repl.
|
||||
func (ops HeaderOps) ApplyTo(hdr http.Header, repl *caddy.Replacer) {
|
||||
func (ops *HeaderOps) ApplyTo(hdr http.Header, repl *caddy.Replacer) {
|
||||
if ops == nil {
|
||||
return
|
||||
}
|
||||
// before manipulating headers in other ways, check if there
|
||||
// is configuration to delete all headers, and do that first
|
||||
// because if a header is to be added, we don't want to delete
|
||||
|
||||
@@ -17,6 +17,7 @@ package intercept
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -175,10 +176,35 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy
|
||||
c.Write(zap.Int("handler", rec.handlerIndex))
|
||||
}
|
||||
|
||||
// pass the request through the response handler routes
|
||||
return rec.handler.Routes.Compile(next).ServeHTTP(w, r)
|
||||
// response recorder doesn't create a new copy of the original headers, they're
|
||||
// present in the original response writer
|
||||
// create a new recorder to see if any response body from the new handler is present,
|
||||
// if not, use the already buffered response body
|
||||
recorder := caddyhttp.NewResponseRecorder(w, nil, nil)
|
||||
if err := rec.handler.Routes.Compile(emptyHandler).ServeHTTP(recorder, r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// no new response status and the status is not 0
|
||||
if recorder.Status() == 0 && rec.Status() != 0 {
|
||||
w.WriteHeader(rec.Status())
|
||||
}
|
||||
|
||||
// no new response body and there is some in the original response
|
||||
// TODO: what if the new response doesn't have a body by design?
|
||||
// see: https://github.com/caddyserver/caddy/pull/6232#issue-2235224400
|
||||
if recorder.Size() == 0 && buf.Len() > 0 {
|
||||
_, err := io.Copy(w, buf)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// this handler does nothing because everything we need is already buffered
|
||||
var emptyHandler caddyhttp.Handler = caddyhttp.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax:
|
||||
//
|
||||
// intercept [<matcher>] {
|
||||
|
||||
@@ -15,18 +15,28 @@
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/exp/zapslog"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterSlogHandlerFactory(func(handler slog.Handler, core zapcore.Core, moduleID string) slog.Handler {
|
||||
return &extraFieldsSlogHandler{defaultHandler: handler, core: core, moduleID: moduleID}
|
||||
})
|
||||
}
|
||||
|
||||
// ServerLogConfig describes a server's logging configuration. If
|
||||
// enabled without customization, all requests to this server are
|
||||
// logged to the default logger; logger destinations may be
|
||||
@@ -223,17 +233,21 @@ func errLogValues(err error) (status int, msg string, fields func() []zapcore.Fi
|
||||
|
||||
// ExtraLogFields is a list of extra fields to log with every request.
|
||||
type ExtraLogFields struct {
|
||||
fields []zapcore.Field
|
||||
fields []zapcore.Field
|
||||
handlers sync.Map
|
||||
}
|
||||
|
||||
// Add adds a field to the list of extra fields to log.
|
||||
func (e *ExtraLogFields) Add(field zap.Field) {
|
||||
e.handlers.Clear()
|
||||
e.fields = append(e.fields, field)
|
||||
}
|
||||
|
||||
// Set sets a field in the list of extra fields to log.
|
||||
// If the field already exists, it is replaced.
|
||||
func (e *ExtraLogFields) Set(field zap.Field) {
|
||||
e.handlers.Clear()
|
||||
|
||||
for i := range e.fields {
|
||||
if e.fields[i].Key == field.Key {
|
||||
e.fields[i] = field
|
||||
@@ -243,6 +257,29 @@ func (e *ExtraLogFields) Set(field zap.Field) {
|
||||
e.fields = append(e.fields, field)
|
||||
}
|
||||
|
||||
func (e *ExtraLogFields) getSloggerHandler(handler *extraFieldsSlogHandler) (h slog.Handler) {
|
||||
if existing, ok := e.handlers.Load(handler); ok {
|
||||
return existing.(slog.Handler)
|
||||
}
|
||||
|
||||
if handler.moduleID == "" {
|
||||
h = zapslog.NewHandler(handler.core.With(e.fields))
|
||||
} else {
|
||||
h = zapslog.NewHandler(handler.core.With(e.fields), zapslog.WithName(handler.moduleID))
|
||||
}
|
||||
|
||||
if handler.group != "" {
|
||||
h = h.WithGroup(handler.group)
|
||||
}
|
||||
if handler.attrs != nil {
|
||||
h = h.WithAttrs(handler.attrs)
|
||||
}
|
||||
|
||||
e.handlers.Store(handler, h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
const (
|
||||
// Variable name used to indicate that this request
|
||||
// should be omitted from the access logs
|
||||
@@ -254,3 +291,43 @@ const (
|
||||
// Variable name used to indicate the logger to be used
|
||||
AccessLoggerNameVarKey string = "access_logger_names"
|
||||
)
|
||||
|
||||
type extraFieldsSlogHandler struct {
|
||||
defaultHandler slog.Handler
|
||||
core zapcore.Core
|
||||
moduleID string
|
||||
group string
|
||||
attrs []slog.Attr
|
||||
}
|
||||
|
||||
func (e *extraFieldsSlogHandler) Enabled(ctx context.Context, level slog.Level) bool {
|
||||
return e.defaultHandler.Enabled(ctx, level)
|
||||
}
|
||||
|
||||
func (e *extraFieldsSlogHandler) Handle(ctx context.Context, record slog.Record) error {
|
||||
if elf, ok := ctx.Value(ExtraLogFieldsCtxKey).(*ExtraLogFields); ok {
|
||||
return elf.getSloggerHandler(e).Handle(ctx, record)
|
||||
}
|
||||
|
||||
return e.defaultHandler.Handle(ctx, record)
|
||||
}
|
||||
|
||||
func (e *extraFieldsSlogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return &extraFieldsSlogHandler{
|
||||
e.defaultHandler.WithAttrs(attrs),
|
||||
e.core,
|
||||
e.moduleID,
|
||||
e.group,
|
||||
append(e.attrs, attrs...),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *extraFieldsSlogHandler) WithGroup(name string) slog.Handler {
|
||||
return &extraFieldsSlogHandler{
|
||||
e.defaultHandler.WithGroup(name),
|
||||
e.core,
|
||||
e.moduleID,
|
||||
name,
|
||||
e.attrs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
@@ -26,7 +28,7 @@ func init() {
|
||||
|
||||
// parseCaddyfile sets up the log_append handler from Caddyfile tokens. Syntax:
|
||||
//
|
||||
// log_append [<matcher>] <key> <value>
|
||||
// log_append [<matcher>] [<]<key> <value>
|
||||
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
|
||||
handler := new(LogAppend)
|
||||
err := handler.UnmarshalCaddyfile(h.Dispenser)
|
||||
@@ -43,6 +45,10 @@ func (h *LogAppend) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
if strings.HasPrefix(h.Key, "<") && len(h.Key) > 1 {
|
||||
h.Early = true
|
||||
h.Key = h.Key[1:]
|
||||
}
|
||||
h.Value = d.Val()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -42,6 +44,12 @@ type LogAppend struct {
|
||||
// map, the value of that key will be used. Otherwise
|
||||
// the value will be used as-is as a constant string.
|
||||
Value string `json:"value,omitempty"`
|
||||
|
||||
// Early, if true, adds the log field before calling
|
||||
// the next handler in the chain. By default, the log
|
||||
// field is added on the way back up the middleware chain,
|
||||
// after all subsequent handlers have completed.
|
||||
Early bool `json:"early,omitempty"`
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
@@ -53,13 +61,63 @@ func (LogAppend) CaddyModule() caddy.ModuleInfo {
|
||||
}
|
||||
|
||||
func (h LogAppend) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||
// Run the next handler in the chain first.
|
||||
// Determine if we need to add the log field early.
|
||||
// We do if the Early flag is set, or for convenience,
|
||||
// if the value is a special placeholder for the request body.
|
||||
needsEarly := h.Early || h.Value == placeholderRequestBody || h.Value == placeholderRequestBodyBase64
|
||||
|
||||
// Check if we need to buffer the response for special placeholders
|
||||
needsResponseBody := h.Value == placeholderResponseBody || h.Value == placeholderResponseBodyBase64
|
||||
|
||||
if needsEarly && !needsResponseBody {
|
||||
// Add the log field before calling the next handler
|
||||
// (but not if we need the response body, which isn't available yet)
|
||||
h.addLogField(r, nil)
|
||||
}
|
||||
|
||||
var rec caddyhttp.ResponseRecorder
|
||||
var buf *bytes.Buffer
|
||||
|
||||
if needsResponseBody {
|
||||
// Wrap the response writer with a recorder to capture the response body
|
||||
buf = new(bytes.Buffer)
|
||||
rec = caddyhttp.NewResponseRecorder(w, buf, func(status int, header http.Header) bool {
|
||||
// Always buffer the response when we need to log the body
|
||||
return true
|
||||
})
|
||||
w = rec
|
||||
}
|
||||
|
||||
// Run the next handler in the chain.
|
||||
// If an error occurs, we still want to add
|
||||
// any extra log fields that we can, so we
|
||||
// hold onto the error and return it later.
|
||||
handlerErr := next.ServeHTTP(w, r)
|
||||
|
||||
// On the way back up the chain, add the extra log field
|
||||
if needsResponseBody {
|
||||
// Write the buffered response to the client
|
||||
if rec.Buffered() {
|
||||
h.addLogField(r, buf)
|
||||
err := rec.WriteResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
if !h.Early {
|
||||
// Add the log field after the handler completes
|
||||
h.addLogField(r, buf)
|
||||
}
|
||||
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// addLogField adds the log field to the request's extra log fields.
|
||||
// If buf is not nil, it contains the buffered response body for special
|
||||
// response body placeholders.
|
||||
func (h LogAppend) addLogField(r *http.Request, buf *bytes.Buffer) {
|
||||
ctx := r.Context()
|
||||
|
||||
vars := ctx.Value(caddyhttp.VarsCtxKey).(map[string]any)
|
||||
@@ -67,7 +125,21 @@ func (h LogAppend) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
|
||||
extra := ctx.Value(caddyhttp.ExtraLogFieldsCtxKey).(*caddyhttp.ExtraLogFields)
|
||||
|
||||
var varValue any
|
||||
if strings.HasPrefix(h.Value, "{") &&
|
||||
|
||||
// Handle special case placeholders for response body
|
||||
if h.Value == placeholderResponseBody {
|
||||
if buf != nil {
|
||||
varValue = buf.String()
|
||||
} else {
|
||||
varValue = ""
|
||||
}
|
||||
} else if h.Value == placeholderResponseBodyBase64 {
|
||||
if buf != nil {
|
||||
varValue = base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
} else {
|
||||
varValue = ""
|
||||
}
|
||||
} else if strings.HasPrefix(h.Value, "{") &&
|
||||
strings.HasSuffix(h.Value, "}") &&
|
||||
strings.Count(h.Value, "{") == 1 {
|
||||
// the value looks like a placeholder, so get its value
|
||||
@@ -84,10 +156,17 @@ func (h LogAppend) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
|
||||
// We use zap.Any because it will reflect
|
||||
// to the correct type for us.
|
||||
extra.Add(zap.Any(h.Key, varValue))
|
||||
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
const (
|
||||
// Special placeholder values that are handled by log_append
|
||||
// rather than by the replacer.
|
||||
placeholderRequestBody = "{http.request.body}"
|
||||
placeholderRequestBodyBase64 = "{http.request.body_base64}"
|
||||
placeholderResponseBody = "{http.response.body}"
|
||||
placeholderResponseBodyBase64 = "{http.response.body_base64}"
|
||||
)
|
||||
|
||||
// Interface guards
|
||||
var (
|
||||
_ caddyhttp.MiddlewareHandler = (*LogAppend)(nil)
|
||||
@@ -110,6 +110,7 @@ func (t LoggableTLSConnState) MarshalLogObject(enc zapcore.ObjectEncoder) error
|
||||
enc.AddUint16("cipher_suite", t.CipherSuite)
|
||||
enc.AddString("proto", t.NegotiatedProtocol)
|
||||
enc.AddString("server_name", t.ServerName)
|
||||
enc.AddBool("ech", t.ECHAccepted)
|
||||
if len(t.PeerCertificates) > 0 {
|
||||
enc.AddString("client_common_name", t.PeerCertificates[0].Subject.CommonName)
|
||||
enc.AddString("client_serial", t.PeerCertificates[0].SerialNumber.String())
|
||||
|
||||
@@ -262,13 +262,17 @@ func (m MatchHost) Provision(_ caddy.Context) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting hostname '%s' to ASCII: %v", host, err)
|
||||
}
|
||||
if asciiHost != host {
|
||||
m[i] = asciiHost
|
||||
}
|
||||
normalizedHost := strings.ToLower(asciiHost)
|
||||
if firstI, ok := seen[normalizedHost]; ok {
|
||||
return fmt.Errorf("host at index %d is repeated at index %d: %s", firstI, i, host)
|
||||
}
|
||||
// Normalize exact hosts for standardized comparison in large-list fastpath later on.
|
||||
// Keep wildcards/placeholders untouched.
|
||||
if m.fuzzy(asciiHost) {
|
||||
m[i] = asciiHost
|
||||
} else {
|
||||
m[i] = normalizedHost
|
||||
}
|
||||
seen[normalizedHost] = i
|
||||
}
|
||||
|
||||
@@ -312,14 +316,15 @@ func (m MatchHost) MatchWithError(r *http.Request) (bool, error) {
|
||||
}
|
||||
|
||||
if m.large() {
|
||||
reqHostLower := strings.ToLower(reqHost)
|
||||
// fast path: locate exact match using binary search (about 100-1000x faster for large lists)
|
||||
pos := sort.Search(len(m), func(i int) bool {
|
||||
if m.fuzzy(m[i]) {
|
||||
return false
|
||||
}
|
||||
return m[i] >= reqHost
|
||||
return m[i] >= reqHostLower
|
||||
})
|
||||
if pos < len(m) && m[pos] == reqHost {
|
||||
if pos < len(m) && m[pos] == reqHostLower {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
@@ -533,6 +538,7 @@ func (m MatchPath) MatchWithError(r *http.Request) (bool, error) {
|
||||
}
|
||||
|
||||
func (MatchPath) matchPatternWithEscapeSequence(escapedPath, matchPath string) bool {
|
||||
escapedPath = strings.ToLower(escapedPath)
|
||||
// We would just compare the pattern against r.URL.Path,
|
||||
// but the pattern contains %, indicating that we should
|
||||
// compare at least some part of the path in raw/escaped
|
||||
@@ -632,8 +638,8 @@ func (MatchPath) matchPatternWithEscapeSequence(escapedPath, matchPath string) b
|
||||
// we can now treat rawpath globs (%*) as regular globs (*)
|
||||
matchPath = strings.ReplaceAll(matchPath, "%*", "*")
|
||||
|
||||
// ignore error here because we can't handle it anyway=
|
||||
matches, _ := path.Match(matchPath, sb.String())
|
||||
// ignore error here because we can't handle it anyway
|
||||
matches, _ := path.Match(matchPath, strings.ToLower(sb.String()))
|
||||
return matches
|
||||
}
|
||||
|
||||
|
||||
@@ -412,6 +412,16 @@ func TestPathMatcher(t *testing.T) {
|
||||
input: "/foo%2fbar/baz",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
match: MatchPath{"/admin%2fpanel"},
|
||||
input: "/ADMIN%2fpanel",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
match: MatchPath{"/admin%2fpa*el"},
|
||||
input: "/ADMIN%2fPaAzZLm123NEL",
|
||||
expect: true,
|
||||
},
|
||||
} {
|
||||
err := tc.match.Provision(caddy.Context{})
|
||||
if err == nil && tc.provisionErr {
|
||||
@@ -957,6 +967,7 @@ func TestVarREMatcher(t *testing.T) {
|
||||
desc string
|
||||
match MatchVarsRE
|
||||
input VarsMiddleware
|
||||
headers http.Header
|
||||
expect bool
|
||||
expectRepl map[string]string
|
||||
}{
|
||||
@@ -991,6 +1002,14 @@ func TestVarREMatcher(t *testing.T) {
|
||||
input: VarsMiddleware{"Var1": "var1Value"},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
desc: "placeholder key value containing braces is not double-expanded",
|
||||
match: MatchVarsRE{"{http.request.header.X-Input}": &MatchRegexp{Pattern: ".+", Name: "val"}},
|
||||
input: VarsMiddleware{},
|
||||
headers: http.Header{"X-Input": []string{"{env.HOME}"}},
|
||||
expect: true,
|
||||
expectRepl: map[string]string{"val.0": "{env.HOME}"},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -1007,7 +1026,7 @@ func TestVarREMatcher(t *testing.T) {
|
||||
}
|
||||
|
||||
// set up the fake request and its Replacer
|
||||
req := &http.Request{URL: new(url.URL), Method: http.MethodGet}
|
||||
req := &http.Request{URL: new(url.URL), Method: http.MethodGet, Header: tc.headers}
|
||||
repl := caddy.NewReplacer()
|
||||
ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl)
|
||||
ctx = context.WithValue(ctx, VarsCtxKey, make(map[string]any))
|
||||
|
||||
+128
-12
@@ -17,14 +17,60 @@ import (
|
||||
|
||||
// Metrics configures metrics observations.
|
||||
// EXPERIMENTAL and subject to change or removal.
|
||||
//
|
||||
// Example configuration:
|
||||
//
|
||||
// {
|
||||
// "apps": {
|
||||
// "http": {
|
||||
// "metrics": {
|
||||
// "per_host": true,
|
||||
// "observe_catchall_hosts": false
|
||||
// },
|
||||
// "servers": {
|
||||
// "srv0": {
|
||||
// "routes": [{
|
||||
// "match": [{"host": ["example.com", "www.example.com"]}],
|
||||
// "handle": [{"handler": "static_response", "body": "Hello"}]
|
||||
// }]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// In this configuration:
|
||||
// - Requests to example.com and www.example.com get individual host labels
|
||||
// - All other hosts (e.g., attacker.com) are aggregated under "_other" label
|
||||
// - This prevents unlimited cardinality from arbitrary Host headers
|
||||
type Metrics struct {
|
||||
// Enable per-host metrics. Enabling this option may
|
||||
// incur high-memory consumption, depending on the number of hosts
|
||||
// managed by Caddy.
|
||||
//
|
||||
// CARDINALITY PROTECTION: To prevent unbounded cardinality attacks,
|
||||
// only explicitly configured hosts (via host matchers) are allowed
|
||||
// by default. Other hosts are aggregated under the "_other" label.
|
||||
// See AllowCatchAllHosts to change this behavior.
|
||||
PerHost bool `json:"per_host,omitempty"`
|
||||
|
||||
init sync.Once
|
||||
httpMetrics *httpMetrics `json:"-"`
|
||||
// Allow metrics for catch-all hosts (hosts without explicit configuration).
|
||||
// When false (default), only hosts explicitly configured via host matchers
|
||||
// will get individual metrics labels. All other hosts will be aggregated
|
||||
// under the "_other" label to prevent cardinality explosion.
|
||||
//
|
||||
// This is automatically enabled for HTTPS servers (since certificates provide
|
||||
// some protection against unbounded cardinality), but disabled for HTTP servers
|
||||
// by default to prevent cardinality attacks from arbitrary Host headers.
|
||||
//
|
||||
// Set to true to allow all hosts to get individual metrics (NOT RECOMMENDED
|
||||
// for production environments exposed to the internet).
|
||||
ObserveCatchallHosts bool `json:"observe_catchall_hosts,omitempty"`
|
||||
|
||||
init sync.Once
|
||||
httpMetrics *httpMetrics
|
||||
allowedHosts map[string]struct{}
|
||||
hasHTTPSServer bool
|
||||
}
|
||||
|
||||
type httpMetrics struct {
|
||||
@@ -101,6 +147,63 @@ func initHTTPMetrics(ctx caddy.Context, metrics *Metrics) {
|
||||
}, httpLabels)
|
||||
}
|
||||
|
||||
// scanConfigForHosts scans the HTTP app configuration to build a set of allowed hosts
|
||||
// for metrics collection, similar to how auto-HTTPS scans for domain names.
|
||||
func (m *Metrics) scanConfigForHosts(app *App) {
|
||||
if !m.PerHost {
|
||||
return
|
||||
}
|
||||
|
||||
m.allowedHosts = make(map[string]struct{})
|
||||
m.hasHTTPSServer = false
|
||||
|
||||
for _, srv := range app.Servers {
|
||||
// Check if this server has TLS enabled
|
||||
serverHasTLS := len(srv.TLSConnPolicies) > 0
|
||||
if serverHasTLS {
|
||||
m.hasHTTPSServer = true
|
||||
}
|
||||
|
||||
// Collect hosts from route matchers
|
||||
for _, route := range srv.Routes {
|
||||
for _, matcherSet := range route.MatcherSets {
|
||||
for _, matcher := range matcherSet {
|
||||
if hm, ok := matcher.(*MatchHost); ok {
|
||||
for _, host := range *hm {
|
||||
// Only allow non-fuzzy hosts to prevent unbounded cardinality
|
||||
if !hm.fuzzy(host) {
|
||||
m.allowedHosts[strings.ToLower(host)] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldAllowHostMetrics determines if metrics should be collected for the given host.
|
||||
// This implements the cardinality protection by only allowing metrics for:
|
||||
// 1. Explicitly configured hosts
|
||||
// 2. Catch-all requests on HTTPS servers (if AllowCatchAllHosts is true or auto-enabled)
|
||||
// 3. Catch-all requests on HTTP servers only if explicitly allowed
|
||||
func (m *Metrics) shouldAllowHostMetrics(host string, isHTTPS bool) bool {
|
||||
if !m.PerHost {
|
||||
return true // host won't be used in labels anyway
|
||||
}
|
||||
|
||||
normalizedHost := strings.ToLower(host)
|
||||
|
||||
// Always allow explicitly configured hosts
|
||||
if _, exists := m.allowedHosts[normalizedHost]; exists {
|
||||
return true
|
||||
}
|
||||
|
||||
// For catch-all requests (not in allowed hosts)
|
||||
allowCatchAll := m.ObserveCatchallHosts || (isHTTPS && m.hasHTTPSServer)
|
||||
return allowCatchAll
|
||||
}
|
||||
|
||||
// serverNameFromContext extracts the current server name from the context.
|
||||
// Returns "UNKNOWN" if none is available (should probably never happen).
|
||||
func serverNameFromContext(ctx context.Context) string {
|
||||
@@ -111,21 +214,24 @@ func serverNameFromContext(ctx context.Context) string {
|
||||
return srv.name
|
||||
}
|
||||
|
||||
type metricsInstrumentedHandler struct {
|
||||
// metricsInstrumentedRoute wraps a compiled route Handler with metrics
|
||||
// instrumentation. It wraps the entire compiled route chain once,
|
||||
// collecting metrics only once per route match.
|
||||
type metricsInstrumentedRoute struct {
|
||||
handler string
|
||||
mh MiddlewareHandler
|
||||
next Handler
|
||||
metrics *Metrics
|
||||
}
|
||||
|
||||
func newMetricsInstrumentedHandler(ctx caddy.Context, handler string, mh MiddlewareHandler, metrics *Metrics) *metricsInstrumentedHandler {
|
||||
metrics.init.Do(func() {
|
||||
initHTTPMetrics(ctx, metrics)
|
||||
func newMetricsInstrumentedRoute(ctx caddy.Context, handler string, next Handler, m *Metrics) *metricsInstrumentedRoute {
|
||||
m.init.Do(func() {
|
||||
initHTTPMetrics(ctx, m)
|
||||
})
|
||||
|
||||
return &metricsInstrumentedHandler{handler, mh, metrics}
|
||||
return &metricsInstrumentedRoute{handler: handler, next: next, metrics: m}
|
||||
}
|
||||
|
||||
func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next Handler) error {
|
||||
func (h *metricsInstrumentedRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
server := serverNameFromContext(r.Context())
|
||||
labels := prometheus.Labels{"server": server, "handler": h.handler}
|
||||
method := metrics.SanitizeMethod(r.Method)
|
||||
@@ -133,9 +239,19 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
|
||||
// of a panic
|
||||
statusLabels := prometheus.Labels{"server": server, "handler": h.handler, "method": method, "code": ""}
|
||||
|
||||
// Determine if this is an HTTPS request
|
||||
isHTTPS := r.TLS != nil
|
||||
|
||||
if h.metrics.PerHost {
|
||||
labels["host"] = strings.ToLower(r.Host)
|
||||
statusLabels["host"] = strings.ToLower(r.Host)
|
||||
// Apply cardinality protection for host metrics
|
||||
if h.metrics.shouldAllowHostMetrics(r.Host, isHTTPS) {
|
||||
labels["host"] = strings.ToLower(r.Host)
|
||||
statusLabels["host"] = strings.ToLower(r.Host)
|
||||
} else {
|
||||
// Use a catch-all label for unallowed hosts to prevent cardinality explosion
|
||||
labels["host"] = "_other"
|
||||
statusLabels["host"] = "_other"
|
||||
}
|
||||
}
|
||||
|
||||
inFlight := h.metrics.httpMetrics.requestInFlight.With(labels)
|
||||
@@ -154,7 +270,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
|
||||
return false
|
||||
})
|
||||
wrec := NewResponseRecorder(w, nil, writeHeaderRecorder)
|
||||
err := h.mh.ServeHTTP(wrec, r, next)
|
||||
err := h.next.ServeHTTP(wrec, r)
|
||||
dur := time.Since(start).Seconds()
|
||||
h.metrics.httpMetrics.requestCount.With(labels).Inc()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package caddyhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -46,16 +47,12 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
|
||||
return handlerErr
|
||||
})
|
||||
|
||||
mh := middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
|
||||
return h.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
ih := newMetricsInstrumentedHandler(ctx, "bar", mh, metrics)
|
||||
ih := newMetricsInstrumentedRoute(ctx, "bar", h, metrics)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if actual := ih.ServeHTTP(w, r, h); actual != handlerErr {
|
||||
if actual := ih.ServeHTTP(w, r); actual != handlerErr {
|
||||
t.Errorf("Not same: expected %#v, but got %#v", handlerErr, actual)
|
||||
}
|
||||
if actual := testutil.ToFloat64(metrics.httpMetrics.requestInFlight); actual != 0.0 {
|
||||
@@ -63,19 +60,19 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
handlerErr = nil
|
||||
if err := ih.ServeHTTP(w, r, h); err != nil {
|
||||
if err := ih.ServeHTTP(w, r); err != nil {
|
||||
t.Errorf("Received unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// an empty handler - no errors, no header written
|
||||
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
|
||||
emptyHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
})
|
||||
ih = newMetricsInstrumentedHandler(ctx, "empty", mh, metrics)
|
||||
ih = newMetricsInstrumentedRoute(ctx, "empty", emptyHandler, metrics)
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
if err := ih.ServeHTTP(w, r, h); err != nil {
|
||||
if err := ih.ServeHTTP(w, r); err != nil {
|
||||
t.Errorf("Received unexpected error: %v", err)
|
||||
}
|
||||
if actual := w.Result().StatusCode; actual != 200 {
|
||||
@@ -86,16 +83,16 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
// handler returning an error with an HTTP status
|
||||
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
|
||||
errHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return Error(http.StatusTooManyRequests, nil)
|
||||
})
|
||||
|
||||
ih = newMetricsInstrumentedHandler(ctx, "foo", mh, metrics)
|
||||
ih = newMetricsInstrumentedRoute(ctx, "foo", errHandler, metrics)
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
if err := ih.ServeHTTP(w, r, nil); err == nil {
|
||||
if err := ih.ServeHTTP(w, r); err == nil {
|
||||
t.Errorf("expected error to be propagated")
|
||||
}
|
||||
|
||||
@@ -206,9 +203,11 @@ func TestMetricsInstrumentedHandler(t *testing.T) {
|
||||
func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
|
||||
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
metrics := &Metrics{
|
||||
PerHost: true,
|
||||
init: sync.Once{},
|
||||
httpMetrics: &httpMetrics{},
|
||||
PerHost: true,
|
||||
ObserveCatchallHosts: true, // Allow all hosts for testing
|
||||
init: sync.Once{},
|
||||
httpMetrics: &httpMetrics{},
|
||||
allowedHosts: make(map[string]struct{}),
|
||||
}
|
||||
handlerErr := errors.New("oh noes")
|
||||
response := []byte("hello world!")
|
||||
@@ -222,16 +221,12 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
|
||||
return handlerErr
|
||||
})
|
||||
|
||||
mh := middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
|
||||
return h.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
ih := newMetricsInstrumentedHandler(ctx, "bar", mh, metrics)
|
||||
ih := newMetricsInstrumentedRoute(ctx, "bar", h, metrics)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if actual := ih.ServeHTTP(w, r, h); actual != handlerErr {
|
||||
if actual := ih.ServeHTTP(w, r); actual != handlerErr {
|
||||
t.Errorf("Not same: expected %#v, but got %#v", handlerErr, actual)
|
||||
}
|
||||
if actual := testutil.ToFloat64(metrics.httpMetrics.requestInFlight); actual != 0.0 {
|
||||
@@ -239,19 +234,19 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
|
||||
}
|
||||
|
||||
handlerErr = nil
|
||||
if err := ih.ServeHTTP(w, r, h); err != nil {
|
||||
if err := ih.ServeHTTP(w, r); err != nil {
|
||||
t.Errorf("Received unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// an empty handler - no errors, no header written
|
||||
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
|
||||
emptyHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
})
|
||||
ih = newMetricsInstrumentedHandler(ctx, "empty", mh, metrics)
|
||||
ih = newMetricsInstrumentedRoute(ctx, "empty", emptyHandler, metrics)
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
if err := ih.ServeHTTP(w, r, h); err != nil {
|
||||
if err := ih.ServeHTTP(w, r); err != nil {
|
||||
t.Errorf("Received unexpected error: %v", err)
|
||||
}
|
||||
if actual := w.Result().StatusCode; actual != 200 {
|
||||
@@ -262,16 +257,16 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
|
||||
}
|
||||
|
||||
// handler returning an error with an HTTP status
|
||||
mh = middlewareHandlerFunc(func(w http.ResponseWriter, r *http.Request, h Handler) error {
|
||||
errHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return Error(http.StatusTooManyRequests, nil)
|
||||
})
|
||||
|
||||
ih = newMetricsInstrumentedHandler(ctx, "foo", mh, metrics)
|
||||
ih = newMetricsInstrumentedRoute(ctx, "foo", errHandler, metrics)
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
if err := ih.ServeHTTP(w, r, nil); err == nil {
|
||||
if err := ih.ServeHTTP(w, r); err == nil {
|
||||
t.Errorf("expected error to be propagated")
|
||||
}
|
||||
|
||||
@@ -379,8 +374,208 @@ func TestMetricsInstrumentedHandlerPerHost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type middlewareHandlerFunc func(http.ResponseWriter, *http.Request, Handler) error
|
||||
func TestMetricsCardinalityProtection(t *testing.T) {
|
||||
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
|
||||
func (f middlewareHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, h Handler) error {
|
||||
return f(w, r, h)
|
||||
// Test 1: Without AllowCatchAllHosts, arbitrary hosts should be mapped to "_other"
|
||||
metrics := &Metrics{
|
||||
PerHost: true,
|
||||
ObserveCatchallHosts: false, // Default - should map unknown hosts to "_other"
|
||||
init: sync.Once{},
|
||||
httpMetrics: &httpMetrics{},
|
||||
allowedHosts: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
// Add one allowed host
|
||||
metrics.allowedHosts["allowed.com"] = struct{}{}
|
||||
|
||||
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Write([]byte("hello"))
|
||||
return nil
|
||||
})
|
||||
|
||||
ih := newMetricsInstrumentedRoute(ctx, "test", h, metrics)
|
||||
|
||||
// Test request to allowed host
|
||||
r1 := httptest.NewRequest("GET", "http://allowed.com/", nil)
|
||||
r1.Host = "allowed.com"
|
||||
w1 := httptest.NewRecorder()
|
||||
ih.ServeHTTP(w1, r1)
|
||||
|
||||
// Test request to unknown host (should be mapped to "_other")
|
||||
r2 := httptest.NewRequest("GET", "http://attacker.com/", nil)
|
||||
r2.Host = "attacker.com"
|
||||
w2 := httptest.NewRecorder()
|
||||
ih.ServeHTTP(w2, r2)
|
||||
|
||||
// Test request to another unknown host (should also be mapped to "_other")
|
||||
r3 := httptest.NewRequest("GET", "http://evil.com/", nil)
|
||||
r3.Host = "evil.com"
|
||||
w3 := httptest.NewRecorder()
|
||||
ih.ServeHTTP(w3, r3)
|
||||
|
||||
// Check that metrics contain:
|
||||
// - One entry for "allowed.com"
|
||||
// - One entry for "_other" (aggregating attacker.com and evil.com)
|
||||
expected := `
|
||||
# HELP caddy_http_requests_total Counter of HTTP(S) requests made.
|
||||
# TYPE caddy_http_requests_total counter
|
||||
caddy_http_requests_total{handler="test",host="_other",server="UNKNOWN"} 2
|
||||
caddy_http_requests_total{handler="test",host="allowed.com",server="UNKNOWN"} 1
|
||||
`
|
||||
|
||||
if err := testutil.GatherAndCompare(ctx.GetMetricsRegistry(), strings.NewReader(expected),
|
||||
"caddy_http_requests_total",
|
||||
); err != nil {
|
||||
t.Errorf("Cardinality protection test failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsHTTPSCatchAll(t *testing.T) {
|
||||
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
|
||||
// Test that HTTPS requests allow catch-all even when AllowCatchAllHosts is false
|
||||
metrics := &Metrics{
|
||||
PerHost: true,
|
||||
ObserveCatchallHosts: false,
|
||||
hasHTTPSServer: true, // Simulate having HTTPS servers
|
||||
init: sync.Once{},
|
||||
httpMetrics: &httpMetrics{},
|
||||
allowedHosts: make(map[string]struct{}), // Empty - no explicitly allowed hosts
|
||||
}
|
||||
|
||||
h := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Write([]byte("hello"))
|
||||
return nil
|
||||
})
|
||||
|
||||
ih := newMetricsInstrumentedRoute(ctx, "test", h, metrics)
|
||||
|
||||
// Test HTTPS request (should be allowed even though not in allowedHosts)
|
||||
r1 := httptest.NewRequest("GET", "https://unknown.com/", nil)
|
||||
r1.Host = "unknown.com"
|
||||
r1.TLS = &tls.ConnectionState{} // Mark as TLS/HTTPS
|
||||
w1 := httptest.NewRecorder()
|
||||
ih.ServeHTTP(w1, r1)
|
||||
|
||||
// Test HTTP request (should be mapped to "_other")
|
||||
r2 := httptest.NewRequest("GET", "http://unknown.com/", nil)
|
||||
r2.Host = "unknown.com"
|
||||
// No TLS field = HTTP request
|
||||
w2 := httptest.NewRecorder()
|
||||
ih.ServeHTTP(w2, r2)
|
||||
|
||||
// Check that HTTPS request gets real host, HTTP gets "_other"
|
||||
expected := `
|
||||
# HELP caddy_http_requests_total Counter of HTTP(S) requests made.
|
||||
# TYPE caddy_http_requests_total counter
|
||||
caddy_http_requests_total{handler="test",host="_other",server="UNKNOWN"} 1
|
||||
caddy_http_requests_total{handler="test",host="unknown.com",server="UNKNOWN"} 1
|
||||
`
|
||||
|
||||
if err := testutil.GatherAndCompare(ctx.GetMetricsRegistry(), strings.NewReader(expected),
|
||||
"caddy_http_requests_total",
|
||||
); err != nil {
|
||||
t.Errorf("HTTPS catch-all test failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsInstrumentedRoute(t *testing.T) {
|
||||
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
m := &Metrics{
|
||||
init: sync.Once{},
|
||||
httpMetrics: &httpMetrics{},
|
||||
}
|
||||
|
||||
handlerErr := errors.New("oh noes")
|
||||
response := []byte("hello world!")
|
||||
innerHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if actual := testutil.ToFloat64(m.httpMetrics.requestInFlight); actual != 1.0 {
|
||||
t.Errorf("Expected requestInFlight to be 1.0, got %v", actual)
|
||||
}
|
||||
if handlerErr == nil {
|
||||
w.Write(response)
|
||||
}
|
||||
return handlerErr
|
||||
})
|
||||
|
||||
ih := newMetricsInstrumentedRoute(ctx, "test_handler", innerHandler, m)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Test with error
|
||||
if actual := ih.ServeHTTP(w, r); actual != handlerErr {
|
||||
t.Errorf("Expected error %v, got %v", handlerErr, actual)
|
||||
}
|
||||
if actual := testutil.ToFloat64(m.httpMetrics.requestInFlight); actual != 0.0 {
|
||||
t.Errorf("Expected requestInFlight to be 0.0 after request, got %v", actual)
|
||||
}
|
||||
if actual := testutil.ToFloat64(m.httpMetrics.requestErrors); actual != 1.0 {
|
||||
t.Errorf("Expected requestErrors to be 1.0, got %v", actual)
|
||||
}
|
||||
|
||||
// Test without error
|
||||
handlerErr = nil
|
||||
w = httptest.NewRecorder()
|
||||
if err := ih.ServeHTTP(w, r); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMetricsInstrumentedRoute(b *testing.B) {
|
||||
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
m := &Metrics{
|
||||
init: sync.Once{},
|
||||
httpMetrics: &httpMetrics{},
|
||||
}
|
||||
|
||||
noopHandler := HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Write([]byte("ok"))
|
||||
return nil
|
||||
})
|
||||
|
||||
ih := newMetricsInstrumentedRoute(ctx, "bench_handler", noopHandler, m)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ih.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSingleRouteMetrics simulates the new behavior where metrics
|
||||
// are collected once for the entire route.
|
||||
func BenchmarkSingleRouteMetrics(b *testing.B) {
|
||||
ctx, _ := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
m := &Metrics{
|
||||
init: sync.Once{},
|
||||
httpMetrics: &httpMetrics{},
|
||||
}
|
||||
|
||||
// Build a chain of 5 plain middleware handlers (no per-handler metrics)
|
||||
var next Handler = HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
})
|
||||
for i := 0; i < 5; i++ {
|
||||
capturedNext := next
|
||||
next = HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return capturedNext.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Wrap the entire chain with a single route-level metrics handler
|
||||
ih := newMetricsInstrumentedRoute(ctx, "handler", next, m)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ih.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
|
||||
var err error
|
||||
|
||||
// include current token, which we treat as an argument here
|
||||
// nolint:prealloc
|
||||
args := []string{h.Val()}
|
||||
args = append(args, h.RemainingArgs()...)
|
||||
|
||||
|
||||
@@ -229,6 +229,21 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
|
||||
req.Body = io.NopCloser(buf) // replace real body with buffered data
|
||||
return buf.String(), true
|
||||
|
||||
case "http.request.body_base64":
|
||||
if req.Body == nil {
|
||||
return "", true
|
||||
}
|
||||
// normally net/http will close the body for us, but since we
|
||||
// are replacing it with a fake one, we have to ensure we close
|
||||
// the real body ourselves when we're done
|
||||
defer req.Body.Close()
|
||||
// read the request body into a buffer (can't pool because we
|
||||
// don't know its lifetime and would have to make a copy anyway)
|
||||
buf := new(bytes.Buffer)
|
||||
_, _ = io.Copy(buf, req.Body) // can't handle error, so just ignore it
|
||||
req.Body = io.NopCloser(buf) // replace real body with buffered data
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), true
|
||||
|
||||
// original request, before any internal changes
|
||||
case "http.request.orig_method":
|
||||
or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
|
||||
@@ -405,7 +420,16 @@ func getReqTLSReplacement(req *http.Request, key string) (any, bool) {
|
||||
if strings.HasPrefix(field, "client.") {
|
||||
cert := getTLSPeerCert(req.TLS)
|
||||
if cert == nil {
|
||||
return nil, false
|
||||
// Instead of returning (nil, false) here, we set it to a dummy
|
||||
// value to fix #7530. This way, even if there is no client cert,
|
||||
// evaluating placeholders with ReplaceKnown() will still remove
|
||||
// the placeholder, which would be expected. It is not expected
|
||||
// for the placeholder to sometimes get removed based on whether
|
||||
// the client presented a cert. We also do not return true here
|
||||
// because we probably should remain accurate about whether a
|
||||
// placeholder is, in fact, known or not.
|
||||
// (This allocation may be slightly inefficient.)
|
||||
cert = new(x509.Certificate)
|
||||
}
|
||||
|
||||
// subject alternate names (SANs)
|
||||
@@ -511,6 +535,8 @@ func getReqTLSReplacement(req *http.Request, key string) (any, bool) {
|
||||
return true, true
|
||||
case "server_name":
|
||||
return req.TLS.ServerName, true
|
||||
case "ech":
|
||||
return req.TLS.ECHAccepted, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -73,8 +73,9 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
|
||||
|
||||
// Collect the results to respond with
|
||||
results := []upstreamStatus{}
|
||||
knownHosts := make(map[string]struct{})
|
||||
|
||||
// Iterate over the upstream pool (needs to be fast)
|
||||
// Iterate over the static upstream pool (needs to be fast)
|
||||
var rangeErr error
|
||||
hosts.Range(func(key, val any) bool {
|
||||
address, ok := key.(string)
|
||||
@@ -95,6 +96,8 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
|
||||
return false
|
||||
}
|
||||
|
||||
knownHosts[address] = struct{}{}
|
||||
|
||||
results = append(results, upstreamStatus{
|
||||
Address: address,
|
||||
NumRequests: upstream.NumRequests(),
|
||||
@@ -103,11 +106,32 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
|
||||
return true
|
||||
})
|
||||
|
||||
// If an error happened during the range, return it
|
||||
currentInFlight := getInFlightRequests()
|
||||
for address, count := range currentInFlight {
|
||||
if _, exists := knownHosts[address]; !exists && count > 0 {
|
||||
results = append(results, upstreamStatus{
|
||||
Address: address,
|
||||
NumRequests: int(count),
|
||||
Fails: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if rangeErr != nil {
|
||||
return rangeErr
|
||||
}
|
||||
|
||||
// Also include dynamic upstreams
|
||||
dynamicHostsMu.RLock()
|
||||
for address, entry := range dynamicHosts {
|
||||
results = append(results, upstreamStatus{
|
||||
Address: address,
|
||||
NumRequests: entry.host.NumRequests(),
|
||||
Fails: entry.host.Fails(),
|
||||
})
|
||||
}
|
||||
dynamicHostsMu.RUnlock()
|
||||
|
||||
err := enc.Encode(results)
|
||||
if err != nil {
|
||||
return caddy.APIError{
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// adminHandlerFixture sets up the global host state for an admin endpoint test
|
||||
// and returns a cleanup function that must be deferred by the caller.
|
||||
//
|
||||
// staticAddrs are inserted into the UsagePool (as a static upstream would be).
|
||||
// dynamicAddrs are inserted into the dynamicHosts map (as a dynamic upstream would be).
|
||||
func adminHandlerFixture(t *testing.T, staticAddrs, dynamicAddrs []string) func() {
|
||||
t.Helper()
|
||||
|
||||
for _, addr := range staticAddrs {
|
||||
u := &Upstream{Dial: addr}
|
||||
u.fillHost()
|
||||
}
|
||||
|
||||
dynamicHostsMu.Lock()
|
||||
for _, addr := range dynamicAddrs {
|
||||
dynamicHosts[addr] = dynamicHostEntry{host: new(Host), lastSeen: time.Now()}
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
|
||||
return func() {
|
||||
// Remove static entries from the UsagePool.
|
||||
for _, addr := range staticAddrs {
|
||||
_, _ = hosts.Delete(addr)
|
||||
}
|
||||
// Remove dynamic entries.
|
||||
dynamicHostsMu.Lock()
|
||||
for _, addr := range dynamicAddrs {
|
||||
delete(dynamicHosts, addr)
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// callAdminUpstreams fires a GET against handleUpstreams and returns the
|
||||
// decoded response body.
|
||||
func callAdminUpstreams(t *testing.T) []upstreamStatus {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, "/reverse_proxy/upstreams", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := adminUpstreams{}
|
||||
if err := handler.handleUpstreams(w, req); err != nil {
|
||||
t.Fatalf("handleUpstreams returned unexpected error: %v", err)
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("expected Content-Type application/json, got %q", ct)
|
||||
}
|
||||
|
||||
var results []upstreamStatus
|
||||
if err := json.NewDecoder(w.Body).Decode(&results); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// resultsByAddress indexes a slice of upstreamStatus by address for easier
|
||||
// lookup in assertions.
|
||||
func resultsByAddress(statuses []upstreamStatus) map[string]upstreamStatus {
|
||||
m := make(map[string]upstreamStatus, len(statuses))
|
||||
for _, s := range statuses {
|
||||
m[s.Address] = s
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsMethodNotAllowed verifies that non-GET methods are rejected.
|
||||
func TestAdminUpstreamsMethodNotAllowed(t *testing.T) {
|
||||
for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodDelete} {
|
||||
req := httptest.NewRequest(method, "/reverse_proxy/upstreams", nil)
|
||||
w := httptest.NewRecorder()
|
||||
err := (adminUpstreams{}).handleUpstreams(w, req)
|
||||
if err == nil {
|
||||
t.Errorf("method %s: expected an error, got nil", method)
|
||||
continue
|
||||
}
|
||||
apiErr, ok := err.(interface{ HTTPStatus() int })
|
||||
if !ok {
|
||||
// caddy.APIError stores the code in HTTPStatus field, access via the
|
||||
// exported interface it satisfies indirectly; just check non-nil.
|
||||
continue
|
||||
}
|
||||
if code := apiErr.HTTPStatus(); code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("method %s: expected 405, got %d", method, code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsEmpty verifies that an empty response is valid JSON when
|
||||
// no upstreams are registered.
|
||||
func TestAdminUpstreamsEmpty(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
results := callAdminUpstreams(t)
|
||||
if results == nil {
|
||||
t.Error("expected non-nil (empty) slice, got nil")
|
||||
}
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 results with empty pools, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsStaticOnly verifies that static upstreams (from the
|
||||
// UsagePool) appear in the response with correct addresses.
|
||||
func TestAdminUpstreamsStaticOnly(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
cleanup := adminHandlerFixture(t,
|
||||
[]string{"10.0.0.1:80", "10.0.0.2:80"},
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
results := callAdminUpstreams(t)
|
||||
byAddr := resultsByAddress(results)
|
||||
|
||||
for _, addr := range []string{"10.0.0.1:80", "10.0.0.2:80"} {
|
||||
if _, ok := byAddr[addr]; !ok {
|
||||
t.Errorf("expected static upstream %q in response", addr)
|
||||
}
|
||||
}
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected exactly 2 results, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsDynamicOnly verifies that dynamic upstreams (from
|
||||
// dynamicHosts) appear in the response with correct addresses.
|
||||
func TestAdminUpstreamsDynamicOnly(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
cleanup := adminHandlerFixture(t,
|
||||
nil,
|
||||
[]string{"10.0.1.1:80", "10.0.1.2:80"},
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
results := callAdminUpstreams(t)
|
||||
byAddr := resultsByAddress(results)
|
||||
|
||||
for _, addr := range []string{"10.0.1.1:80", "10.0.1.2:80"} {
|
||||
if _, ok := byAddr[addr]; !ok {
|
||||
t.Errorf("expected dynamic upstream %q in response", addr)
|
||||
}
|
||||
}
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected exactly 2 results, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsBothPools verifies that static and dynamic upstreams are
|
||||
// both present in the same response and that there is no overlap or omission.
|
||||
func TestAdminUpstreamsBothPools(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
cleanup := adminHandlerFixture(t,
|
||||
[]string{"10.0.2.1:80"},
|
||||
[]string{"10.0.2.2:80"},
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
results := callAdminUpstreams(t)
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results (1 static + 1 dynamic), got %d", len(results))
|
||||
}
|
||||
|
||||
byAddr := resultsByAddress(results)
|
||||
if _, ok := byAddr["10.0.2.1:80"]; !ok {
|
||||
t.Error("static upstream missing from response")
|
||||
}
|
||||
if _, ok := byAddr["10.0.2.2:80"]; !ok {
|
||||
t.Error("dynamic upstream missing from response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsNoOverlapBetweenPools verifies that an address registered
|
||||
// only as a static upstream does not also appear as a dynamic entry, and
|
||||
// vice-versa.
|
||||
func TestAdminUpstreamsNoOverlapBetweenPools(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
cleanup := adminHandlerFixture(t,
|
||||
[]string{"10.0.3.1:80"},
|
||||
[]string{"10.0.3.2:80"},
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
results := callAdminUpstreams(t)
|
||||
seen := make(map[string]int)
|
||||
for _, r := range results {
|
||||
seen[r.Address]++
|
||||
}
|
||||
for addr, count := range seen {
|
||||
if count > 1 {
|
||||
t.Errorf("address %q appeared %d times; expected exactly once", addr, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsReportsFailCounts verifies that fail counts accumulated on
|
||||
// a dynamic upstream's Host are reflected in the response.
|
||||
func TestAdminUpstreamsReportsFailCounts(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
const addr = "10.0.4.1:80"
|
||||
h := new(Host)
|
||||
_ = h.countFail(3)
|
||||
|
||||
dynamicHostsMu.Lock()
|
||||
dynamicHosts[addr] = dynamicHostEntry{host: h, lastSeen: time.Now()}
|
||||
dynamicHostsMu.Unlock()
|
||||
defer func() {
|
||||
dynamicHostsMu.Lock()
|
||||
delete(dynamicHosts, addr)
|
||||
dynamicHostsMu.Unlock()
|
||||
}()
|
||||
|
||||
results := callAdminUpstreams(t)
|
||||
byAddr := resultsByAddress(results)
|
||||
|
||||
status, ok := byAddr[addr]
|
||||
if !ok {
|
||||
t.Fatalf("expected %q in response", addr)
|
||||
}
|
||||
if status.Fails != 3 {
|
||||
t.Errorf("expected Fails=3, got %d", status.Fails)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpstreamsReportsNumRequests verifies that the active request count
|
||||
// for a static upstream is reflected in the response.
|
||||
func TestAdminUpstreamsReportsNumRequests(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
const addr = "10.0.4.2:80"
|
||||
u := &Upstream{Dial: addr}
|
||||
u.fillHost()
|
||||
defer func() { _, _ = hosts.Delete(addr) }()
|
||||
|
||||
_ = u.Host.countRequest(2)
|
||||
defer func() { _ = u.Host.countRequest(-2) }()
|
||||
|
||||
results := callAdminUpstreams(t)
|
||||
byAddr := resultsByAddress(results)
|
||||
|
||||
status, ok := byAddr[addr]
|
||||
if !ok {
|
||||
t.Fatalf("expected %q in response", addr)
|
||||
}
|
||||
if status.NumRequests != 2 {
|
||||
t.Errorf("expected NumRequests=2, got %d", status.NumRequests)
|
||||
}
|
||||
}
|
||||
@@ -888,8 +888,11 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
if commonScheme == "http" && te.TLSEnabled() {
|
||||
return d.Errf("upstream address scheme is HTTP but transport is configured for HTTP+TLS (HTTPS)")
|
||||
}
|
||||
if te, ok := transport.(*HTTPTransport); ok && commonScheme == "h2c" {
|
||||
te.Versions = []string{"h2c", "2"}
|
||||
if h2ct, ok := transport.(H2CTransport); ok && commonScheme == "h2c" {
|
||||
err := h2ct.EnableH2C()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if commonScheme == "https" {
|
||||
return d.Errf("upstreams are configured for HTTPS but transport module does not support TLS: %T", transport)
|
||||
@@ -1525,6 +1528,7 @@ func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
return d.Errf("bad delay value '%s': %v", d.Val(), err)
|
||||
}
|
||||
u.FallbackDelay = caddy.Duration(dur)
|
||||
|
||||
case "grace_period":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
|
||||
// resetDynamicHosts clears global dynamic host state between tests.
|
||||
func resetDynamicHosts() {
|
||||
dynamicHostsMu.Lock()
|
||||
dynamicHosts = make(map[string]dynamicHostEntry)
|
||||
dynamicHostsMu.Unlock()
|
||||
// Reset the Once so cleanup goroutine tests can re-trigger if needed.
|
||||
dynamicHostsCleanerOnce = sync.Once{}
|
||||
}
|
||||
|
||||
// TestFillDynamicHostCreatesEntry verifies that calling fillDynamicHost on a
|
||||
// new address inserts an entry into dynamicHosts and assigns a non-nil Host.
|
||||
func TestFillDynamicHostCreatesEntry(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
u := &Upstream{Dial: "192.0.2.1:80"}
|
||||
u.fillDynamicHost()
|
||||
|
||||
if u.Host == nil {
|
||||
t.Fatal("expected Host to be set after fillDynamicHost")
|
||||
}
|
||||
|
||||
dynamicHostsMu.RLock()
|
||||
entry, ok := dynamicHosts["192.0.2.1:80"]
|
||||
dynamicHostsMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
t.Fatal("expected entry in dynamicHosts map")
|
||||
}
|
||||
if entry.host != u.Host {
|
||||
t.Error("dynamicHosts entry host should be the same pointer assigned to Upstream.Host")
|
||||
}
|
||||
if entry.lastSeen.IsZero() {
|
||||
t.Error("expected lastSeen to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFillDynamicHostReusesSameHost verifies that two calls for the same address
|
||||
// return the exact same *Host pointer so that state (e.g. fail counts) is shared.
|
||||
func TestFillDynamicHostReusesSameHost(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
u1 := &Upstream{Dial: "192.0.2.2:80"}
|
||||
u1.fillDynamicHost()
|
||||
|
||||
u2 := &Upstream{Dial: "192.0.2.2:80"}
|
||||
u2.fillDynamicHost()
|
||||
|
||||
if u1.Host != u2.Host {
|
||||
t.Error("expected both upstreams to share the same *Host pointer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFillDynamicHostUpdatesLastSeen verifies that a second call for the same
|
||||
// address advances the lastSeen timestamp.
|
||||
func TestFillDynamicHostUpdatesLastSeen(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
u := &Upstream{Dial: "192.0.2.3:80"}
|
||||
u.fillDynamicHost()
|
||||
|
||||
dynamicHostsMu.RLock()
|
||||
first := dynamicHosts["192.0.2.3:80"].lastSeen
|
||||
dynamicHostsMu.RUnlock()
|
||||
|
||||
// Ensure measurable time passes.
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
u2 := &Upstream{Dial: "192.0.2.3:80"}
|
||||
u2.fillDynamicHost()
|
||||
|
||||
dynamicHostsMu.RLock()
|
||||
second := dynamicHosts["192.0.2.3:80"].lastSeen
|
||||
dynamicHostsMu.RUnlock()
|
||||
|
||||
if !second.After(first) {
|
||||
t.Error("expected lastSeen to be updated on second fillDynamicHost call")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFillDynamicHostIndependentAddresses verifies that different addresses get
|
||||
// independent Host entries.
|
||||
func TestFillDynamicHostIndependentAddresses(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
u1 := &Upstream{Dial: "192.0.2.4:80"}
|
||||
u1.fillDynamicHost()
|
||||
|
||||
u2 := &Upstream{Dial: "192.0.2.5:80"}
|
||||
u2.fillDynamicHost()
|
||||
|
||||
if u1.Host == u2.Host {
|
||||
t.Error("different addresses should have different *Host entries")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFillDynamicHostPreservesFailCount verifies that fail counts on a dynamic
|
||||
// host survive across multiple fillDynamicHost calls (simulating sequential
|
||||
// requests), which is the core behaviour fixed by this change.
|
||||
func TestFillDynamicHostPreservesFailCount(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
// First "request": provision and record a failure.
|
||||
u1 := &Upstream{Dial: "192.0.2.6:80"}
|
||||
u1.fillDynamicHost()
|
||||
_ = u1.Host.countFail(1)
|
||||
|
||||
if u1.Host.Fails() != 1 {
|
||||
t.Fatalf("expected 1 fail, got %d", u1.Host.Fails())
|
||||
}
|
||||
|
||||
// Second "request": provision the same address again (new *Upstream, same address).
|
||||
u2 := &Upstream{Dial: "192.0.2.6:80"}
|
||||
u2.fillDynamicHost()
|
||||
|
||||
if u2.Host.Fails() != 1 {
|
||||
t.Errorf("expected fail count to persist across fillDynamicHost calls, got %d", u2.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvisionUpstreamDynamic verifies that provisionUpstream with dynamic=true
|
||||
// uses fillDynamicHost (not the UsagePool) and sets healthCheckPolicy /
|
||||
// MaxRequests correctly from handler config.
|
||||
func TestProvisionUpstreamDynamic(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
passive := &PassiveHealthChecks{
|
||||
FailDuration: caddy.Duration(10 * time.Second),
|
||||
MaxFails: 3,
|
||||
UnhealthyRequestCount: 5,
|
||||
}
|
||||
h := Handler{
|
||||
HealthChecks: &HealthChecks{
|
||||
Passive: passive,
|
||||
},
|
||||
}
|
||||
|
||||
u := &Upstream{Dial: "192.0.2.7:80"}
|
||||
h.provisionUpstream(u, true)
|
||||
|
||||
if u.Host == nil {
|
||||
t.Fatal("Host should be set after provisionUpstream")
|
||||
}
|
||||
if u.healthCheckPolicy != passive {
|
||||
t.Error("healthCheckPolicy should point to the handler's PassiveHealthChecks")
|
||||
}
|
||||
if u.MaxRequests != 5 {
|
||||
t.Errorf("expected MaxRequests=5 from UnhealthyRequestCount, got %d", u.MaxRequests)
|
||||
}
|
||||
|
||||
// Must be in dynamicHosts, not in the static UsagePool.
|
||||
dynamicHostsMu.RLock()
|
||||
_, inDynamic := dynamicHosts["192.0.2.7:80"]
|
||||
dynamicHostsMu.RUnlock()
|
||||
if !inDynamic {
|
||||
t.Error("dynamic upstream should be stored in dynamicHosts")
|
||||
}
|
||||
_, inPool := hosts.References("192.0.2.7:80")
|
||||
if inPool {
|
||||
t.Error("dynamic upstream should NOT be stored in the static UsagePool")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvisionUpstreamStatic verifies that provisionUpstream with dynamic=false
|
||||
// uses the UsagePool and does NOT insert into dynamicHosts.
|
||||
func TestProvisionUpstreamStatic(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
h := Handler{}
|
||||
|
||||
u := &Upstream{Dial: "192.0.2.8:80"}
|
||||
h.provisionUpstream(u, false)
|
||||
|
||||
if u.Host == nil {
|
||||
t.Fatal("Host should be set after provisionUpstream")
|
||||
}
|
||||
|
||||
refs, inPool := hosts.References("192.0.2.8:80")
|
||||
if !inPool {
|
||||
t.Error("static upstream should be in the UsagePool")
|
||||
}
|
||||
if refs != 1 {
|
||||
t.Errorf("expected ref count 1, got %d", refs)
|
||||
}
|
||||
|
||||
dynamicHostsMu.RLock()
|
||||
_, inDynamic := dynamicHosts["192.0.2.8:80"]
|
||||
dynamicHostsMu.RUnlock()
|
||||
if inDynamic {
|
||||
t.Error("static upstream should NOT be in dynamicHosts")
|
||||
}
|
||||
|
||||
// Clean up the pool entry we just added.
|
||||
_, _ = hosts.Delete("192.0.2.8:80")
|
||||
}
|
||||
|
||||
// TestDynamicHostHealthyConsultsFails verifies the end-to-end passive health
|
||||
// check path: after enough failures are recorded against a dynamic upstream's
|
||||
// shared *Host, Healthy() returns false for a newly provisioned *Upstream with
|
||||
// the same address.
|
||||
func TestDynamicHostHealthyConsultsFails(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
passive := &PassiveHealthChecks{
|
||||
FailDuration: caddy.Duration(time.Minute),
|
||||
MaxFails: 2,
|
||||
}
|
||||
h := Handler{
|
||||
HealthChecks: &HealthChecks{Passive: passive},
|
||||
}
|
||||
|
||||
// First request: provision and record two failures.
|
||||
u1 := &Upstream{Dial: "192.0.2.9:80"}
|
||||
h.provisionUpstream(u1, true)
|
||||
|
||||
_ = u1.Host.countFail(1)
|
||||
_ = u1.Host.countFail(1)
|
||||
|
||||
// Second request: fresh *Upstream, same address.
|
||||
u2 := &Upstream{Dial: "192.0.2.9:80"}
|
||||
h.provisionUpstream(u2, true)
|
||||
|
||||
if u2.Healthy() {
|
||||
t.Error("upstream should be unhealthy after MaxFails failures have been recorded against its shared Host")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicHostCleanupEvictsStaleEntries verifies that the cleanup sweep
|
||||
// removes entries whose lastSeen is older than dynamicHostIdleExpiry.
|
||||
func TestDynamicHostCleanupEvictsStaleEntries(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
const addr = "192.0.2.10:80"
|
||||
|
||||
// Insert an entry directly with a lastSeen far in the past.
|
||||
dynamicHostsMu.Lock()
|
||||
dynamicHosts[addr] = dynamicHostEntry{
|
||||
host: new(Host),
|
||||
lastSeen: time.Now().Add(-2 * dynamicHostIdleExpiry),
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
|
||||
// Run the cleanup logic inline (same logic as the goroutine).
|
||||
dynamicHostsMu.Lock()
|
||||
for a, entry := range dynamicHosts {
|
||||
if time.Since(entry.lastSeen) > dynamicHostIdleExpiry {
|
||||
delete(dynamicHosts, a)
|
||||
}
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
|
||||
dynamicHostsMu.RLock()
|
||||
_, stillPresent := dynamicHosts[addr]
|
||||
dynamicHostsMu.RUnlock()
|
||||
|
||||
if stillPresent {
|
||||
t.Error("stale dynamic host entry should have been evicted by cleanup sweep")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicHostCleanupRetainsFreshEntries verifies that the cleanup sweep
|
||||
// keeps entries whose lastSeen is within dynamicHostIdleExpiry.
|
||||
func TestDynamicHostCleanupRetainsFreshEntries(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
const addr = "192.0.2.11:80"
|
||||
|
||||
dynamicHostsMu.Lock()
|
||||
dynamicHosts[addr] = dynamicHostEntry{
|
||||
host: new(Host),
|
||||
lastSeen: time.Now(),
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
|
||||
// Run the cleanup logic inline.
|
||||
dynamicHostsMu.Lock()
|
||||
for a, entry := range dynamicHosts {
|
||||
if time.Since(entry.lastSeen) > dynamicHostIdleExpiry {
|
||||
delete(dynamicHosts, a)
|
||||
}
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
|
||||
dynamicHostsMu.RLock()
|
||||
_, stillPresent := dynamicHosts[addr]
|
||||
dynamicHostsMu.RUnlock()
|
||||
|
||||
if !stillPresent {
|
||||
t.Error("fresh dynamic host entry should be retained by cleanup sweep")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicHostConcurrentFillHost verifies that concurrent calls to
|
||||
// fillDynamicHost for the same address all get the same *Host pointer and
|
||||
// don't race (run with -race).
|
||||
func TestDynamicHostConcurrentFillHost(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
|
||||
const addr = "192.0.2.12:80"
|
||||
const goroutines = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
hosts := make([]*Host, goroutines)
|
||||
|
||||
for i := range goroutines {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
u := &Upstream{Dial: addr}
|
||||
u.fillDynamicHost()
|
||||
hosts[idx] = u.Host
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
first := hosts[0]
|
||||
for i, h := range hosts {
|
||||
if h != first {
|
||||
t.Errorf("goroutine %d got a different *Host pointer; expected all to share the same entry", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/fcgi"
|
||||
@@ -197,7 +197,7 @@ func generateRandFile(size int) (p string, m string) {
|
||||
h := md5.New()
|
||||
for i := 0; i < size/16; i++ {
|
||||
buf := make([]byte, 16)
|
||||
binary.PutVarint(buf, rand.Int63())
|
||||
binary.PutVarint(buf, rand.Int64())
|
||||
if _, err := fo.Write(buf); err != nil {
|
||||
log.Printf("[ERROR] failed to write buffer: %v\n", err)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ package fastcgi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -23,9 +24,12 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"golang.org/x/text/language"
|
||||
"golang.org/x/text/search"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
@@ -33,7 +37,11 @@ import (
|
||||
"github.com/caddyserver/caddy/v2/modules/caddytls"
|
||||
)
|
||||
|
||||
var noopLogger = zap.NewNop()
|
||||
var (
|
||||
ErrInvalidSplitPath = errors.New("split path contains non-ASCII characters")
|
||||
|
||||
noopLogger = zap.NewNop()
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterModule(Transport{})
|
||||
@@ -50,6 +58,9 @@ type Transport struct {
|
||||
// actual resource (CGI script) name, and the second piece will be set to
|
||||
// PATH_INFO for the CGI script to use.
|
||||
//
|
||||
// Split paths can only contain ASCII characters.
|
||||
// Comparison is case-insensitive.
|
||||
//
|
||||
// Future enhancements should be careful to avoid CVE-2019-11043,
|
||||
// which can be mitigated with use of a try_files-like behavior
|
||||
// that 404s if the fastcgi path info is not found.
|
||||
@@ -109,9 +120,45 @@ func (t *Transport) Provision(ctx caddy.Context) error {
|
||||
t.DialTimeout = caddy.Duration(3 * time.Second)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
for i, split := range t.SplitPath {
|
||||
b.Grow(len(split))
|
||||
|
||||
for j := 0; j < len(split); j++ {
|
||||
c := split[j]
|
||||
if c >= utf8.RuneSelf {
|
||||
return ErrInvalidSplitPath
|
||||
}
|
||||
|
||||
if 'A' <= c && c <= 'Z' {
|
||||
b.WriteByte(c + 'a' - 'A')
|
||||
} else {
|
||||
b.WriteByte(c)
|
||||
}
|
||||
}
|
||||
|
||||
t.SplitPath[i] = b.String()
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultBufferSizes enables request buffering for fastcgi if not configured.
|
||||
// This is because most fastcgi servers are php-fpm that require the content length to be set to read the body, golang
|
||||
// std has fastcgi implementation that doesn't need this value to process the body, but we can safely assume that's
|
||||
// not used.
|
||||
// http3 requests have a negative content length for GET and HEAD requests, if that header is not sent.
|
||||
// see: https://github.com/caddyserver/caddy/issues/6678#issuecomment-2472224182
|
||||
// Though it appears even if CONTENT_LENGTH is invalid, php-fpm can handle just fine if the body is empty (no Stdin records sent).
|
||||
// php-fpm will hang if there is any data in the body though, https://github.com/caddyserver/caddy/issues/5420#issuecomment-2415943516
|
||||
|
||||
// TODO: better default buffering for fastcgi requests without content length, in theory a value of 1 should be enough, make it bigger anyway
|
||||
func (t Transport) DefaultBufferSizes() (int64, int64) {
|
||||
return 4096, 0
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper.
|
||||
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
server := r.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server)
|
||||
@@ -371,8 +418,15 @@ func (t Transport) buildEnv(r *http.Request) (envVars, error) {
|
||||
return env, nil
|
||||
}
|
||||
|
||||
var splitSearchNonASCII = search.New(language.Und, search.IgnoreCase)
|
||||
|
||||
// splitPos returns the index where path should
|
||||
// be split based on t.SplitPath.
|
||||
//
|
||||
// example: if splitPath is [".php"]
|
||||
// "/path/to/script.php/some/path": ("/path/to/script.php", "/some/path")
|
||||
//
|
||||
// Adapted from FrankenPHP's code (copyright 2026 Kévin Dunglas, MIT license)
|
||||
func (t Transport) splitPos(path string) int {
|
||||
// TODO: from v1...
|
||||
// if httpserver.CaseSensitivePath {
|
||||
@@ -382,12 +436,54 @@ func (t Transport) splitPos(path string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
lowerPath := strings.ToLower(path)
|
||||
pathLen := len(path)
|
||||
|
||||
// We are sure that split strings are all ASCII-only and lower-case because of validation and normalization in Provision().
|
||||
for _, split := range t.SplitPath {
|
||||
if idx := strings.Index(lowerPath, strings.ToLower(split)); idx > -1 {
|
||||
return idx + len(split)
|
||||
splitLen := len(split)
|
||||
|
||||
for i := range pathLen {
|
||||
if path[i] >= utf8.RuneSelf {
|
||||
if _, end := splitSearchNonASCII.IndexString(path, split); end > -1 {
|
||||
return end
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if i+splitLen > pathLen {
|
||||
continue
|
||||
}
|
||||
|
||||
match := true
|
||||
for j := range splitLen {
|
||||
c := path[i+j]
|
||||
|
||||
if c >= utf8.RuneSelf {
|
||||
if _, end := splitSearchNonASCII.IndexString(path, split); end > -1 {
|
||||
return end
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if 'A' <= c && c <= 'Z' {
|
||||
c += 'a' - 'A'
|
||||
}
|
||||
|
||||
if c != split[j] {
|
||||
match = false
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if match {
|
||||
return i + splitLen
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
@@ -427,6 +523,7 @@ var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_")
|
||||
var (
|
||||
_ zapcore.ObjectMarshaler = (*loggableEnv)(nil)
|
||||
|
||||
_ caddy.Provisioner = (*Transport)(nil)
|
||||
_ http.RoundTripper = (*Transport)(nil)
|
||||
_ caddy.Provisioner = (*Transport)(nil)
|
||||
_ http.RoundTripper = (*Transport)(nil)
|
||||
_ reverseproxy.BufferedTransport = (*Transport)(nil)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
package fastcgi
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
|
||||
func TestProvisionSplitPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
splitPath []string
|
||||
wantErr error
|
||||
wantSplitPath []string
|
||||
}{
|
||||
{
|
||||
name: "valid lowercase split path",
|
||||
splitPath: []string{".php"},
|
||||
wantErr: nil,
|
||||
wantSplitPath: []string{".php"},
|
||||
},
|
||||
{
|
||||
name: "valid uppercase split path normalized",
|
||||
splitPath: []string{".PHP"},
|
||||
wantErr: nil,
|
||||
wantSplitPath: []string{".php"},
|
||||
},
|
||||
{
|
||||
name: "valid mixed case split path normalized",
|
||||
splitPath: []string{".PhP", ".PHTML"},
|
||||
wantErr: nil,
|
||||
wantSplitPath: []string{".php", ".phtml"},
|
||||
},
|
||||
{
|
||||
name: "empty split path",
|
||||
splitPath: []string{},
|
||||
wantErr: nil,
|
||||
wantSplitPath: []string{},
|
||||
},
|
||||
{
|
||||
name: "non-ASCII character in split path rejected",
|
||||
splitPath: []string{".php", ".Ⱥphp"},
|
||||
wantErr: ErrInvalidSplitPath,
|
||||
},
|
||||
{
|
||||
name: "unicode character in split path rejected",
|
||||
splitPath: []string{".phpⱥ"},
|
||||
wantErr: ErrInvalidSplitPath,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tr := Transport{SplitPath: tt.splitPath}
|
||||
err := tr.Provision(caddy.Context{})
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantSplitPath, tr.SplitPath)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitPos(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
splitPath []string
|
||||
wantPos int
|
||||
}{
|
||||
{
|
||||
name: "simple php extension",
|
||||
path: "/path/to/script.php",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 19,
|
||||
},
|
||||
{
|
||||
name: "php extension with path info",
|
||||
path: "/path/to/script.php/some/path",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 19,
|
||||
},
|
||||
{
|
||||
name: "case insensitive match",
|
||||
path: "/path/to/script.PHP",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 19,
|
||||
},
|
||||
{
|
||||
name: "mixed case match",
|
||||
path: "/path/to/script.PhP/info",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 19,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
path: "/path/to/script.txt",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: -1,
|
||||
},
|
||||
{
|
||||
name: "empty split path",
|
||||
path: "/path/to/script.php",
|
||||
splitPath: []string{},
|
||||
wantPos: 0,
|
||||
},
|
||||
{
|
||||
name: "multiple split paths first match",
|
||||
path: "/path/to/script.php",
|
||||
splitPath: []string{".php", ".phtml"},
|
||||
wantPos: 19,
|
||||
},
|
||||
{
|
||||
name: "multiple split paths second match",
|
||||
path: "/path/to/script.phtml",
|
||||
splitPath: []string{".php", ".phtml"},
|
||||
wantPos: 21,
|
||||
},
|
||||
// Unicode case-folding tests (security fix for GHSA-g966-83w7-6w38)
|
||||
// U+023A (Ⱥ) lowercases to U+2C65 (ⱥ), which has different UTF-8 byte length
|
||||
// Ⱥ: 2 bytes (C8 BA), ⱥ: 3 bytes (E2 B1 A5)
|
||||
{
|
||||
name: "unicode path with case-folding length expansion",
|
||||
path: "/ȺȺȺȺshell.php",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 18, // correct position in original string
|
||||
},
|
||||
{
|
||||
name: "unicode path with extension after expansion chars",
|
||||
path: "/ȺȺȺȺshell.php/path/info",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 18,
|
||||
},
|
||||
{
|
||||
name: "unicode in filename with multiple php occurrences",
|
||||
path: "/ȺȺȺȺshell.php.txt.php",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 18, // should match first .php, not be confused by byte offset shift
|
||||
},
|
||||
{
|
||||
name: "unicode case insensitive extension",
|
||||
path: "/ȺȺȺȺshell.PHP",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 18,
|
||||
},
|
||||
{
|
||||
name: "unicode in middle of path",
|
||||
path: "/path/Ⱥtest/script.php",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 23, // Ⱥ is 2 bytes, so path is 23 bytes total, .php ends at byte 23
|
||||
},
|
||||
{
|
||||
name: "unicode only in directory not filename",
|
||||
path: "/Ⱥ/script.php",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 14,
|
||||
},
|
||||
// Additional Unicode characters that expand when lowercased
|
||||
// U+0130 (İ - Turkish capital I with dot) lowercases to U+0069 + U+0307
|
||||
{
|
||||
name: "turkish capital I with dot",
|
||||
path: "/İtest.php",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 11,
|
||||
},
|
||||
// Ensure standard ASCII still works correctly
|
||||
{
|
||||
name: "ascii only path with case variation",
|
||||
path: "/PATH/TO/SCRIPT.PHP/INFO",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 19,
|
||||
},
|
||||
{
|
||||
name: "path at root",
|
||||
path: "/index.php",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 10,
|
||||
},
|
||||
{
|
||||
name: "extension in middle of filename",
|
||||
path: "/test.php.bak",
|
||||
splitPath: []string{".php"},
|
||||
wantPos: 9,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotPos := Transport{SplitPath: tt.splitPath}.splitPos(tt.path)
|
||||
assert.Equal(t, tt.wantPos, gotPos, "splitPos(%q, %v)", tt.path, tt.splitPath)
|
||||
|
||||
// Verify that the split produces valid substrings
|
||||
if gotPos > 0 && gotPos <= len(tt.path) {
|
||||
scriptName := tt.path[:gotPos]
|
||||
pathInfo := tt.path[gotPos:]
|
||||
|
||||
// The script name should end with one of the split extensions (case-insensitive)
|
||||
hasValidEnding := false
|
||||
for _, split := range tt.splitPath {
|
||||
if strings.HasSuffix(strings.ToLower(scriptName), split) {
|
||||
hasValidEnding = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasValidEnding, "script name %q should end with one of %v", scriptName, tt.splitPath)
|
||||
|
||||
// Original path should be reconstructable
|
||||
assert.Equal(t, tt.path, scriptName+pathInfo, "path should be reconstructable from split parts")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSplitPosUnicodeSecurityRegression specifically tests the vulnerability
|
||||
// described in GHSA-g966-83w7-6w38 where Unicode case-folding caused
|
||||
// incorrect SCRIPT_NAME/PATH_INFO splitting
|
||||
func TestSplitPosUnicodeSecurityRegression(t *testing.T) {
|
||||
// U+023A: Ⱥ (UTF-8: C8 BA). Lowercase is ⱥ (UTF-8: E2 B1 A5), longer in bytes.
|
||||
path := "/ȺȺȺȺshell.php.txt.php"
|
||||
split := []string{".php"}
|
||||
|
||||
pos := Transport{SplitPath: split}.splitPos(path)
|
||||
|
||||
// The vulnerable code would return 22 (computed on lowercased string)
|
||||
// The correct code should return 18 (position in original string)
|
||||
expectedPos := strings.Index(path, ".php") + len(".php")
|
||||
assert.Equal(t, expectedPos, pos, "split position should match first .php in original string")
|
||||
assert.Equal(t, 18, pos, "split position should be 18, not 22")
|
||||
|
||||
if pos > 0 && pos <= len(path) {
|
||||
scriptName := path[:pos]
|
||||
pathInfo := path[pos:]
|
||||
|
||||
assert.Equal(t, "/ȺȺȺȺshell.php", scriptName, "script name should be the path up to first .php")
|
||||
assert.Equal(t, ".txt.php", pathInfo, "path info should be the remainder after first .php")
|
||||
}
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func encodeSize(b []byte, size uint32) int {
|
||||
binary.BigEndian.PutUint32(b, size)
|
||||
return 4
|
||||
}
|
||||
b[0] = byte(size)
|
||||
b[0] = byte(size) //nolint:gosec // false positive; b is made 8 bytes long, then this function is always called with b being at least 4 or 1 byte long
|
||||
return 1
|
||||
}
|
||||
|
||||
|
||||
@@ -208,6 +208,24 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error)
|
||||
for _, from := range sortedHeadersToCopy {
|
||||
to := http.CanonicalHeaderKey(headersToCopy[from])
|
||||
placeholderName := "http.reverse_proxy.header." + http.CanonicalHeaderKey(from)
|
||||
|
||||
// Always delete the client-supplied header before conditionally setting
|
||||
// it from the auth response. Without this, a client that pre-supplies a
|
||||
// header listed in copy_headers can inject arbitrary values when the auth
|
||||
// service does not return that header: the MatchNot guard below would
|
||||
// skip the Set entirely, leaving the original client-controlled value
|
||||
// intact and forwarding it to the backend.
|
||||
copyHeaderRoutes = append(copyHeaderRoutes, caddyhttp.Route{
|
||||
HandlersRaw: []json.RawMessage{caddyconfig.JSONModuleObject(
|
||||
&headers.Handler{
|
||||
Request: &headers.HeaderOps{
|
||||
Delete: []string{to},
|
||||
},
|
||||
},
|
||||
"handler", "headers", nil,
|
||||
)},
|
||||
})
|
||||
|
||||
handler := &headers.Handler{
|
||||
Request: &headers.HeaderOps{
|
||||
Set: http.Header{
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
func TestAddForwardedHeadersNonIP(t *testing.T) {
|
||||
h := Handler{}
|
||||
|
||||
// Simulate a request with a non-IP remote address (e.g. SCION, abstract socket, or hostname)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "my-weird-network:12345"
|
||||
|
||||
// Mock the context variables required by Caddy.
|
||||
// We need to inject the variable map manually since we aren't running the full server.
|
||||
vars := map[string]interface{}{
|
||||
caddyhttp.TrustedProxyVarKey: false,
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// Execute the unexported function
|
||||
err := h.addForwardedHeaders(req)
|
||||
|
||||
// Expectation: No error should be returned for non-IP addresses.
|
||||
// The function should simply skip the trusted proxy check.
|
||||
if err != nil {
|
||||
t.Errorf("expected no error for non-IP address, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddForwardedHeaders_UnixSocketTrusted(t *testing.T) {
|
||||
h := Handler{}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
req.RemoteAddr = "@"
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4, 10.0.0.1")
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "original.example.com")
|
||||
|
||||
vars := map[string]interface{}{
|
||||
caddyhttp.TrustedProxyVarKey: true,
|
||||
caddyhttp.ClientIPVarKey: "1.2.3.4",
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
err := h.addForwardedHeaders(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if got := req.Header.Get("X-Forwarded-For"); got != "1.2.3.4, 10.0.0.1" {
|
||||
t.Errorf("X-Forwarded-For = %q, want %q", got, "1.2.3.4, 10.0.0.1")
|
||||
}
|
||||
if got := req.Header.Get("X-Forwarded-Proto"); got != "https" {
|
||||
t.Errorf("X-Forwarded-Proto = %q, want %q", got, "https")
|
||||
}
|
||||
if got := req.Header.Get("X-Forwarded-Host"); got != "original.example.com" {
|
||||
t.Errorf("X-Forwarded-Host = %q, want %q", got, "original.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddForwardedHeaders_UnixSocketUntrusted(t *testing.T) {
|
||||
h := Handler{}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
req.RemoteAddr = "@"
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "spoofed.example.com")
|
||||
|
||||
vars := map[string]interface{}{
|
||||
caddyhttp.TrustedProxyVarKey: false,
|
||||
caddyhttp.ClientIPVarKey: "",
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
err := h.addForwardedHeaders(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if got := req.Header.Get("X-Forwarded-For"); got != "" {
|
||||
t.Errorf("X-Forwarded-For should be deleted, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Forwarded-Proto"); got != "" {
|
||||
t.Errorf("X-Forwarded-Proto should be deleted, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Forwarded-Host"); got != "" {
|
||||
t.Errorf("X-Forwarded-Host should be deleted, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddForwardedHeaders_UnixSocketTrustedNoExistingHeaders(t *testing.T) {
|
||||
h := Handler{}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
req.RemoteAddr = "@"
|
||||
|
||||
vars := map[string]interface{}{
|
||||
caddyhttp.TrustedProxyVarKey: true,
|
||||
caddyhttp.ClientIPVarKey: "5.6.7.8",
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), caddyhttp.VarsCtxKey, vars)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
err := h.addForwardedHeaders(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if got := req.Header.Get("X-Forwarded-For"); got != "" {
|
||||
t.Errorf("X-Forwarded-For should be empty when no prior XFF exists, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Forwarded-Proto"); got != "http" {
|
||||
t.Errorf("X-Forwarded-Proto = %q, want %q", got, "http")
|
||||
}
|
||||
if got := req.Header.Get("X-Forwarded-Host"); got != "example.com" {
|
||||
t.Errorf("X-Forwarded-Host = %q, want %q", got, "example.com")
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -360,6 +359,12 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
|
||||
dialInfoUpstream = &Upstream{
|
||||
Dial: h.HealthChecks.Active.Upstream,
|
||||
}
|
||||
} else if upstream.activeHealthCheckPort != 0 {
|
||||
// health_port overrides the port; addr has already been updated
|
||||
// with the health port, so use its address for dialing
|
||||
dialInfoUpstream = &Upstream{
|
||||
Dial: addr.JoinHostPort(0),
|
||||
}
|
||||
}
|
||||
dialInfo, _ := dialInfoUpstream.fillDialInfo(repl)
|
||||
|
||||
@@ -405,14 +410,9 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ
|
||||
u.Host = net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// this is kind of a hacky way to know if we should use HTTPS, but whatever
|
||||
if tt, ok := h.Transport.(TLSTransport); ok && tt.TLSEnabled() {
|
||||
u.Scheme = "https"
|
||||
|
||||
// if the port is in the except list, flip back to HTTP
|
||||
if ht, ok := h.Transport.(*HTTPTransport); ok && slices.Contains(ht.TLS.ExceptPorts, port) {
|
||||
u.Scheme = "http"
|
||||
}
|
||||
// override health check schemes if applicable
|
||||
if hcsot, ok := h.Transport.(HealthCheckSchemeOverriderTransport); ok {
|
||||
hcsot.OverrideHealthCheckScheme(u, port)
|
||||
}
|
||||
|
||||
// if we have a provisioned uri, use that, otherwise use
|
||||
@@ -506,7 +506,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ
|
||||
}
|
||||
|
||||
// do the request, being careful to tame the response body
|
||||
resp, err := h.HealthChecks.Active.httpClient.Do(req)
|
||||
resp, err := h.HealthChecks.Active.httpClient.Do(req) //nolint:gosec // no SSRF
|
||||
if err != nil {
|
||||
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "HTTP request failed"); c != nil {
|
||||
c.Write(
|
||||
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
@@ -132,6 +134,43 @@ func (u *Upstream) fillHost() {
|
||||
u.Host = host
|
||||
}
|
||||
|
||||
// fillDynamicHost is like fillHost, but stores the host in the separate
|
||||
// dynamicHosts map rather than the reference-counted UsagePool. Dynamic
|
||||
// hosts are not reference-counted; instead, they are retained as long as
|
||||
// they are actively seen and are evicted by a background cleanup goroutine
|
||||
// after dynamicHostIdleExpiry of inactivity. This preserves health state
|
||||
// (e.g. passive fail counts) across sequential requests.
|
||||
func (u *Upstream) fillDynamicHost() {
|
||||
dynamicHostsMu.Lock()
|
||||
entry, ok := dynamicHosts[u.String()]
|
||||
if ok {
|
||||
entry.lastSeen = time.Now()
|
||||
dynamicHosts[u.String()] = entry
|
||||
u.Host = entry.host
|
||||
} else {
|
||||
h := new(Host)
|
||||
dynamicHosts[u.String()] = dynamicHostEntry{host: h, lastSeen: time.Now()}
|
||||
u.Host = h
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
|
||||
// ensure the cleanup goroutine is running
|
||||
dynamicHostsCleanerOnce.Do(func() {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(dynamicHostCleanupInterval)
|
||||
dynamicHostsMu.Lock()
|
||||
for addr, entry := range dynamicHosts {
|
||||
if time.Since(entry.lastSeen) > dynamicHostIdleExpiry {
|
||||
delete(dynamicHosts, addr)
|
||||
}
|
||||
}
|
||||
dynamicHostsMu.Unlock()
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// Host is the basic, in-memory representation of the state of a remote host.
|
||||
// Its fields are accessed atomically and Host values must not be copied.
|
||||
type Host struct {
|
||||
@@ -268,6 +307,28 @@ func GetDialInfo(ctx context.Context) (DialInfo, bool) {
|
||||
// through config reloads.
|
||||
var hosts = caddy.NewUsagePool()
|
||||
|
||||
// dynamicHosts tracks hosts that were provisioned from dynamic upstream
|
||||
// sources. Unlike static upstreams which are reference-counted via the
|
||||
// UsagePool, dynamic upstream hosts are not reference-counted. Instead,
|
||||
// their last-seen time is updated on each request, and a background
|
||||
// goroutine evicts entries that have been idle for dynamicHostIdleExpiry.
|
||||
// This preserves health state (e.g. passive fail counts) across requests
|
||||
// to the same dynamic backend.
|
||||
var (
|
||||
dynamicHosts = make(map[string]dynamicHostEntry)
|
||||
dynamicHostsMu sync.RWMutex
|
||||
dynamicHostsCleanerOnce sync.Once
|
||||
dynamicHostCleanupInterval = 5 * time.Minute
|
||||
dynamicHostIdleExpiry = time.Hour
|
||||
)
|
||||
|
||||
// dynamicHostEntry holds a Host and the last time it was seen
|
||||
// in a set of dynamic upstreams returned for a request.
|
||||
type dynamicHostEntry struct {
|
||||
host *Host
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// dialInfoVarKey is the key used for the variable that holds
|
||||
// the dial info for the upstream connection.
|
||||
const dialInfoVarKey = "reverse_proxy.dial_info"
|
||||
@@ -285,3 +346,6 @@ type ProxyProtocolInfo struct {
|
||||
// tlsH1OnlyVarKey is the key used that indicates the connection will use h1 only for TLS.
|
||||
// https://github.com/caddyserver/caddy/issues/7292
|
||||
const tlsH1OnlyVarKey = "reverse_proxy.tls_h1_only"
|
||||
|
||||
// proxyVarKey is the key used that indicates the proxy server used for a request.
|
||||
const proxyVarKey = "reverse_proxy.proxy"
|
||||
|
||||
@@ -21,9 +21,10 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"slices"
|
||||
@@ -39,6 +40,7 @@ import (
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp/headers"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddytls"
|
||||
"github.com/caddyserver/caddy/v2/modules/internal/network"
|
||||
)
|
||||
@@ -159,8 +161,7 @@ type HTTPTransport struct {
|
||||
// `HTTPS_PROXY`, and `NO_PROXY` environment variables.
|
||||
NetworkProxyRaw json.RawMessage `json:"network_proxy,omitempty" caddy:"namespace=caddy.network_proxy inline_key=from"`
|
||||
|
||||
h2cTransport *http2.Transport
|
||||
h3Transport *http3.Transport // TODO: EXPERIMENTAL (May 2024)
|
||||
h3Transport *http3.Transport // TODO: EXPERIMENTAL (May 2024)
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
@@ -204,11 +205,16 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
|
||||
func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, error) {
|
||||
// Set keep-alive defaults if it wasn't otherwise configured
|
||||
if h.KeepAlive == nil {
|
||||
h.KeepAlive = &KeepAlive{
|
||||
ProbeInterval: caddy.Duration(30 * time.Second),
|
||||
IdleConnTimeout: caddy.Duration(2 * time.Minute),
|
||||
MaxIdleConnsPerHost: 32, // seems about optimal, see #2805
|
||||
}
|
||||
h.KeepAlive = new(KeepAlive)
|
||||
}
|
||||
if h.KeepAlive.ProbeInterval == 0 {
|
||||
h.KeepAlive.ProbeInterval = caddy.Duration(30 * time.Second)
|
||||
}
|
||||
if h.KeepAlive.IdleConnTimeout == 0 {
|
||||
h.KeepAlive.IdleConnTimeout = caddy.Duration(2 * time.Minute)
|
||||
}
|
||||
if h.KeepAlive.MaxIdleConnsPerHost == 0 {
|
||||
h.KeepAlive.MaxIdleConnsPerHost = 32 // seems about optimal, see #2805
|
||||
}
|
||||
|
||||
// Set a relatively short default dial timeout.
|
||||
@@ -260,22 +266,22 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
//nolint:gosec
|
||||
addr := h.Resolver.netAddrs[weakrand.Intn(len(h.Resolver.netAddrs))]
|
||||
addr := h.Resolver.netAddrs[weakrand.IntN(len(h.Resolver.netAddrs))]
|
||||
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// For unix socket upstreams, we need to recover the dial info from
|
||||
// the request's context, because the Host on the request's URL
|
||||
// will have been modified by directing the request, overwriting
|
||||
// the unix socket filename.
|
||||
// Also, we need to avoid overwriting the address at this point
|
||||
// when not necessary, because http.ProxyFromEnvironment may have
|
||||
// modified the address according to the user's env proxy config.
|
||||
// The network is usually tcp, and the address is the host in http.Request.URL.Host
|
||||
// and that's been overwritten in directRequest
|
||||
// However, if proxy is used according to http.ProxyFromEnvironment or proxy providers,
|
||||
// address will be the address of the proxy server.
|
||||
|
||||
// This means we can safely use the address in dialInfo if proxy is not used (the address and network will be same any way)
|
||||
// or if the upstream is unix (because there is no way socks or http proxy can be used for unix address).
|
||||
if dialInfo, ok := GetDialInfo(ctx); ok {
|
||||
if strings.HasPrefix(dialInfo.Network, "unix") {
|
||||
if caddyhttp.GetVar(ctx, proxyVarKey) == nil || strings.HasPrefix(dialInfo.Network, "unix") {
|
||||
network = dialInfo.Network
|
||||
address = dialInfo.Address
|
||||
}
|
||||
@@ -376,9 +382,22 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
||||
return nil, fmt.Errorf("network_proxy module is not `(func(*http.Request) (*url.URL, error))``")
|
||||
}
|
||||
}
|
||||
// we need to keep track if a proxy is used for a request
|
||||
proxyWrapper := func(req *http.Request) (*url.URL, error) {
|
||||
if proxy == nil {
|
||||
return nil, nil
|
||||
}
|
||||
u, err := proxy(req)
|
||||
if u == nil || err != nil {
|
||||
return u, err
|
||||
}
|
||||
// there must be a proxy for this request
|
||||
caddyhttp.SetVar(req.Context(), proxyVarKey, u)
|
||||
return u, nil
|
||||
}
|
||||
|
||||
rt := &http.Transport{
|
||||
Proxy: proxy,
|
||||
Proxy: proxyWrapper,
|
||||
DialContext: dialContext,
|
||||
MaxConnsPerHost: h.MaxConnsPerHost,
|
||||
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
|
||||
@@ -396,8 +415,13 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
||||
return nil, fmt.Errorf("making TLS client config: %v", err)
|
||||
}
|
||||
|
||||
// servername has a placeholder, so we need to replace it
|
||||
if strings.Contains(h.TLS.ServerName, "{") {
|
||||
serverNameHasPlaceholder := strings.Contains(h.TLS.ServerName, "{")
|
||||
|
||||
// We need to use custom DialTLSContext if:
|
||||
// 1. ServerName has a placeholder that needs to be replaced at request-time, OR
|
||||
// 2. ProxyProtocol is enabled, because req.URL.Host is modified to include
|
||||
// client address info with "->" separator which breaks Go's address parsing
|
||||
if serverNameHasPlaceholder || h.ProxyProtocol != "" {
|
||||
rt.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// reuses the dialer from above to establish a plaintext connection
|
||||
conn, err := dialContext(ctx, network, addr)
|
||||
@@ -406,9 +430,11 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
||||
}
|
||||
|
||||
// but add our own handshake logic
|
||||
repl := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
tlsConfig := rt.TLSClientConfig.Clone()
|
||||
tlsConfig.ServerName = repl.ReplaceAll(tlsConfig.ServerName, "")
|
||||
if serverNameHasPlaceholder {
|
||||
repl := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
tlsConfig.ServerName = repl.ReplaceAll(tlsConfig.ServerName, "")
|
||||
}
|
||||
|
||||
// h1 only
|
||||
if caddyhttp.GetVar(ctx, tlsH1OnlyVarKey) == true {
|
||||
@@ -422,7 +448,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
||||
// complete the handshake before returning the connection
|
||||
if rt.TLSHandshakeTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, rt.TLSHandshakeTimeout)
|
||||
ctx, cancel = context.WithTimeoutCause(ctx, rt.TLSHandshakeTimeout, fmt.Errorf("HTTP transport TLS handshake %ds timeout", int(rt.TLSHandshakeTimeout.Seconds())))
|
||||
defer cancel()
|
||||
}
|
||||
err = tlsConn.HandshakeContext(ctx)
|
||||
@@ -457,24 +483,10 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
||||
rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
|
||||
}
|
||||
|
||||
// The proxy protocol header can only be sent once right after opening the connection.
|
||||
// So single connection must not be used for multiple requests, which can potentially
|
||||
// come from different clients.
|
||||
if !rt.DisableKeepAlives && h.ProxyProtocol != "" {
|
||||
caddyCtx.Logger().Warn("disabling keepalives, they are incompatible with using PROXY protocol")
|
||||
rt.DisableKeepAlives = true
|
||||
}
|
||||
|
||||
if h.Compression != nil {
|
||||
rt.DisableCompression = !*h.Compression
|
||||
}
|
||||
|
||||
if slices.Contains(h.Versions, "2") {
|
||||
if err := http2.ConfigureTransport(rt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// configure HTTP/3 transport if enabled; however, this does not
|
||||
// automatically fall back to lower versions like most web browsers
|
||||
// do (that'd add latency and complexity, besides, we expect that
|
||||
@@ -492,30 +504,49 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
||||
return nil, fmt.Errorf("if HTTP/3 is enabled to the upstream, no other HTTP versions are supported")
|
||||
}
|
||||
|
||||
// if h2c is enabled, configure its transport (std lib http.Transport
|
||||
// does not "HTTP/2 over cleartext TCP")
|
||||
if slices.Contains(h.Versions, "h2c") {
|
||||
// crafting our own http2.Transport doesn't allow us to utilize
|
||||
// most of the customizations/preferences on the http.Transport,
|
||||
// because, for some reason, only http2.ConfigureTransport()
|
||||
// is allowed to set the unexported field that refers to a base
|
||||
// http.Transport config; oh well
|
||||
h2t := &http2.Transport{
|
||||
// kind of a hack, but for plaintext/H2C requests, pretend to dial TLS
|
||||
DialTLSContext: func(ctx context.Context, network, address string, _ *tls.Config) (net.Conn, error) {
|
||||
return dialContext(ctx, network, address)
|
||||
},
|
||||
AllowHTTP: true,
|
||||
// if h2/c is enabled, configure it explicitly
|
||||
if slices.Contains(h.Versions, "2") || slices.Contains(h.Versions, "h2c") {
|
||||
if err := http2.ConfigureTransport(rt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if h.Compression != nil {
|
||||
h2t.DisableCompression = !*h.Compression
|
||||
|
||||
// DisableCompression from h2 is configured by http2.ConfigureTransport
|
||||
// Likewise, DisableKeepAlives from h1 is used too.
|
||||
|
||||
// Protocols field is only used when the request is not using TLS,
|
||||
// http1/2 over tls is still allowed
|
||||
if slices.Contains(h.Versions, "h2c") {
|
||||
rt.Protocols = new(http.Protocols)
|
||||
rt.Protocols.SetUnencryptedHTTP2(true)
|
||||
rt.Protocols.SetHTTP1(false)
|
||||
}
|
||||
h.h2cTransport = h2t
|
||||
}
|
||||
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
// RequestHeaderOps implements TransportHeaderOpsProvider. It returns header
|
||||
// operations for requests when the transport's configuration indicates they
|
||||
// should be applied. In particular, when TLS is enabled for this transport,
|
||||
// return an operation to set the Host header to the upstream host:port
|
||||
// placeholder so HTTPS upstreams get the proper Host by default.
|
||||
//
|
||||
// Note: this is a provision-time hook; the Handler will call this during
|
||||
// its Provision and cache the resulting HeaderOps. The HeaderOps are
|
||||
// applied per-request (so placeholders are expanded at request time).
|
||||
func (h *HTTPTransport) RequestHeaderOps() *headers.HeaderOps {
|
||||
// If TLS is not configured for this transport, don't inject Host
|
||||
// defaults. TLS being non-nil indicates HTTPS to the upstream.
|
||||
if h.TLS == nil {
|
||||
return nil
|
||||
}
|
||||
return &headers.HeaderOps{
|
||||
Set: http.Header{
|
||||
"Host": []string{"{http.reverse_proxy.upstream.hostport}"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper.
|
||||
func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
h.SetScheme(req)
|
||||
@@ -525,15 +556,6 @@ func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return h.h3Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// if H2C ("HTTP/2 over cleartext") is enabled and the upstream request is
|
||||
// HTTP without TLS, use the alternate H2C-capable transport instead
|
||||
if req.URL.Scheme == "http" && h.h2cTransport != nil {
|
||||
// There is no dedicated DisableKeepAlives field in *http2.Transport.
|
||||
// This is an alternative way to disable keep-alive.
|
||||
req.Close = h.Transport.DisableKeepAlives
|
||||
return h.h2cTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
return h.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
@@ -575,6 +597,26 @@ func (h *HTTPTransport) EnableTLS(base *TLSConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableH2C enables H2C (HTTP/2 over Cleartext) on the transport.
|
||||
func (h *HTTPTransport) EnableH2C() error {
|
||||
h.Versions = []string{"h2c", "2"}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OverrideHealthCheckScheme overrides the scheme of the given URL
|
||||
// used for health checks.
|
||||
func (h HTTPTransport) OverrideHealthCheckScheme(base *url.URL, port string) {
|
||||
// if tls is enabled and the port isn't in the except list, use HTTPs
|
||||
if h.TLSEnabled() && !slices.Contains(h.TLS.ExceptPorts, port) {
|
||||
base.Scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyProtocolEnabled returns true if proxy protocol is enabled.
|
||||
func (h HTTPTransport) ProxyProtocolEnabled() bool {
|
||||
return h.ProxyProtocol != ""
|
||||
}
|
||||
|
||||
// Cleanup implements caddy.CleanerUpper and closes any idle connections.
|
||||
func (h HTTPTransport) Cleanup() error {
|
||||
if h.Transport == nil {
|
||||
@@ -831,8 +873,11 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
|
||||
|
||||
// Interface guards
|
||||
var (
|
||||
_ caddy.Provisioner = (*HTTPTransport)(nil)
|
||||
_ http.RoundTripper = (*HTTPTransport)(nil)
|
||||
_ caddy.CleanerUpper = (*HTTPTransport)(nil)
|
||||
_ TLSTransport = (*HTTPTransport)(nil)
|
||||
_ caddy.Provisioner = (*HTTPTransport)(nil)
|
||||
_ http.RoundTripper = (*HTTPTransport)(nil)
|
||||
_ caddy.CleanerUpper = (*HTTPTransport)(nil)
|
||||
_ TLSTransport = (*HTTPTransport)(nil)
|
||||
_ H2CTransport = (*HTTPTransport)(nil)
|
||||
_ HealthCheckSchemeOverriderTransport = (*HTTPTransport)(nil)
|
||||
_ ProxyProtocolTransport = (*HTTPTransport)(nil)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
)
|
||||
|
||||
@@ -94,3 +96,102 @@ func TestHTTPTransportUnmarshalCaddyFileWithCaPools(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransport_RequestHeaderOps_TLS(t *testing.T) {
|
||||
var ht HTTPTransport
|
||||
// When TLS is nil, expect no header ops
|
||||
if ops := ht.RequestHeaderOps(); ops != nil {
|
||||
t.Fatalf("expected nil HeaderOps when TLS is nil, got: %#v", ops)
|
||||
}
|
||||
|
||||
// When TLS is configured, expect a HeaderOps that sets Host
|
||||
ht.TLS = &TLSConfig{}
|
||||
ops := ht.RequestHeaderOps()
|
||||
if ops == nil {
|
||||
t.Fatal("expected non-nil HeaderOps when TLS is set")
|
||||
}
|
||||
if ops.Set == nil {
|
||||
t.Fatalf("expected ops.Set to be non-nil, got nil")
|
||||
}
|
||||
if got := ops.Set.Get("Host"); got != "{http.reverse_proxy.upstream.hostport}" {
|
||||
t.Fatalf("unexpected Host value; want placeholder, got: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTPTransport_DialTLSContext_ProxyProtocol verifies that when TLS and
|
||||
// ProxyProtocol are both enabled, DialTLSContext is set. This is critical because
|
||||
// ProxyProtocol modifies req.URL.Host to include client info with "->" separator
|
||||
// (e.g., "[2001:db8::1]:12345->127.0.0.1:443"), which breaks Go's address parsing.
|
||||
// Without a custom DialTLSContext, Go's HTTP library would fail with
|
||||
// "too many colons in address" when trying to parse the mangled host.
|
||||
func TestHTTPTransport_DialTLSContext_ProxyProtocol(t *testing.T) {
|
||||
ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tls *TLSConfig
|
||||
proxyProtocol string
|
||||
serverNameHasPlaceholder bool
|
||||
expectDialTLSContext bool
|
||||
}{
|
||||
{
|
||||
name: "no TLS, no proxy protocol",
|
||||
tls: nil,
|
||||
proxyProtocol: "",
|
||||
expectDialTLSContext: false,
|
||||
},
|
||||
{
|
||||
name: "TLS without proxy protocol",
|
||||
tls: &TLSConfig{},
|
||||
proxyProtocol: "",
|
||||
expectDialTLSContext: false,
|
||||
},
|
||||
{
|
||||
name: "TLS with proxy protocol v1",
|
||||
tls: &TLSConfig{},
|
||||
proxyProtocol: "v1",
|
||||
expectDialTLSContext: true,
|
||||
},
|
||||
{
|
||||
name: "TLS with proxy protocol v2",
|
||||
tls: &TLSConfig{},
|
||||
proxyProtocol: "v2",
|
||||
expectDialTLSContext: true,
|
||||
},
|
||||
{
|
||||
name: "TLS with placeholder ServerName",
|
||||
tls: &TLSConfig{ServerName: "{http.request.host}"},
|
||||
proxyProtocol: "",
|
||||
serverNameHasPlaceholder: true,
|
||||
expectDialTLSContext: true,
|
||||
},
|
||||
{
|
||||
name: "TLS with placeholder ServerName and proxy protocol",
|
||||
tls: &TLSConfig{ServerName: "{http.request.host}"},
|
||||
proxyProtocol: "v2",
|
||||
serverNameHasPlaceholder: true,
|
||||
expectDialTLSContext: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ht := &HTTPTransport{
|
||||
TLS: tt.tls,
|
||||
ProxyProtocol: tt.proxyProtocol,
|
||||
}
|
||||
|
||||
rt, err := ht.NewTransport(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTransport() error = %v", err)
|
||||
}
|
||||
|
||||
hasDialTLSContext := rt.DialTLSContext != nil
|
||||
if hasDialTLSContext != tt.expectDialTLSContext {
|
||||
t.Errorf("DialTLSContext set = %v, want %v", hasDialTLSContext, tt.expectDialTLSContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,391 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
|
||||
// newPassiveHandler builds a minimal Handler with passive health checks
|
||||
// configured and a live caddy.Context so the fail-forgetter goroutine can
|
||||
// be cancelled cleanly. The caller must call cancel() when done.
|
||||
func newPassiveHandler(t *testing.T, maxFails int, failDuration time.Duration) (*Handler, context.CancelFunc) {
|
||||
t.Helper()
|
||||
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
h := &Handler{
|
||||
ctx: caddyCtx,
|
||||
HealthChecks: &HealthChecks{
|
||||
Passive: &PassiveHealthChecks{
|
||||
MaxFails: maxFails,
|
||||
FailDuration: caddy.Duration(failDuration),
|
||||
},
|
||||
},
|
||||
}
|
||||
return h, cancel
|
||||
}
|
||||
|
||||
// provisionedStaticUpstream creates a static upstream, registers it in the
|
||||
// UsagePool, and returns a cleanup func that removes it from the pool.
|
||||
func provisionedStaticUpstream(t *testing.T, h *Handler, addr string) (*Upstream, func()) {
|
||||
t.Helper()
|
||||
u := &Upstream{Dial: addr}
|
||||
h.provisionUpstream(u, false)
|
||||
return u, func() { _, _ = hosts.Delete(addr) }
|
||||
}
|
||||
|
||||
// provisionedDynamicUpstream creates a dynamic upstream, registers it in
|
||||
// dynamicHosts, and returns a cleanup func that removes it.
|
||||
func provisionedDynamicUpstream(t *testing.T, h *Handler, addr string) (*Upstream, func()) {
|
||||
t.Helper()
|
||||
u := &Upstream{Dial: addr}
|
||||
h.provisionUpstream(u, true)
|
||||
return u, func() {
|
||||
dynamicHostsMu.Lock()
|
||||
delete(dynamicHosts, addr)
|
||||
dynamicHostsMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// --- countFailure behaviour ---
|
||||
|
||||
// TestCountFailureNoopWhenNoHealthChecks verifies that countFailure is a no-op
|
||||
// when HealthChecks is nil.
|
||||
func TestCountFailureNoopWhenNoHealthChecks(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h := &Handler{}
|
||||
u := &Upstream{Dial: "10.1.0.1:80", Host: new(Host)}
|
||||
|
||||
h.countFailure(u)
|
||||
|
||||
if u.Host.Fails() != 0 {
|
||||
t.Errorf("expected 0 fails with no HealthChecks config, got %d", u.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCountFailureNoopWhenZeroDuration verifies that countFailure is a no-op
|
||||
// when FailDuration is 0 (the zero value disables passive checks).
|
||||
func TestCountFailureNoopWhenZeroDuration(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
h := &Handler{
|
||||
ctx: caddyCtx,
|
||||
HealthChecks: &HealthChecks{
|
||||
Passive: &PassiveHealthChecks{MaxFails: 1, FailDuration: 0},
|
||||
},
|
||||
}
|
||||
u := &Upstream{Dial: "10.1.0.2:80", Host: new(Host)}
|
||||
|
||||
h.countFailure(u)
|
||||
|
||||
if u.Host.Fails() != 0 {
|
||||
t.Errorf("expected 0 fails with zero FailDuration, got %d", u.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCountFailureIncrementsCount verifies that countFailure increments the
|
||||
// fail count on the upstream's Host.
|
||||
func TestCountFailureIncrementsCount(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Minute)
|
||||
defer cancel()
|
||||
u := &Upstream{Dial: "10.1.0.3:80", Host: new(Host)}
|
||||
|
||||
h.countFailure(u)
|
||||
|
||||
if u.Host.Fails() != 1 {
|
||||
t.Errorf("expected 1 fail after countFailure, got %d", u.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCountFailureDecrementsAfterDuration verifies that the fail count is
|
||||
// decremented back after FailDuration elapses.
|
||||
func TestCountFailureDecrementsAfterDuration(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
const failDuration = 50 * time.Millisecond
|
||||
h, cancel := newPassiveHandler(t, 2, failDuration)
|
||||
defer cancel()
|
||||
u := &Upstream{Dial: "10.1.0.4:80", Host: new(Host)}
|
||||
|
||||
h.countFailure(u)
|
||||
if u.Host.Fails() != 1 {
|
||||
t.Fatalf("expected 1 fail immediately after countFailure, got %d", u.Host.Fails())
|
||||
}
|
||||
|
||||
// Wait long enough for the forgetter goroutine to fire.
|
||||
time.Sleep(3 * failDuration)
|
||||
|
||||
if u.Host.Fails() != 0 {
|
||||
t.Errorf("expected fail count to return to 0 after FailDuration, got %d", u.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCountFailureCancelledContextForgets verifies that cancelling the handler
|
||||
// context (simulating a config unload) also triggers the forgetter to run,
|
||||
// decrementing the fail count.
|
||||
func TestCountFailureCancelledContextForgets(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Hour) // very long duration
|
||||
u := &Upstream{Dial: "10.1.0.5:80", Host: new(Host)}
|
||||
|
||||
h.countFailure(u)
|
||||
if u.Host.Fails() != 1 {
|
||||
t.Fatalf("expected 1 fail immediately after countFailure, got %d", u.Host.Fails())
|
||||
}
|
||||
|
||||
// Cancelling the context should cause the forgetter goroutine to exit and
|
||||
// decrement the count.
|
||||
cancel()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if u.Host.Fails() != 0 {
|
||||
t.Errorf("expected fail count to be decremented after context cancel, got %d", u.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// --- static upstream passive health check ---
|
||||
|
||||
// TestStaticUpstreamHealthyWithNoFailures verifies that a static upstream with
|
||||
// no recorded failures is considered healthy.
|
||||
func TestStaticUpstreamHealthyWithNoFailures(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
u, cleanup := provisionedStaticUpstream(t, h, "10.2.0.1:80")
|
||||
defer cleanup()
|
||||
|
||||
if !u.Healthy() {
|
||||
t.Error("upstream with no failures should be healthy")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticUpstreamUnhealthyAtMaxFails verifies that a static upstream is
|
||||
// marked unhealthy once its fail count reaches MaxFails.
|
||||
func TestStaticUpstreamUnhealthyAtMaxFails(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
u, cleanup := provisionedStaticUpstream(t, h, "10.2.0.2:80")
|
||||
defer cleanup()
|
||||
|
||||
h.countFailure(u)
|
||||
if !u.Healthy() {
|
||||
t.Error("upstream should still be healthy after 1 of 2 allowed failures")
|
||||
}
|
||||
|
||||
h.countFailure(u)
|
||||
if u.Healthy() {
|
||||
t.Error("upstream should be unhealthy after reaching MaxFails=2")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticUpstreamRecoversAfterFailDuration verifies that a static upstream
|
||||
// returns to healthy once its failures expire.
|
||||
func TestStaticUpstreamRecoversAfterFailDuration(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
const failDuration = 50 * time.Millisecond
|
||||
h, cancel := newPassiveHandler(t, 1, failDuration)
|
||||
defer cancel()
|
||||
|
||||
u, cleanup := provisionedStaticUpstream(t, h, "10.2.0.3:80")
|
||||
defer cleanup()
|
||||
|
||||
h.countFailure(u)
|
||||
if u.Healthy() {
|
||||
t.Fatal("upstream should be unhealthy immediately after MaxFails failure")
|
||||
}
|
||||
|
||||
time.Sleep(3 * failDuration)
|
||||
|
||||
if !u.Healthy() {
|
||||
t.Errorf("upstream should recover to healthy after FailDuration, Fails=%d", u.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticUpstreamHealthPersistedAcrossReprovisioning verifies that static
|
||||
// upstreams share a Host via the UsagePool, so a second call to provisionUpstream
|
||||
// for the same address (as happens on config reload) sees the accumulated state.
|
||||
func TestStaticUpstreamHealthPersistedAcrossReprovisioning(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
u1, cleanup1 := provisionedStaticUpstream(t, h, "10.2.0.4:80")
|
||||
defer cleanup1()
|
||||
|
||||
h.countFailure(u1)
|
||||
h.countFailure(u1)
|
||||
|
||||
// Simulate a second handler instance referencing the same upstream
|
||||
// (e.g. after a config reload that keeps the same backend address).
|
||||
u2, cleanup2 := provisionedStaticUpstream(t, h, "10.2.0.4:80")
|
||||
defer cleanup2()
|
||||
|
||||
if u1.Host != u2.Host {
|
||||
t.Fatal("expected both Upstream structs to share the same *Host via UsagePool")
|
||||
}
|
||||
if u2.Healthy() {
|
||||
t.Error("re-provisioned upstream should still see the prior fail count and be unhealthy")
|
||||
}
|
||||
}
|
||||
|
||||
// --- dynamic upstream passive health check ---
|
||||
|
||||
// TestDynamicUpstreamHealthyWithNoFailures verifies that a freshly provisioned
|
||||
// dynamic upstream is healthy.
|
||||
func TestDynamicUpstreamHealthyWithNoFailures(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.1:80")
|
||||
defer cleanup()
|
||||
|
||||
if !u.Healthy() {
|
||||
t.Error("dynamic upstream with no failures should be healthy")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicUpstreamUnhealthyAtMaxFails verifies that a dynamic upstream is
|
||||
// marked unhealthy once its fail count reaches MaxFails.
|
||||
func TestDynamicUpstreamUnhealthyAtMaxFails(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.2:80")
|
||||
defer cleanup()
|
||||
|
||||
h.countFailure(u)
|
||||
if !u.Healthy() {
|
||||
t.Error("dynamic upstream should still be healthy after 1 of 2 allowed failures")
|
||||
}
|
||||
|
||||
h.countFailure(u)
|
||||
if u.Healthy() {
|
||||
t.Error("dynamic upstream should be unhealthy after reaching MaxFails=2")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicUpstreamFailCountPersistedBetweenRequests is the core regression
|
||||
// test: it simulates two sequential (non-concurrent) requests to the same
|
||||
// dynamic upstream. Before the fix, the UsagePool entry would be deleted
|
||||
// between requests, wiping the fail count. Now it should survive.
|
||||
func TestDynamicUpstreamFailCountPersistedBetweenRequests(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
h, cancel := newPassiveHandler(t, 2, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// --- first request ---
|
||||
u1 := &Upstream{Dial: "10.3.0.3:80"}
|
||||
h.provisionUpstream(u1, true)
|
||||
h.countFailure(u1)
|
||||
|
||||
if u1.Host.Fails() != 1 {
|
||||
t.Fatalf("expected 1 fail after first request, got %d", u1.Host.Fails())
|
||||
}
|
||||
|
||||
// Simulate end of first request: no delete from any pool (key difference
|
||||
// vs. the old behaviour where hosts.Delete was deferred).
|
||||
|
||||
// --- second request: brand-new *Upstream struct, same dial address ---
|
||||
u2 := &Upstream{Dial: "10.3.0.3:80"}
|
||||
h.provisionUpstream(u2, true)
|
||||
|
||||
if u1.Host != u2.Host {
|
||||
t.Fatal("expected both requests to share the same *Host pointer from dynamicHosts")
|
||||
}
|
||||
if u2.Host.Fails() != 1 {
|
||||
t.Errorf("expected fail count to persist across requests, got %d", u2.Host.Fails())
|
||||
}
|
||||
|
||||
// A second failure now tips it over MaxFails=2.
|
||||
h.countFailure(u2)
|
||||
if u2.Healthy() {
|
||||
t.Error("upstream should be unhealthy after accumulated failures across requests")
|
||||
}
|
||||
|
||||
// Cleanup.
|
||||
dynamicHostsMu.Lock()
|
||||
delete(dynamicHosts, "10.3.0.3:80")
|
||||
dynamicHostsMu.Unlock()
|
||||
}
|
||||
|
||||
// TestDynamicUpstreamRecoveryAfterFailDuration verifies that a dynamic
|
||||
// upstream's fail count expires and it returns to healthy.
|
||||
func TestDynamicUpstreamRecoveryAfterFailDuration(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
const failDuration = 50 * time.Millisecond
|
||||
h, cancel := newPassiveHandler(t, 1, failDuration)
|
||||
defer cancel()
|
||||
|
||||
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.4:80")
|
||||
defer cleanup()
|
||||
|
||||
h.countFailure(u)
|
||||
if u.Healthy() {
|
||||
t.Fatal("upstream should be unhealthy immediately after MaxFails failure")
|
||||
}
|
||||
|
||||
time.Sleep(3 * failDuration)
|
||||
|
||||
// Re-provision (as a new request would) to get fresh *Upstream with policy set.
|
||||
u2 := &Upstream{Dial: "10.3.0.4:80"}
|
||||
h.provisionUpstream(u2, true)
|
||||
|
||||
if !u2.Healthy() {
|
||||
t.Errorf("dynamic upstream should recover to healthy after FailDuration, Fails=%d", u2.Host.Fails())
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicUpstreamMaxRequestsFromUnhealthyRequestCount verifies that
|
||||
// UnhealthyRequestCount is copied into MaxRequests so Full() works correctly.
|
||||
func TestDynamicUpstreamMaxRequestsFromUnhealthyRequestCount(t *testing.T) {
|
||||
resetDynamicHosts()
|
||||
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
h := &Handler{
|
||||
ctx: caddyCtx,
|
||||
HealthChecks: &HealthChecks{
|
||||
Passive: &PassiveHealthChecks{
|
||||
UnhealthyRequestCount: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
u, cleanup := provisionedDynamicUpstream(t, h, "10.3.0.5:80")
|
||||
defer cleanup()
|
||||
|
||||
if u.MaxRequests != 3 {
|
||||
t.Errorf("expected MaxRequests=3 from UnhealthyRequestCount, got %d", u.MaxRequests)
|
||||
}
|
||||
|
||||
// Should not be full with fewer requests than the limit.
|
||||
_ = u.Host.countRequest(2)
|
||||
if u.Full() {
|
||||
t.Error("upstream should not be full with 2 of 3 allowed requests")
|
||||
}
|
||||
|
||||
_ = u.Host.countRequest(1)
|
||||
if !u.Full() {
|
||||
t.Error("upstream should be full at UnhealthyRequestCount concurrent requests")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
// prepareTestRequest injects the context values that ServeHTTP and
|
||||
// proxyLoopIteration require (caddy.ReplacerCtxKey, VarsCtxKey, etc.) using
|
||||
// the same helper that the real HTTP server uses.
|
||||
//
|
||||
// A zero-value Server is passed so that caddyhttp.ServerCtxKey is set to a
|
||||
// non-nil pointer; reverseProxy dereferences it to check ShouldLogCredentials.
|
||||
func prepareTestRequest(req *http.Request) *http.Request {
|
||||
repl := caddy.NewReplacer()
|
||||
return caddyhttp.PrepareRequest(req, repl, nil, &caddyhttp.Server{})
|
||||
}
|
||||
|
||||
// closeOnCloseReader is an io.ReadCloser whose Close method actually makes
|
||||
// subsequent reads fail, mimicking the behaviour of a real HTTP request body
|
||||
// (as opposed to io.NopCloser, whose Close is a no-op and would mask the bug
|
||||
// we are testing).
|
||||
type closeOnCloseReader struct {
|
||||
mu sync.Mutex
|
||||
r *strings.Reader
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newCloseOnCloseReader(s string) *closeOnCloseReader {
|
||||
return &closeOnCloseReader{r: strings.NewReader(s)}
|
||||
}
|
||||
|
||||
func (c *closeOnCloseReader) Read(p []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return 0, errors.New("http: invalid Read on closed Body")
|
||||
}
|
||||
return c.r.Read(p)
|
||||
}
|
||||
|
||||
func (c *closeOnCloseReader) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// deadUpstreamAddr returns a TCP address that is guaranteed to refuse
|
||||
// connections: we bind a listener, note its address, close it immediately,
|
||||
// and return the address. Any dial to that address will get ECONNREFUSED.
|
||||
func deadUpstreamAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create dead upstream listener: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
ln.Close()
|
||||
return addr
|
||||
}
|
||||
|
||||
// testTransport wraps http.Transport to:
|
||||
// 1. Set the URL scheme to "http" when it is empty (matching what
|
||||
// HTTPTransport.SetScheme does in production; cloneRequest strips the
|
||||
// scheme intentionally so a plain *http.Transport would fail with
|
||||
// "unsupported protocol scheme").
|
||||
// 2. Wrap dial errors as DialError so that tryAgain correctly identifies them
|
||||
// as safe-to-retry regardless of request method (as HTTPTransport does in
|
||||
// production via its custom dialer).
|
||||
type testTransport struct{ *http.Transport }
|
||||
|
||||
func (t testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Scheme == "" {
|
||||
req.URL.Scheme = "http"
|
||||
}
|
||||
resp, err := t.Transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
// Wrap dial errors as DialError to match production behaviour.
|
||||
// Without this wrapping, tryAgain treats ECONNREFUSED on a POST
|
||||
// request as non-retryable (only GET is retried by default when
|
||||
// the error is not a DialError).
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Op == "dial" {
|
||||
return nil, DialError{err}
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// minimalHandler returns a Handler with only the fields required by ServeHTTP
|
||||
// set directly, bypassing Provision (which requires a full Caddy runtime).
|
||||
// RoundRobinSelection is used so that successive iterations of the proxy loop
|
||||
// advance through the upstream pool in a predictable order.
|
||||
func minimalHandler(retries int, upstreams ...*Upstream) *Handler {
|
||||
return &Handler{
|
||||
logger: zap.NewNop(),
|
||||
Transport: testTransport{&http.Transport{}},
|
||||
Upstreams: upstreams,
|
||||
LoadBalancing: &LoadBalancing{
|
||||
Retries: retries,
|
||||
SelectionPolicy: &RoundRobinSelection{},
|
||||
// RetryMatch intentionally nil: dial errors are always retried
|
||||
// regardless of RetryMatch or request method.
|
||||
},
|
||||
// ctx, connections, connectionsMu, events: zero/nil values are safe
|
||||
// for the code paths exercised by these tests (TryInterval=0 so
|
||||
// ctx.Done() is never consulted; no WebSocket hijacking; no passive
|
||||
// health-check event emission).
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialErrorBodyRetry verifies that a POST request whose body has NOT been
|
||||
// pre-buffered via request_buffers can still be retried after a dial error.
|
||||
//
|
||||
// Before the fix, a dial error caused Go's transport to close the shared body
|
||||
// (via cloneRequest's shallow copy), so the retry attempt would read from an
|
||||
// already-closed io.ReadCloser and produce:
|
||||
//
|
||||
// http: invalid Read on closed Body → HTTP 502
|
||||
//
|
||||
// After the fix the handler wraps the body in noCloseBody when retries are
|
||||
// configured, preventing the transport's Close() from propagating to the
|
||||
// shared body. Since dial errors never read any bytes, the body remains at
|
||||
// position 0 for the retry.
|
||||
func TestDialErrorBodyRetry(t *testing.T) {
|
||||
// Good upstream: echoes the request body with 200 OK.
|
||||
goodServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read body: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
t.Cleanup(goodServer.Close)
|
||||
|
||||
const requestBody = "hello, retry"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
body string
|
||||
retries int
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
// Core regression case: POST with a body, no request_buffers,
|
||||
// dial error on first upstream → retry to second upstream succeeds.
|
||||
name: "POST body retried after dial error",
|
||||
method: http.MethodPost,
|
||||
body: requestBody,
|
||||
retries: 1,
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: requestBody,
|
||||
},
|
||||
{
|
||||
// Dial errors are always retried regardless of method, but there
|
||||
// is no body to re-read, so GET has always worked. Keep it as a
|
||||
// sanity check that we did not break the no-body path.
|
||||
name: "GET without body retried after dial error",
|
||||
method: http.MethodGet,
|
||||
body: "",
|
||||
retries: 1,
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "",
|
||||
},
|
||||
{
|
||||
// Without any retry configuration the handler must give up on the
|
||||
// first dial error and return a 502. Confirms no wrapping occurs
|
||||
// in the no-retry path.
|
||||
name: "no retries configured returns 502 on dial error",
|
||||
method: http.MethodPost,
|
||||
body: requestBody,
|
||||
retries: 0,
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantBody: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dead := deadUpstreamAddr(t)
|
||||
|
||||
// Build the upstream pool. RoundRobinSelection starts its
|
||||
// counter at 0 and increments before returning, so with a
|
||||
// two-element pool it picks index 1 first, then index 0.
|
||||
// Put the good upstream at index 0 and the dead one at
|
||||
// index 1 so that:
|
||||
// attempt 1 → pool[1] = dead → DialError (ECONNREFUSED)
|
||||
// attempt 2 → pool[0] = good → 200
|
||||
upstreams := []*Upstream{
|
||||
{Host: new(Host), Dial: goodServer.Listener.Addr().String()},
|
||||
{Host: new(Host), Dial: dead},
|
||||
}
|
||||
if tc.retries == 0 {
|
||||
// For the "no retries" case use only the dead upstream so
|
||||
// there is nowhere to retry to.
|
||||
upstreams = []*Upstream{
|
||||
{Host: new(Host), Dial: dead},
|
||||
}
|
||||
}
|
||||
|
||||
h := minimalHandler(tc.retries, upstreams...)
|
||||
|
||||
// Use closeOnCloseReader so that Close() truly prevents further
|
||||
// reads, matching real http.body semantics. io.NopCloser would
|
||||
// mask the bug because its Close is a no-op.
|
||||
var bodyReader io.ReadCloser
|
||||
if tc.body != "" {
|
||||
bodyReader = newCloseOnCloseReader(tc.body)
|
||||
}
|
||||
req := httptest.NewRequest(tc.method, "http://example.com/", bodyReader)
|
||||
if bodyReader != nil {
|
||||
// httptest.NewRequest wraps the reader in NopCloser; replace
|
||||
// it with our close-aware reader so Close() is propagated.
|
||||
req.Body = bodyReader
|
||||
req.ContentLength = int64(len(tc.body))
|
||||
}
|
||||
req = prepareTestRequest(req)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
err := h.ServeHTTP(rec, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}))
|
||||
|
||||
// For error cases (e.g. 502) ServeHTTP returns a HandlerError
|
||||
// rather than writing the status itself.
|
||||
gotStatus := rec.Code
|
||||
if err != nil {
|
||||
if herr, ok := err.(caddyhttp.HandlerError); ok {
|
||||
gotStatus = herr.StatusCode
|
||||
}
|
||||
}
|
||||
|
||||
if gotStatus != tc.wantStatus {
|
||||
t.Errorf("status: got %d, want %d (err=%v)", gotStatus, tc.wantStatus, err)
|
||||
}
|
||||
if tc.wantBody != "" && rec.Body.String() != tc.wantBody {
|
||||
t.Errorf("body: got %q, want %q", rec.Body.String(), tc.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -46,6 +47,31 @@ import (
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
|
||||
)
|
||||
|
||||
// inFlightRequests uses sync.Map with atomic.Int64 for lock-free updates on the hot path
|
||||
var inFlightRequests sync.Map
|
||||
|
||||
func incInFlightRequest(address string) {
|
||||
v, _ := inFlightRequests.LoadOrStore(address, new(atomic.Int64))
|
||||
v.(*atomic.Int64).Add(1)
|
||||
}
|
||||
|
||||
func decInFlightRequest(address string) {
|
||||
if v, ok := inFlightRequests.Load(address); ok {
|
||||
if v.(*atomic.Int64).Add(-1) <= 0 {
|
||||
inFlightRequests.Delete(address)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getInFlightRequests() map[string]int64 {
|
||||
copyMap := make(map[string]int64)
|
||||
inFlightRequests.Range(func(key, value any) bool {
|
||||
copyMap[key.(string)] = value.(*atomic.Int64).Load()
|
||||
return true
|
||||
})
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func init() {
|
||||
caddy.RegisterModule(Handler{})
|
||||
}
|
||||
@@ -192,6 +218,13 @@ type Handler struct {
|
||||
CB CircuitBreaker `json:"-"`
|
||||
DynamicUpstreams UpstreamSource `json:"-"`
|
||||
|
||||
// transportHeaderOps is a set of header operations provided
|
||||
// by the transport at provision time, if the transport
|
||||
// implements TransportHeaderOpsProvider. These ops are
|
||||
// applied before any user-configured header ops so the
|
||||
// user can override transport defaults.
|
||||
transportHeaderOps *headers.HeaderOps
|
||||
|
||||
// Holds the parsed CIDR ranges from TrustedProxies
|
||||
trustedProxies []netip.Prefix
|
||||
|
||||
@@ -243,18 +276,16 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||
return fmt.Errorf("loading transport: %v", err)
|
||||
}
|
||||
h.Transport = mod.(http.RoundTripper)
|
||||
// enable request buffering for fastcgi if not configured
|
||||
// This is because most fastcgi servers are php-fpm that require the content length to be set to read the body, golang
|
||||
// std has fastcgi implementation that doesn't need this value to process the body, but we can safely assume that's
|
||||
// not used.
|
||||
// http3 requests have a negative content length for GET and HEAD requests, if that header is not sent.
|
||||
// see: https://github.com/caddyserver/caddy/issues/6678#issuecomment-2472224182
|
||||
// Though it appears even if CONTENT_LENGTH is invalid, php-fpm can handle just fine if the body is empty (no Stdin records sent).
|
||||
// php-fpm will hang if there is any data in the body though, https://github.com/caddyserver/caddy/issues/5420#issuecomment-2415943516
|
||||
|
||||
// TODO: better default buffering for fastcgi requests without content length, in theory a value of 1 should be enough, make it bigger anyway
|
||||
if module, ok := h.Transport.(caddy.Module); ok && module.CaddyModule().ID.Name() == "fastcgi" && h.RequestBuffers == 0 {
|
||||
h.RequestBuffers = 4096
|
||||
// set default buffer sizes if applicable
|
||||
if bt, ok := h.Transport.(BufferedTransport); ok {
|
||||
reqBuffers, respBuffers := bt.DefaultBufferSizes()
|
||||
if h.RequestBuffers == 0 {
|
||||
h.RequestBuffers = reqBuffers
|
||||
}
|
||||
if h.ResponseBuffers == 0 {
|
||||
h.ResponseBuffers = respBuffers
|
||||
}
|
||||
}
|
||||
}
|
||||
if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil {
|
||||
@@ -324,6 +355,18 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||
h.Transport = t
|
||||
}
|
||||
|
||||
// If the transport can provide header ops, cache them now so we don't
|
||||
// have to compute them per-request. Provision the HeaderOps if present
|
||||
// so any runtime artifacts (like precompiled regex) are prepared.
|
||||
if tph, ok := h.Transport.(RequestHeaderOpsTransport); ok {
|
||||
h.transportHeaderOps = tph.RequestHeaderOps()
|
||||
if h.transportHeaderOps != nil {
|
||||
if err := h.transportHeaderOps.Provision(ctx); err != nil {
|
||||
return fmt.Errorf("provisioning transport header ops: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set up load balancing
|
||||
if h.LoadBalancing == nil {
|
||||
h.LoadBalancing = new(LoadBalancing)
|
||||
@@ -349,7 +392,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||
|
||||
// set up upstreams
|
||||
for _, u := range h.Upstreams {
|
||||
h.provisionUpstream(u)
|
||||
h.provisionUpstream(u, false)
|
||||
}
|
||||
|
||||
if h.HealthChecks != nil {
|
||||
@@ -439,6 +482,33 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
|
||||
reqHost := clonedReq.Host
|
||||
reqHeader := clonedReq.Header
|
||||
|
||||
// When retries are configured and there is a body, wrap it in
|
||||
// io.NopCloser to prevent Go's transport from closing it on dial
|
||||
// errors. cloneRequest does a shallow copy, so clonedReq.Body and
|
||||
// r.Body share the same io.ReadCloser — a dial-failure Close()
|
||||
// would kill the original body for all subsequent retry attempts.
|
||||
// The real body is closed by the HTTP server when the handler
|
||||
// returns.
|
||||
//
|
||||
// If the body was already fully buffered (via request_buffers),
|
||||
// we also extract the buffer so the retry loop can replay it
|
||||
// from the beginning on each attempt. (see #6259, #7546)
|
||||
var bufferedReqBody *bytes.Buffer
|
||||
if clonedReq.Body != nil && h.LoadBalancing != nil &&
|
||||
(h.LoadBalancing.Retries > 0 || h.LoadBalancing.TryDuration > 0) {
|
||||
if reqBodyBuf, ok := clonedReq.Body.(bodyReadCloser); ok && reqBodyBuf.body == nil && reqBodyBuf.buf != nil {
|
||||
bufferedReqBody = reqBodyBuf.buf
|
||||
reqBodyBuf.buf = nil
|
||||
clonedReq.Body = io.NopCloser(bytes.NewReader(bufferedReqBody.Bytes()))
|
||||
defer func() {
|
||||
bufferedReqBody.Reset()
|
||||
bufPool.Put(bufferedReqBody)
|
||||
}()
|
||||
} else {
|
||||
clonedReq.Body = io.NopCloser(clonedReq.Body)
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
// total proxying duration, including time spent on LB and retries
|
||||
@@ -457,8 +527,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
|
||||
// and reusable, so if a backend partially or fully reads the body but then
|
||||
// produces an error, the request can be repeated to the next backend with
|
||||
// the full body (retries should only happen for idempotent requests) (see #6259)
|
||||
if reqBodyBuf, ok := r.Body.(bodyReadCloser); ok && reqBodyBuf.body == nil {
|
||||
r.Body = io.NopCloser(bytes.NewReader(reqBodyBuf.buf.Bytes()))
|
||||
if bufferedReqBody != nil {
|
||||
clonedReq.Body = io.NopCloser(bytes.NewReader(bufferedReqBody.Bytes()))
|
||||
}
|
||||
|
||||
var done bool
|
||||
@@ -506,18 +576,11 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
|
||||
} else {
|
||||
upstreams = dUpstreams
|
||||
for _, dUp := range dUpstreams {
|
||||
h.provisionUpstream(dUp)
|
||||
h.provisionUpstream(dUp, true)
|
||||
}
|
||||
if c := h.logger.Check(zapcore.DebugLevel, "provisioned dynamic upstreams"); c != nil {
|
||||
c.Write(zap.Int("count", len(dUpstreams)))
|
||||
}
|
||||
defer func() {
|
||||
// these upstreams are dynamic, so they are only used for this iteration
|
||||
// of the proxy loop; be sure to let them go away when we're done with them
|
||||
for _, upstream := range dUpstreams {
|
||||
_, _ = hosts.Delete(upstream.String())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -563,14 +626,26 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
|
||||
repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
|
||||
|
||||
// mutate request headers according to this upstream;
|
||||
// because we're in a retry loop, we have to copy
|
||||
// headers (and the r.Host value) from the original
|
||||
// so that each retry is identical to the first
|
||||
if h.Headers != nil && h.Headers.Request != nil {
|
||||
// because we're in a retry loop, we have to copy headers
|
||||
// (and the r.Host value) from the original so that each
|
||||
// retry is identical to the first. If either transport or
|
||||
// user ops exist, apply them in order (transport first,
|
||||
// then user, so user's config wins).
|
||||
var userOps *headers.HeaderOps
|
||||
if h.Headers != nil {
|
||||
userOps = h.Headers.Request
|
||||
}
|
||||
transportOps := h.transportHeaderOps
|
||||
if transportOps != nil || userOps != nil {
|
||||
r.Header = make(http.Header)
|
||||
copyHeader(r.Header, reqHeader)
|
||||
r.Host = reqHost
|
||||
h.Headers.Request.ApplyToRequest(r)
|
||||
if transportOps != nil {
|
||||
transportOps.ApplyToRequest(r)
|
||||
}
|
||||
if userOps != nil {
|
||||
userOps.ApplyToRequest(r)
|
||||
}
|
||||
}
|
||||
|
||||
// proxy the request to that upstream
|
||||
@@ -758,48 +833,71 @@ func (h Handler) prepareRequest(req *http.Request, repl *caddy.Replacer) (*http.
|
||||
// the headers at all, then they will be added with the values
|
||||
// that we can glean from the request.
|
||||
func (h Handler) addForwardedHeaders(req *http.Request) error {
|
||||
// Parse the remote IP, ignore the error as non-fatal,
|
||||
// but the remote IP is required to continue, so we
|
||||
// just return early. This should probably never happen
|
||||
// though, unless some other module manipulated the request's
|
||||
// remote address and used an invalid value.
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
// Remove the `X-Forwarded-*` headers to avoid upstreams
|
||||
// potentially trusting a header that came from the client
|
||||
req.Header.Del("X-Forwarded-For")
|
||||
req.Header.Del("X-Forwarded-Proto")
|
||||
req.Header.Del("X-Forwarded-Host")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client IP may contain a zone if IPv6, so we need
|
||||
// to pull that out before parsing the IP
|
||||
clientIP, _, _ = strings.Cut(clientIP, "%")
|
||||
ipAddr, err := netip.ParseAddr(clientIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP address: '%s': %v", clientIP, err)
|
||||
}
|
||||
|
||||
// Check if the client is a trusted proxy
|
||||
trusted := caddyhttp.GetVar(req.Context(), caddyhttp.TrustedProxyVarKey).(bool)
|
||||
for _, ipRange := range h.trustedProxies {
|
||||
if ipRange.Contains(ipAddr) {
|
||||
trusted = true
|
||||
break
|
||||
|
||||
var clientIP string
|
||||
|
||||
if req.RemoteAddr == "@" {
|
||||
// For Unix socket connections, RemoteAddr is "@" which cannot
|
||||
// be parsed as host:port. If untrusted, strip forwarded headers
|
||||
// for security. If trusted, there is no peer IP to append to
|
||||
// X-Forwarded-For, so clientIP stays empty.
|
||||
if !trusted {
|
||||
req.Header.Del("X-Forwarded-For")
|
||||
req.Header.Del("X-Forwarded-Proto")
|
||||
req.Header.Del("X-Forwarded-Host")
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// Parse the remote IP, ignore the error as non-fatal,
|
||||
// but the remote IP is required to continue, so we
|
||||
// just return early. This should probably never happen
|
||||
// though, unless some other module manipulated the request's
|
||||
// remote address and used an invalid value.
|
||||
var err error
|
||||
clientIP, _, err = net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
// Remove the `X-Forwarded-*` headers to avoid upstreams
|
||||
// potentially trusting a header that came from the client
|
||||
req.Header.Del("X-Forwarded-For")
|
||||
req.Header.Del("X-Forwarded-Proto")
|
||||
req.Header.Del("X-Forwarded-Host")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client IP may contain a zone if IPv6, so we need
|
||||
// to pull that out before parsing the IP
|
||||
clientIP, _, _ = strings.Cut(clientIP, "%")
|
||||
ipAddr, err := netip.ParseAddr(clientIP)
|
||||
|
||||
// If ParseAddr fails (e.g. non-IP network like SCION), we cannot check
|
||||
// if it is a trusted proxy by IP range. In this case, we ignore the
|
||||
// error and treat the connection as untrusted (or retain existing status).
|
||||
if err == nil {
|
||||
for _, ipRange := range h.trustedProxies {
|
||||
if ipRange.Contains(ipAddr) {
|
||||
trusted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we aren't the first proxy, and the proxy is trusted,
|
||||
// retain prior X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
clientXFF := clientIP
|
||||
prior, ok, omit := allHeaderValues(req.Header, "X-Forwarded-For")
|
||||
if trusted && ok && prior != "" {
|
||||
clientXFF = prior + ", " + clientXFF
|
||||
}
|
||||
if !omit {
|
||||
req.Header.Set("X-Forwarded-For", clientXFF)
|
||||
if trusted && ok && prior != "" {
|
||||
if clientIP != "" {
|
||||
req.Header.Set("X-Forwarded-For", prior+", "+clientIP)
|
||||
} else {
|
||||
req.Header.Set("X-Forwarded-For", prior)
|
||||
}
|
||||
} else if clientIP != "" {
|
||||
req.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
// Set X-Forwarded-Proto; many backend apps expect this,
|
||||
@@ -838,8 +936,16 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
|
||||
// Go standard library which was used as the foundation.)
|
||||
func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origReq *http.Request, repl *caddy.Replacer, di DialInfo, next caddyhttp.Handler) error {
|
||||
_ = di.Upstream.Host.countRequest(1)
|
||||
|
||||
// Increment the in-flight request count
|
||||
incInFlightRequest(di.Address)
|
||||
|
||||
//nolint:errcheck
|
||||
defer di.Upstream.Host.countRequest(-1)
|
||||
defer func() {
|
||||
di.Upstream.Host.countRequest(-1)
|
||||
// Decrement the in-flight request count
|
||||
decInFlightRequest(di.Address)
|
||||
}()
|
||||
|
||||
// point the request to this upstream
|
||||
h.directRequest(req, di)
|
||||
@@ -1198,7 +1304,7 @@ func (lb LoadBalancing) tryAgain(ctx caddy.Context, start time.Time, retries int
|
||||
|
||||
// directRequest modifies only req.URL so that it points to the upstream
|
||||
// in the given DialInfo. It must modify ONLY the request URL.
|
||||
func (Handler) directRequest(req *http.Request, di DialInfo) {
|
||||
func (h *Handler) directRequest(req *http.Request, di DialInfo) {
|
||||
// we need a host, so set the upstream's host address
|
||||
reqHost := di.Address
|
||||
|
||||
@@ -1209,12 +1315,31 @@ func (Handler) directRequest(req *http.Request, di DialInfo) {
|
||||
reqHost = di.Host
|
||||
}
|
||||
|
||||
// add client address to the host to let transport differentiate requests from different clients
|
||||
if ppt, ok := h.Transport.(ProxyProtocolTransport); ok && ppt.ProxyProtocolEnabled() {
|
||||
if proxyProtocolInfo, ok := caddyhttp.GetVar(req.Context(), proxyProtocolInfoVarKey).(ProxyProtocolInfo); ok {
|
||||
// encode the request so it plays well with h2 transport, it's unnecessary for h1 but anyway
|
||||
// The issue is that h2 transport will use the address to determine if new connections are needed
|
||||
// to roundtrip requests but the without escaping, new connections are constantly created and closed until
|
||||
// file descriptors are exhausted.
|
||||
// see: https://github.com/caddyserver/caddy/issues/7529
|
||||
reqHost = url.QueryEscape(proxyProtocolInfo.AddrPort.String() + "->" + reqHost)
|
||||
}
|
||||
}
|
||||
|
||||
req.URL.Host = reqHost
|
||||
}
|
||||
|
||||
func (h Handler) provisionUpstream(upstream *Upstream) {
|
||||
// create or get the host representation for this upstream
|
||||
upstream.fillHost()
|
||||
func (h Handler) provisionUpstream(upstream *Upstream, dynamic bool) {
|
||||
// create or get the host representation for this upstream;
|
||||
// dynamic upstreams are tracked in a separate map with last-seen
|
||||
// timestamps so their health state persists across requests without
|
||||
// being reference-counted (and thus discarded between requests).
|
||||
if dynamic {
|
||||
upstream.fillDynamicHost()
|
||||
} else {
|
||||
upstream.fillHost()
|
||||
}
|
||||
|
||||
// give it the circuit breaker, if any
|
||||
upstream.cb = h.CB
|
||||
@@ -1494,6 +1619,43 @@ type TLSTransport interface {
|
||||
EnableTLS(base *TLSConfig) error
|
||||
}
|
||||
|
||||
// H2CTransport is implemented by transports
|
||||
// that are capable of using h2c.
|
||||
type H2CTransport interface {
|
||||
EnableH2C() error
|
||||
}
|
||||
|
||||
// ProxyProtocolTransport is implemented by transports
|
||||
// that are capable of using proxy protocol.
|
||||
type ProxyProtocolTransport interface {
|
||||
ProxyProtocolEnabled() bool
|
||||
}
|
||||
|
||||
// HealthCheckSchemeOverriderTransport is implemented by transports
|
||||
// that can override the scheme used for health checks.
|
||||
type HealthCheckSchemeOverriderTransport interface {
|
||||
OverrideHealthCheckScheme(base *url.URL, port string)
|
||||
}
|
||||
|
||||
// BufferedTransport is implemented by transports
|
||||
// that needs to buffer requests and/or responses.
|
||||
type BufferedTransport interface {
|
||||
// DefaultBufferSizes returns the default buffer sizes
|
||||
// for requests and responses, respectively if buffering isn't enabled.
|
||||
DefaultBufferSizes() (int64, int64)
|
||||
}
|
||||
|
||||
// RequestHeaderOpsTransport may be implemented by a transport to provide
|
||||
// header operations to apply to requests immediately before the RoundTrip.
|
||||
// For example, overriding the default Host when TLS is enabled.
|
||||
type RequestHeaderOpsTransport interface {
|
||||
// RequestHeaderOps allows a transport to provide header operations
|
||||
// to apply to the request. The transport is asked at provision time
|
||||
// to return a HeaderOps (or nil) that will be applied before
|
||||
// user-configured header ops.
|
||||
RequestHeaderOps() *headers.HeaderOps
|
||||
}
|
||||
|
||||
// roundtripSucceededError is an error type that is returned if the
|
||||
// roundtrip succeeded, but an error occurred after-the-fact.
|
||||
type roundtripSucceededError struct{ error }
|
||||
@@ -1507,7 +1669,12 @@ type bodyReadCloser struct {
|
||||
}
|
||||
|
||||
func (brc bodyReadCloser) Close() error {
|
||||
bufPool.Put(brc.buf)
|
||||
// Inside this package this will be set to nil for fully-buffered
|
||||
// requests due to the possibility of retrial.
|
||||
if brc.buf != nil {
|
||||
bufPool.Put(brc.buf)
|
||||
}
|
||||
// For fully-buffered bodies, body is nil, so Close is a no-op.
|
||||
if brc.body != nil {
|
||||
return brc.body.Close()
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -225,7 +225,7 @@ func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request, _ http
|
||||
if !upstream.Available() {
|
||||
continue
|
||||
}
|
||||
j := weakrand.Intn(i + 1) //nolint:gosec
|
||||
j := weakrand.IntN(i + 1) //nolint:gosec
|
||||
if j < k {
|
||||
choices[j] = upstream
|
||||
}
|
||||
@@ -274,7 +274,7 @@ func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request, _ http.Resp
|
||||
// sample: https://en.wikipedia.org/wiki/Reservoir_sampling
|
||||
if numReqs == leastReqs {
|
||||
count++
|
||||
if count == 1 || (weakrand.Int()%count) == 0 { //nolint:gosec
|
||||
if count == 1 || weakrand.IntN(count) == 0 { //nolint:gosec
|
||||
bestHost = host
|
||||
}
|
||||
}
|
||||
@@ -312,7 +312,7 @@ func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, _ http.
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := uint32(0); i < n; i++ {
|
||||
for range n {
|
||||
robin := atomic.AddUint32(&r.robin, 1)
|
||||
host := pool[robin%n]
|
||||
if host.Available() {
|
||||
@@ -617,7 +617,7 @@ type CookieHashSelection struct {
|
||||
// The HTTP cookie name whose value is to be hashed and used for upstream selection.
|
||||
Name string `json:"name,omitempty"`
|
||||
// Secret to hash (Hmac256) chosen upstream in cookie
|
||||
Secret string `json:"secret,omitempty"`
|
||||
Secret string `json:"secret,omitempty"` //nolint:gosec // yes it's exported because it needs to encode to JSON
|
||||
// The cookie's Max-Age before it expires. Default is no expiry.
|
||||
MaxAge caddy.Duration `json:"max_age,omitempty"`
|
||||
|
||||
@@ -788,7 +788,7 @@ func selectRandomHost(pool []*Upstream) *Upstream {
|
||||
// upstream will always be chosen if there is at
|
||||
// least one available
|
||||
count++
|
||||
if (weakrand.Int() % count) == 0 { //nolint:gosec
|
||||
if weakrand.IntN(count) == 0 { //nolint:gosec
|
||||
randomHost = upstream
|
||||
}
|
||||
}
|
||||
@@ -827,7 +827,7 @@ func leastRequests(upstreams []*Upstream) *Upstream {
|
||||
if len(best) == 1 {
|
||||
return best[0]
|
||||
}
|
||||
return best[weakrand.Intn(len(best))] //nolint:gosec
|
||||
return best[weakrand.IntN(len(best))] //nolint:gosec
|
||||
}
|
||||
|
||||
// hostByHashing returns an available host from pool based on a hashable string s.
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"mime"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -214,7 +214,10 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
|
||||
timeoutc = timer.C
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
// when a stream timeout is encountered, no error will be read from errc
|
||||
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
|
||||
// see: https://github.com/caddyserver/caddy/issues/7418
|
||||
errc := make(chan error, 2)
|
||||
wg.Add(2)
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
@@ -526,14 +529,14 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
|
||||
// Create aligned word size key.
|
||||
var k [wordSize]byte
|
||||
for i := range k {
|
||||
k[i] = key[(pos+i)&3]
|
||||
k[i] = key[(pos+i)&3] // nolint:gosec // false positive, impossible to be out of bounds; see: https://github.com/securego/gosec/issues/1525
|
||||
}
|
||||
kw := *(*uintptr)(unsafe.Pointer(&k))
|
||||
|
||||
// Mask one word at a time.
|
||||
n := (len(b) / wordSize) * wordSize
|
||||
for i := 0; i < n; i += wordSize {
|
||||
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
|
||||
*(*uintptr)(unsafe.Add(unsafe.Pointer(&b[0]), i)) ^= kw
|
||||
}
|
||||
|
||||
// Mask one byte at a time for remaining bytes.
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -70,6 +70,11 @@ type SRVUpstreams struct {
|
||||
// A negative value disables this.
|
||||
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
|
||||
|
||||
// Specific network to dial when connecting to the upstream(s)
|
||||
// provided by SRV records upstream. See Go's net package for
|
||||
// accepted values. For example, to restrict to IPv4, use "tcp4".
|
||||
DialNetwork string `json:"dial_network,omitempty"`
|
||||
|
||||
resolver *net.Resolver
|
||||
|
||||
logger *zap.Logger
|
||||
@@ -102,7 +107,7 @@ func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
//nolint:gosec
|
||||
addr := su.Resolver.netAddrs[weakrand.Intn(len(su.Resolver.netAddrs))]
|
||||
addr := su.Resolver.netAddrs[weakrand.IntN(len(su.Resolver.netAddrs))]
|
||||
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
|
||||
},
|
||||
}
|
||||
@@ -177,6 +182,9 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
|
||||
)
|
||||
}
|
||||
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
|
||||
if su.DialNetwork != "" {
|
||||
addr = su.DialNetwork + "/" + addr
|
||||
}
|
||||
upstreams[i] = Upstream{Dial: addr}
|
||||
}
|
||||
|
||||
@@ -322,7 +330,7 @@ func (au *AUpstreams) Provision(ctx caddy.Context) error {
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
//nolint:gosec
|
||||
addr := au.Resolver.netAddrs[weakrand.Intn(len(au.Resolver.netAddrs))]
|
||||
addr := au.Resolver.netAddrs[weakrand.IntN(len(au.Resolver.netAddrs))]
|
||||
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
|
||||
},
|
||||
}
|
||||
|
||||
@@ -173,6 +173,7 @@ func parseCaddyfileURI(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, err
|
||||
if hasArgs {
|
||||
return nil, h.Err("Cannot specify uri query rewrites in both argument and block")
|
||||
}
|
||||
// nolint:prealloc
|
||||
queryArgs := []string{h.Val()}
|
||||
queryArgs = append(queryArgs, h.RemainingArgs()...)
|
||||
err := applyQueryOps(h, rewr.Query, queryArgs)
|
||||
|
||||
@@ -247,6 +247,7 @@ func (rewr Rewrite) Rewrite(r *http.Request, repl *caddy.Replacer) bool {
|
||||
} else {
|
||||
r.URL.Path = path
|
||||
}
|
||||
r.URL.RawPath = "" // force recomputing when EscapedPath() is called
|
||||
}
|
||||
if qsStart >= 0 {
|
||||
r.URL.RawQuery = newQuery
|
||||
|
||||
@@ -224,6 +224,11 @@ func TestRewrite(t *testing.T) {
|
||||
input: newRequest(t, "GET", "/foo#fragFirst?c=d"),
|
||||
expect: newRequest(t, "GET", "/bar#fragFirst?c=d"),
|
||||
},
|
||||
{
|
||||
rule: Rewrite{URI: "/api/admin/panel"},
|
||||
input: newRequest(t, "GET", "/api/admin%2Fpanel"),
|
||||
expect: newRequest(t, "GET", "/api/admin/panel"),
|
||||
},
|
||||
|
||||
{
|
||||
rule: Rewrite{StripPathPrefix: "/prefix"},
|
||||
|
||||
+38
-18
@@ -18,6 +18,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
@@ -96,7 +97,10 @@ type Route struct {
|
||||
MatcherSets MatcherSets `json:"-"`
|
||||
Handlers []MiddlewareHandler `json:"-"`
|
||||
|
||||
middleware []Middleware
|
||||
middleware []Middleware
|
||||
metrics *Metrics
|
||||
metricsCtx caddy.Context
|
||||
handlerName string
|
||||
}
|
||||
|
||||
// Empty returns true if the route has all zero/default values.
|
||||
@@ -110,14 +114,16 @@ func (r Route) Empty() bool {
|
||||
}
|
||||
|
||||
func (r Route) String() string {
|
||||
handlersRaw := "["
|
||||
var handlersRaw strings.Builder
|
||||
handlersRaw.WriteByte('[')
|
||||
for _, hr := range r.HandlersRaw {
|
||||
handlersRaw += " " + string(hr)
|
||||
handlersRaw.WriteByte(' ')
|
||||
handlersRaw.WriteString(string(hr))
|
||||
}
|
||||
handlersRaw += "]"
|
||||
handlersRaw.WriteByte(']')
|
||||
|
||||
return fmt.Sprintf(`{Group:"%s" MatcherSetsRaw:%s HandlersRaw:%s Terminal:%t}`,
|
||||
r.Group, r.MatcherSetsRaw, handlersRaw, r.Terminal)
|
||||
r.Group, r.MatcherSetsRaw, handlersRaw.String(), r.Terminal)
|
||||
}
|
||||
|
||||
// Provision sets up both the matchers and handlers in the route.
|
||||
@@ -159,12 +165,20 @@ func (r *Route) ProvisionHandlers(ctx caddy.Context, metrics *Metrics) error {
|
||||
r.Handlers = append(r.Handlers, handler.(MiddlewareHandler))
|
||||
}
|
||||
|
||||
// Store metrics info for route-level instrumentation (applied once
|
||||
// per route in wrapRoute, instead of per-handler which was redundant).
|
||||
r.metrics = metrics
|
||||
r.metricsCtx = ctx
|
||||
if len(r.Handlers) > 0 {
|
||||
r.handlerName = caddy.GetModuleName(r.Handlers[0])
|
||||
}
|
||||
|
||||
// Make ProvisionHandlers idempotent by clearing the middleware field
|
||||
r.middleware = []Middleware{}
|
||||
|
||||
// pre-compile the middleware handler chain
|
||||
for _, midhandler := range r.Handlers {
|
||||
r.middleware = append(r.middleware, wrapMiddleware(ctx, midhandler, metrics))
|
||||
r.middleware = append(r.middleware, wrapMiddleware(ctx, midhandler))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -295,6 +309,16 @@ func wrapRoute(route Route) Middleware {
|
||||
nextCopy = route.middleware[i](nextCopy)
|
||||
}
|
||||
|
||||
// Apply metrics instrumentation once for the entire route,
|
||||
// rather than wrapping each individual handler. This avoids
|
||||
// redundant metrics collection that caused significant CPU
|
||||
// overhead (see issue #4644).
|
||||
if route.metrics != nil {
|
||||
nextCopy = newMetricsInstrumentedRoute(
|
||||
route.metricsCtx, route.handlerName, nextCopy, route.metrics,
|
||||
)
|
||||
}
|
||||
|
||||
return nextCopy.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
@@ -303,20 +327,14 @@ func wrapRoute(route Route) Middleware {
|
||||
// wrapMiddleware wraps mh such that it can be correctly
|
||||
// appended to a list of middleware in preparation for
|
||||
// compiling into a handler chain.
|
||||
func wrapMiddleware(ctx caddy.Context, mh MiddlewareHandler, metrics *Metrics) Middleware {
|
||||
handlerToUse := mh
|
||||
if metrics != nil {
|
||||
// wrap the middleware with metrics instrumentation
|
||||
handlerToUse = newMetricsInstrumentedHandler(ctx, caddy.GetModuleName(mh), mh, metrics)
|
||||
}
|
||||
|
||||
func wrapMiddleware(ctx caddy.Context, mh MiddlewareHandler) Middleware {
|
||||
return func(next Handler) Handler {
|
||||
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
// EXPERIMENTAL: Trace each module that gets invoked
|
||||
if server, ok := r.Context().Value(ServerCtxKey).(*Server); ok && server != nil {
|
||||
server.logTrace(handlerToUse)
|
||||
server.logTrace(mh)
|
||||
}
|
||||
return handlerToUse.ServeHTTP(w, r, next)
|
||||
return mh.ServeHTTP(w, r, next)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -440,13 +458,15 @@ func (ms *MatcherSets) FromInterface(matcherSets any) error {
|
||||
|
||||
// TODO: Is this used?
|
||||
func (ms MatcherSets) String() string {
|
||||
result := "["
|
||||
var result strings.Builder
|
||||
result.WriteByte('[')
|
||||
for _, matcherSet := range ms {
|
||||
for _, matcher := range matcherSet {
|
||||
result += fmt.Sprintf(" %#v", matcher)
|
||||
fmt.Fprintf(&result, " %#v", matcher)
|
||||
}
|
||||
}
|
||||
return result + " ]"
|
||||
result.WriteByte(']')
|
||||
return result.String()
|
||||
}
|
||||
|
||||
var routeGroupCtxKey = caddy.CtxKey("route_group")
|
||||
|
||||
+117
-77
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -33,7 +34,7 @@ import (
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
h3qlog "github.com/quic-go/quic-go/http3/qlog"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
@@ -55,6 +56,10 @@ type Server struct {
|
||||
// of the base listener. They are applied in the given order.
|
||||
ListenerWrappersRaw []json.RawMessage `json:"listener_wrappers,omitempty" caddy:"namespace=caddy.listeners inline_key=wrapper"`
|
||||
|
||||
// A list of packet conn wrapper modules, which can modify the behavior
|
||||
// of the base packet conn. They are applied in the given order.
|
||||
PacketConnWrappersRaw []json.RawMessage `json:"packet_conn_wrappers,omitempty" caddy:"namespace=caddy.packetconns inline_key=wrapper"`
|
||||
|
||||
// How long to allow a read from a client's upload. Setting this
|
||||
// to a short, non-zero value can mitigate slowloris attacks, but
|
||||
// may also affect legitimately slow clients.
|
||||
@@ -248,6 +253,16 @@ type Server struct {
|
||||
// A nil value or element indicates that Protocols will be used instead.
|
||||
ListenProtocols [][]string `json:"listen_protocols,omitempty"`
|
||||
|
||||
// If set, overrides whether QUIC listeners allow 0-RTT (early data).
|
||||
// If nil, the default behavior is used (currently allowed).
|
||||
//
|
||||
// One reason to disable 0-RTT is if a remote IP matcher is used,
|
||||
// which introduces a dependency on the remote address being verified
|
||||
// if routing happens before the TLS handshake completes. An HTTP 425
|
||||
// response is written in that case, but some clients misbehave and
|
||||
// don't perform a retry, so disabling 0-RTT can smooth it out.
|
||||
Allow0RTT *bool `json:"allow_0rtt,omitempty"`
|
||||
|
||||
// If set, metrics observations will be enabled.
|
||||
// This setting is EXPERIMENTAL and subject to change.
|
||||
// DEPRECATED: Use the app-level `metrics` field.
|
||||
@@ -258,7 +273,8 @@ type Server struct {
|
||||
primaryHandlerChain Handler
|
||||
errorHandlerChain Handler
|
||||
listenerWrappers []caddy.ListenerWrapper
|
||||
listeners []net.Listener // stdlib http.Server will close these
|
||||
packetConnWrappers []caddy.PacketConnWrapper
|
||||
listeners []net.Listener
|
||||
quicListeners []http3.QUICListener // http3 now leave the quic.Listener management to us
|
||||
|
||||
tlsApp *caddytls.TLS
|
||||
@@ -285,8 +301,15 @@ type Server struct {
|
||||
onStopFuncs []func(context.Context) error // TODO: Experimental (Nov. 2023)
|
||||
}
|
||||
|
||||
var (
|
||||
ServerHeader = "Caddy"
|
||||
serverHeader = []string{ServerHeader}
|
||||
)
|
||||
|
||||
// ServeHTTP is the entry point for all HTTP requests.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// If there are listener wrappers that process tls connections but don't return a *tls.Conn, this field will be nil.
|
||||
if r.TLS == nil {
|
||||
if tlsConnStateFunc, ok := r.Context().Value(tlsConnectionStateFuncCtxKey).(func() *tls.ConnectionState); ok {
|
||||
@@ -294,55 +317,37 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Server", "Caddy")
|
||||
|
||||
// advertise HTTP/3, if enabled
|
||||
if s.h3server != nil {
|
||||
if r.ProtoMajor < 3 {
|
||||
err := s.h3server.SetQUICHeaders(w.Header())
|
||||
if err != nil {
|
||||
if c := s.logger.Check(zapcore.ErrorLevel, "setting HTTP/3 Alt-Svc header"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reject very long methods; probably a mistake or an attack
|
||||
if len(r.Method) > 32 {
|
||||
if s.shouldLogRequest(r) {
|
||||
if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil {
|
||||
c.Write(
|
||||
zap.String("method_trunc", r.Method[:32]),
|
||||
zap.String("remote_addr", r.RemoteAddr),
|
||||
)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
repl := caddy.NewReplacer()
|
||||
r = PrepareRequest(r, repl, w, s)
|
||||
|
||||
// enable full-duplex for HTTP/1, ensuring the entire
|
||||
// request body gets consumed before writing the response
|
||||
if s.EnableFullDuplex && r.ProtoMajor == 1 {
|
||||
//nolint:bodyclose
|
||||
err := http.NewResponseController(w).EnableFullDuplex()
|
||||
if err != nil {
|
||||
if err := http.NewResponseController(w).EnableFullDuplex(); err != nil { //nolint:bodyclose
|
||||
if c := s.logger.Check(zapcore.WarnLevel, "failed to enable full duplex"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clone the request for logging purposes before
|
||||
// it enters any handler chain; this is necessary
|
||||
// to capture the original request in case it gets
|
||||
// modified during handling
|
||||
// cloning the request and using .WithLazy is considerably faster
|
||||
// than using .With, which will JSON encode the request immediately
|
||||
// set the Server header
|
||||
h := w.Header()
|
||||
h["Server"] = serverHeader
|
||||
|
||||
// advertise HTTP/3, if enabled
|
||||
if s.h3server != nil && r.ProtoMajor < 3 {
|
||||
if err := s.h3server.SetQUICHeaders(h); err != nil {
|
||||
if c := s.logger.Check(zapcore.ErrorLevel, "setting HTTP/3 Alt-Svc header"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepare internals of the request for the handler pipeline
|
||||
repl := caddy.NewReplacer()
|
||||
r = PrepareRequest(r, repl, w, s)
|
||||
|
||||
// clone the request for logging purposes before it enters any handler chain;
|
||||
// this is necessary to capture the original request in case it gets modified
|
||||
// during handling (cloning the request and using .WithLazy is considerably
|
||||
// faster than using .With, which will JSON-encode the request immediately)
|
||||
shouldLogCredentials := s.Logs != nil && s.Logs.ShouldLogCredentials
|
||||
loggableReq := zap.Object("request", LoggableHTTPRequest{
|
||||
Request: r.Clone(r.Context()),
|
||||
@@ -370,36 +375,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// capture the original version of the request
|
||||
accLog := s.accessLogger.With(loggableReq)
|
||||
accLog := s.accessLogger.WithLazy(loggableReq)
|
||||
|
||||
defer s.logRequest(accLog, r, wrec, &duration, repl, bodyReader, shouldLogCredentials)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// guarantee ACME HTTP challenges; handle them
|
||||
// separately from any user-defined handlers
|
||||
// guarantee ACME HTTP challenges; handle them separately from any user-defined handlers
|
||||
if s.tlsApp.HandleHTTPChallenge(w, r) {
|
||||
duration = time.Since(start)
|
||||
return
|
||||
}
|
||||
|
||||
// execute the primary handler chain
|
||||
err := s.primaryHandlerChain.ServeHTTP(w, r)
|
||||
err := s.serveHTTP(w, r)
|
||||
duration = time.Since(start)
|
||||
|
||||
// if no errors, we're done!
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// restore original request before invoking error handler chain (issue #3717)
|
||||
// TODO: this does not restore original headers, if modified (for efficiency)
|
||||
origReq := r.Context().Value(OriginalRequestCtxKey).(http.Request)
|
||||
r.Method = origReq.Method
|
||||
r.RemoteAddr = origReq.RemoteAddr
|
||||
r.RequestURI = origReq.RequestURI
|
||||
cloneURL(origReq.URL, r.URL)
|
||||
// NOTE: this does not restore original headers if modified (for efficiency)
|
||||
origReq, ok := r.Context().Value(OriginalRequestCtxKey).(http.Request)
|
||||
if ok {
|
||||
r.Method = origReq.Method
|
||||
r.RemoteAddr = origReq.RemoteAddr
|
||||
r.RequestURI = origReq.RequestURI
|
||||
cloneURL(origReq.URL, r.URL)
|
||||
}
|
||||
|
||||
// prepare the error log
|
||||
errLog = errLog.With(zap.Duration("duration", duration))
|
||||
@@ -417,10 +419,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var fields []zapcore.Field
|
||||
if s.Errors != nil && len(s.Errors.Routes) > 0 {
|
||||
// execute user-defined error handling route
|
||||
err2 := s.errorHandlerChain.ServeHTTP(w, r)
|
||||
if err2 == nil {
|
||||
// user's error route handled the error response
|
||||
// successfully, so now just log the error
|
||||
if err2 := s.errorHandlerChain.ServeHTTP(w, r); err2 == nil {
|
||||
// user's error route handled the error response successfully, so now just log the error
|
||||
for _, logger := range errLoggers {
|
||||
if c := logger.Check(zapcore.DebugLevel, errMsg); c != nil {
|
||||
if fields == nil {
|
||||
@@ -468,6 +468,35 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
// reject very long methods; probably a mistake or an attack
|
||||
if len(r.Method) > 32 {
|
||||
if s.shouldLogRequest(r) {
|
||||
if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil {
|
||||
c.Write(
|
||||
zap.String("method_trunc", r.Method[:32]),
|
||||
zap.String("remote_addr", r.RemoteAddr),
|
||||
)
|
||||
}
|
||||
}
|
||||
return HandlerError{StatusCode: http.StatusMethodNotAllowed}
|
||||
}
|
||||
|
||||
// RFC 9112 section 3.2: "A server MUST respond with a 400 (Bad Request) status
|
||||
// code to any HTTP/1.1 request message that lacks a Host header field and to any
|
||||
// request message that contains more than one Host header field line or a Host
|
||||
// header field with an invalid field value."
|
||||
if r.ProtoMajor == 1 && r.ProtoMinor == 1 && r.Host == "" {
|
||||
return HandlerError{
|
||||
Err: errors.New("rfc9112 forbids empty Host"),
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// execute the primary handler chain
|
||||
return s.primaryHandlerChain.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// wrapPrimaryRoute wraps stack (a compiled middleware handler chain)
|
||||
// in s.enforcementHandler which performs crucial security checks, etc.
|
||||
func (s *Server) wrapPrimaryRoute(stack Handler) Handler {
|
||||
@@ -551,15 +580,21 @@ func (s *Server) hasListenerAddress(fullAddr string) bool {
|
||||
// The second issue seems very similar to a discussion here:
|
||||
// https://github.com/nodejs/node/issues/9390
|
||||
//
|
||||
// This is very easy to reproduce by creating an HTTP server
|
||||
// that listens to both addresses or just one with a host
|
||||
// interface; or for a more confusing reproduction, try
|
||||
// listening on "127.0.0.1:80" and ":443" and you'll see
|
||||
// the error, if you take away the GOOS condition below.
|
||||
//
|
||||
// So, an address is equivalent if the port is in the port
|
||||
// range, and if not on Linux, the host is the same... sigh.
|
||||
if (runtime.GOOS == "linux" || thisAddrs.Host == laddrs.Host) &&
|
||||
// However, binding to *different specific* interfaces
|
||||
// (e.g. 127.0.0.2:80 and 127.0.0.3:80) IS allowed on Linux.
|
||||
// The conflict only happens when mixing specific IPs with
|
||||
// wildcards (0.0.0.0 or ::).
|
||||
|
||||
// Hosts match exactly (e.g. 127.0.0.2 == 127.0.0.2) -> Conflict.
|
||||
hostMatch := thisAddrs.Host == laddrs.Host
|
||||
|
||||
// On Linux, specific IP vs Wildcard fails to bind.
|
||||
// So if we are on Linux AND either host is empty (wildcard), we treat
|
||||
// it as a match (conflict). But if both are specific and different
|
||||
// (127.0.0.2 vs 127.0.0.3), this remains false (no conflict).
|
||||
linuxWildcardConflict := runtime.GOOS == "linux" && (thisAddrs.Host == "" || laddrs.Host == "")
|
||||
|
||||
if (hostMatch || linuxWildcardConflict) &&
|
||||
(laddrs.StartPort <= thisAddrs.EndPort) &&
|
||||
(laddrs.StartPort >= thisAddrs.StartPort) {
|
||||
return true
|
||||
@@ -625,7 +660,7 @@ func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error
|
||||
return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err)
|
||||
}
|
||||
addr.Network = h3net
|
||||
h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg)
|
||||
h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg, s.packetConnWrappers, s.Allow0RTT)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err)
|
||||
}
|
||||
@@ -638,7 +673,7 @@ func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error
|
||||
MaxHeaderBytes: s.MaxHeaderBytes,
|
||||
QUICConfig: &quic.Config{
|
||||
Versions: []quic.Version{quic.Version1, quic.Version2},
|
||||
Tracer: qlog.DefaultConnectionTracer,
|
||||
Tracer: h3qlog.DefaultConnectionTracer,
|
||||
},
|
||||
IdleTimeout: time.Duration(s.IdleTimeout),
|
||||
}
|
||||
@@ -763,9 +798,11 @@ func (s *Server) shouldLogRequest(r *http.Request) bool {
|
||||
hostWithoutPort = r.Host
|
||||
}
|
||||
|
||||
if _, ok := s.Logs.LoggerNames[hostWithoutPort]; ok {
|
||||
// this host is mapped to a particular logger name
|
||||
return true
|
||||
for loggerName := range s.Logs.LoggerNames {
|
||||
if certmagic.MatchWildcard(hostWithoutPort, loggerName) {
|
||||
// this host is mapped to a particular logger name
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, dh := range s.Logs.SkipHosts {
|
||||
// logging for this particular host is disabled
|
||||
@@ -793,8 +830,10 @@ func (s *Server) logRequest(
|
||||
accLog *zap.Logger, r *http.Request, wrec ResponseRecorder, duration *time.Duration,
|
||||
repl *caddy.Replacer, bodyReader *lengthReader, shouldLogCredentials bool,
|
||||
) {
|
||||
ctx := r.Context()
|
||||
|
||||
// this request may be flagged as omitted from the logs
|
||||
if skip, ok := GetVar(r.Context(), LogSkipVar).(bool); ok && skip {
|
||||
if skip, ok := GetVar(ctx, LogSkipVar).(bool); ok && skip {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -812,7 +851,7 @@ func (s *Server) logRequest(
|
||||
}
|
||||
|
||||
message := "handled request"
|
||||
if nop, ok := GetVar(r.Context(), "unhandled").(bool); ok && nop {
|
||||
if nop, ok := GetVar(ctx, "unhandled").(bool); ok && nop {
|
||||
message = "NOP"
|
||||
}
|
||||
|
||||
@@ -836,7 +875,7 @@ func (s *Server) logRequest(
|
||||
reqBodyLength = bodyReader.Length
|
||||
}
|
||||
|
||||
extra := r.Context().Value(ExtraLogFieldsCtxKey).(*ExtraLogFields)
|
||||
extra := ctx.Value(ExtraLogFieldsCtxKey).(*ExtraLogFields)
|
||||
|
||||
fieldCount := 6
|
||||
fields = make([]zapcore.Field, 0, fieldCount+len(extra.fields))
|
||||
@@ -1001,6 +1040,7 @@ func isTrustedClientIP(ipAddr netip.Addr, trusted []netip.Prefix) bool {
|
||||
// then the first value from those headers is used.
|
||||
func trustedRealClientIP(r *http.Request, headers []string, clientIP string) string {
|
||||
// Read all the values of the configured client IP headers, in order
|
||||
// nolint:prealloc
|
||||
var values []string
|
||||
for _, field := range headers {
|
||||
values = append(values, r.Header.Values(field)...)
|
||||
|
||||
@@ -246,7 +246,7 @@ func (s StaticResponse) ServeHTTP(w http.ResponseWriter, r *http.Request, next H
|
||||
|
||||
// write response body
|
||||
if statusCode != http.StatusEarlyHints && body != "" {
|
||||
fmt.Fprint(w, body)
|
||||
fmt.Fprint(w, body) //nolint:gosec // no XSS unless you sabatoge your own config
|
||||
}
|
||||
|
||||
// continue handling after Early Hints as they are not the final response
|
||||
@@ -257,7 +257,16 @@ func (s StaticResponse) ServeHTTP(w http.ResponseWriter, r *http.Request, next H
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildHTTPServer(i int, port uint, addr string, statusCode int, hdr http.Header, body string, accessLog bool) (*Server, error) {
|
||||
func buildHTTPServer(
|
||||
i int,
|
||||
port uint,
|
||||
addr string,
|
||||
statusCode int,
|
||||
hdr http.Header,
|
||||
body string,
|
||||
accessLog bool,
|
||||
) (*Server, error) {
|
||||
// nolint:prealloc
|
||||
var handlers []json.RawMessage
|
||||
|
||||
// response body supports a basic template; evaluate it
|
||||
|
||||
@@ -306,6 +306,13 @@ func init() {
|
||||
// find the documentation on time layouts [in Go's docs](https://pkg.go.dev/time#pkg-constants).
|
||||
// The default time layout is `RFC1123Z`, i.e. `Mon, 02 Jan 2006 15:04:05 -0700`.
|
||||
//
|
||||
// ```
|
||||
// {{humanize "size" "2048000"}}
|
||||
// {{placeholder "http.response.header.Content-Length" | humanize "size"}}
|
||||
// {{humanize "time" "Fri, 05 May 2022 15:04:05 +0200"}}
|
||||
// {{humanize "time:2006-Jan-02" "2022-May-05"}}
|
||||
// ```
|
||||
//
|
||||
// ##### `pathEscape`
|
||||
//
|
||||
// Passes a string through `url.PathEscape`, replacing characters that have
|
||||
@@ -318,11 +325,22 @@ func init() {
|
||||
// {{pathEscape "50%_valid_filename?.jpg"}}
|
||||
// ```
|
||||
//
|
||||
// ##### `maybe`
|
||||
//
|
||||
// Invokes a custom template function only if it is registered (plugged-in)
|
||||
// in the `http.handlers.templates.functions.*` namespace.
|
||||
//
|
||||
// The first argument is the function name, and any subsequent arguments
|
||||
// are forwarded to that function. If the named function is not available,
|
||||
// the invocation is ignored and a log message is emitted.
|
||||
//
|
||||
// This is useful for templates that optionally use components which may
|
||||
// not be present in every build or environment.
|
||||
//
|
||||
// NOTE: This function is EXPERIMENTAL and subject to change or removal.
|
||||
//
|
||||
// ```
|
||||
// {{humanize "size" "2048000"}}
|
||||
// {{placeholder "http.response.header.Content-Length" | humanize "size"}}
|
||||
// {{humanize "time" "Fri, 05 May 2022 15:04:05 +0200"}}
|
||||
// {{humanize "time:2006-Jan-02" "2022-May-05"}}
|
||||
// {{ maybe "myOptionalFunc" "arg1" 2 }}
|
||||
// ```
|
||||
type Templates struct {
|
||||
// The root path from which to load files. Required if template functions
|
||||
|
||||
@@ -27,6 +27,9 @@ type Tracing struct {
|
||||
// https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/api.md#span
|
||||
SpanName string `json:"span"`
|
||||
|
||||
// SpanAttributes are custom key-value pairs to be added to spans
|
||||
SpanAttributes map[string]string `json:"span_attributes,omitempty"`
|
||||
|
||||
// otel implements opentelemetry related logic.
|
||||
otel openTelemetryWrapper
|
||||
|
||||
@@ -46,7 +49,7 @@ func (ot *Tracing) Provision(ctx caddy.Context) error {
|
||||
ot.logger = ctx.Logger()
|
||||
|
||||
var err error
|
||||
ot.otel, err = newOpenTelemetryWrapper(ctx, ot.SpanName)
|
||||
ot.otel, err = newOpenTelemetryWrapper(ctx, ot.SpanName, ot.SpanAttributes)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -69,6 +72,10 @@ func (ot *Tracing) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyh
|
||||
//
|
||||
// tracing {
|
||||
// [span <span_name>]
|
||||
// [span_attributes {
|
||||
// attr1 value1
|
||||
// attr2 value2
|
||||
// }]
|
||||
// }
|
||||
func (ot *Tracing) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
setParameter := func(d *caddyfile.Dispenser, val *string) error {
|
||||
@@ -94,12 +101,30 @@ func (ot *Tracing) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
|
||||
for d.NextBlock(0) {
|
||||
if dst, ok := paramsMap[d.Val()]; ok {
|
||||
if err := setParameter(d, dst); err != nil {
|
||||
return err
|
||||
switch d.Val() {
|
||||
case "span_attributes":
|
||||
if ot.SpanAttributes == nil {
|
||||
ot.SpanAttributes = make(map[string]string)
|
||||
}
|
||||
for d.NextBlock(1) {
|
||||
key := d.Val()
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
value := d.Val()
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
ot.SpanAttributes[key] = value
|
||||
}
|
||||
default:
|
||||
if dst, ok := paramsMap[d.Val()]; ok {
|
||||
if err := setParameter(d, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return d.ArgErr()
|
||||
}
|
||||
} else {
|
||||
return d.ArgErr()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -2,12 +2,16 @@ package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/sdk/trace/tracetest"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
@@ -15,17 +19,26 @@ import (
|
||||
|
||||
func TestTracing_UnmarshalCaddyfile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spanName string
|
||||
d *caddyfile.Dispenser
|
||||
wantErr bool
|
||||
name string
|
||||
spanName string
|
||||
spanAttributes map[string]string
|
||||
d *caddyfile.Dispenser
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Full config",
|
||||
spanName: "my-span",
|
||||
spanAttributes: map[string]string{
|
||||
"attr1": "value1",
|
||||
"attr2": "value2",
|
||||
},
|
||||
d: caddyfile.NewTestDispenser(`
|
||||
tracing {
|
||||
span my-span
|
||||
span_attributes {
|
||||
attr1 value1
|
||||
attr2 value2
|
||||
}
|
||||
}`),
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -42,6 +55,21 @@ tracing {
|
||||
name: "Empty config",
|
||||
d: caddyfile.NewTestDispenser(`
|
||||
tracing {
|
||||
}`),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Only span attributes",
|
||||
spanAttributes: map[string]string{
|
||||
"service.name": "my-service",
|
||||
"service.version": "1.0.0",
|
||||
},
|
||||
d: caddyfile.NewTestDispenser(`
|
||||
tracing {
|
||||
span_attributes {
|
||||
service.name my-service
|
||||
service.version 1.0.0
|
||||
}
|
||||
}`),
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -56,6 +84,20 @@ tracing {
|
||||
if ot.SpanName != tt.spanName {
|
||||
t.Errorf("UnmarshalCaddyfile() SpanName = %v, want SpanName %v", ot.SpanName, tt.spanName)
|
||||
}
|
||||
|
||||
if len(tt.spanAttributes) > 0 {
|
||||
if ot.SpanAttributes == nil {
|
||||
t.Errorf("UnmarshalCaddyfile() SpanAttributes is nil, expected %v", tt.spanAttributes)
|
||||
} else {
|
||||
for key, expectedValue := range tt.spanAttributes {
|
||||
if actualValue, exists := ot.SpanAttributes[key]; !exists {
|
||||
t.Errorf("UnmarshalCaddyfile() SpanAttributes missing key %v", key)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("UnmarshalCaddyfile() SpanAttributes[%v] = %v, want %v", key, actualValue, expectedValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -79,6 +121,26 @@ func TestTracing_UnmarshalCaddyfile_Error(t *testing.T) {
|
||||
d: caddyfile.NewTestDispenser(`
|
||||
tracing {
|
||||
span
|
||||
}`),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Span attributes missing value",
|
||||
d: caddyfile.NewTestDispenser(`
|
||||
tracing {
|
||||
span_attributes {
|
||||
key
|
||||
}
|
||||
}`),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Span attributes too many arguments",
|
||||
d: caddyfile.NewTestDispenser(`
|
||||
tracing {
|
||||
span_attributes {
|
||||
key value extra
|
||||
}
|
||||
}`),
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -181,6 +243,160 @@ func TestTracing_ServeHTTP_Next_Error(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTracing_JSON_Configuration(t *testing.T) {
|
||||
// Test that our struct correctly marshals to and from JSON
|
||||
original := &Tracing{
|
||||
SpanName: "test-span",
|
||||
SpanAttributes: map[string]string{
|
||||
"service.name": "test-service",
|
||||
"service.version": "1.0.0",
|
||||
"env": "test",
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal to JSON: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled Tracing
|
||||
if err := json.Unmarshal(jsonData, &unmarshaled); err != nil {
|
||||
t.Fatalf("Failed to unmarshal from JSON: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.SpanName != original.SpanName {
|
||||
t.Errorf("Expected SpanName %s, got %s", original.SpanName, unmarshaled.SpanName)
|
||||
}
|
||||
|
||||
if len(unmarshaled.SpanAttributes) != len(original.SpanAttributes) {
|
||||
t.Errorf("Expected %d span attributes, got %d", len(original.SpanAttributes), len(unmarshaled.SpanAttributes))
|
||||
}
|
||||
|
||||
for key, expectedValue := range original.SpanAttributes {
|
||||
if actualValue, exists := unmarshaled.SpanAttributes[key]; !exists {
|
||||
t.Errorf("Expected span attribute %s to exist", key)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("Expected span attribute %s = %s, got %s", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("JSON representation: %s", string(jsonData))
|
||||
}
|
||||
|
||||
func TestTracing_OpenTelemetry_Span_Attributes(t *testing.T) {
|
||||
// Create an in-memory span recorder to capture actual span data
|
||||
spanRecorder := tracetest.NewSpanRecorder()
|
||||
provider := trace.NewTracerProvider(
|
||||
trace.WithSpanProcessor(spanRecorder),
|
||||
)
|
||||
|
||||
// Create our tracing module with span attributes that include placeholders
|
||||
ot := &Tracing{
|
||||
SpanName: "test-span",
|
||||
SpanAttributes: map[string]string{
|
||||
"static": "test-service",
|
||||
"request-placeholder": "{http.request.method}",
|
||||
"response-placeholder": "{http.response.header.X-Some-Header}",
|
||||
"mixed": "prefix-{http.request.method}-{http.response.header.X-Some-Header}",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a specific request to test against
|
||||
req, _ := http.NewRequest("POST", "https://api.example.com/v1/users?id=123", nil)
|
||||
req.Host = "api.example.com"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Set up the replacer
|
||||
repl := caddy.NewReplacer()
|
||||
ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl)
|
||||
ctx = context.WithValue(ctx, caddyhttp.VarsCtxKey, make(map[string]any))
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// Set up request placeholders
|
||||
repl.Set("http.request.method", req.Method)
|
||||
repl.Set("http.request.uri", req.URL.RequestURI())
|
||||
|
||||
// Handler to generate the response
|
||||
var handler caddyhttp.HandlerFunc = func(writer http.ResponseWriter, request *http.Request) error {
|
||||
writer.Header().Set("X-Some-Header", "some-value")
|
||||
writer.WriteHeader(200)
|
||||
|
||||
// Make response headers available to replacer
|
||||
repl.Set("http.response.header.X-Some-Header", writer.Header().Get("X-Some-Header"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set up Caddy context
|
||||
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Override the global tracer provider with our test provider
|
||||
// This is a bit hacky but necessary to capture the actual spans
|
||||
originalProvider := globalTracerProvider
|
||||
globalTracerProvider = &tracerProvider{
|
||||
tracerProvider: provider,
|
||||
tracerProvidersCounter: 1, // Simulate one user
|
||||
}
|
||||
defer func() {
|
||||
globalTracerProvider = originalProvider
|
||||
}()
|
||||
|
||||
// Provision the tracing module
|
||||
if err := ot.Provision(caddyCtx); err != nil {
|
||||
t.Errorf("Provision error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Execute the request
|
||||
if err := ot.ServeHTTP(w, req, handler); err != nil {
|
||||
t.Errorf("ServeHTTP error: %v", err)
|
||||
}
|
||||
|
||||
// Get the recorded spans
|
||||
spans := spanRecorder.Ended()
|
||||
if len(spans) == 0 {
|
||||
t.Fatal("Expected at least one span to be recorded")
|
||||
}
|
||||
|
||||
// Find our span (should be the one with our test span name)
|
||||
var testSpan trace.ReadOnlySpan
|
||||
for _, span := range spans {
|
||||
if span.Name() == "test-span" {
|
||||
testSpan = span
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if testSpan == nil {
|
||||
t.Fatal("Could not find test span in recorded spans")
|
||||
}
|
||||
|
||||
// Verify that the span attributes were set correctly with placeholder replacement
|
||||
expectedAttributes := map[string]string{
|
||||
"static": "test-service",
|
||||
"request-placeholder": "POST",
|
||||
"response-placeholder": "some-value",
|
||||
"mixed": "prefix-POST-some-value",
|
||||
}
|
||||
|
||||
actualAttributes := make(map[string]string)
|
||||
for _, attr := range testSpan.Attributes() {
|
||||
actualAttributes[string(attr.Key)] = attr.Value.AsString()
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedAttributes {
|
||||
if actualValue, exists := actualAttributes[key]; !exists {
|
||||
t.Errorf("Expected span attribute %s to be set", key)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("Expected span attribute %s = %s, got %s", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Recorded span attributes: %+v", actualAttributes)
|
||||
}
|
||||
|
||||
func createRequestWithContext(method string, url string) *http.Request {
|
||||
r, _ := http.NewRequest(method, url, nil)
|
||||
repl := caddy.NewReplacer()
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.opentelemetry.io/contrib/exporters/autoexport"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.opentelemetry.io/contrib/propagators/autoprop"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
@@ -37,20 +38,23 @@ type openTelemetryWrapper struct {
|
||||
|
||||
handler http.Handler
|
||||
|
||||
spanName string
|
||||
spanName string
|
||||
spanAttributes map[string]string
|
||||
}
|
||||
|
||||
// newOpenTelemetryWrapper is responsible for the openTelemetryWrapper initialization using provided configuration.
|
||||
func newOpenTelemetryWrapper(
|
||||
ctx context.Context,
|
||||
spanName string,
|
||||
spanAttributes map[string]string,
|
||||
) (openTelemetryWrapper, error) {
|
||||
if spanName == "" {
|
||||
spanName = defaultSpanName
|
||||
}
|
||||
|
||||
ot := openTelemetryWrapper{
|
||||
spanName: spanName,
|
||||
spanName: spanName,
|
||||
spanAttributes: spanAttributes,
|
||||
}
|
||||
|
||||
version, _ := caddy.Version()
|
||||
@@ -59,7 +63,7 @@ func newOpenTelemetryWrapper(
|
||||
return ot, fmt.Errorf("creating resource error: %w", err)
|
||||
}
|
||||
|
||||
traceExporter, err := otlptracegrpc.New(ctx)
|
||||
traceExporter, err := autoexport.NewSpanExporter(ctx)
|
||||
if err != nil {
|
||||
return ot, fmt.Errorf("creating trace exporter error: %w", err)
|
||||
}
|
||||
@@ -99,8 +103,22 @@ func (ot *openTelemetryWrapper) serveHTTP(w http.ResponseWriter, r *http.Request
|
||||
extra.Add(zap.String("spanID", spanID))
|
||||
}
|
||||
}
|
||||
|
||||
next := ctx.Value(nextCallCtxKey).(*nextCall)
|
||||
next.err = next.next.ServeHTTP(w, r)
|
||||
|
||||
// Add custom span attributes to the current span
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span.IsRecording() && len(ot.spanAttributes) > 0 {
|
||||
replacer := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
attributes := make([]attribute.KeyValue, 0, len(ot.spanAttributes))
|
||||
for key, value := range ot.spanAttributes {
|
||||
// Allow placeholder replacement in attribute values
|
||||
replacedValue := replacer.ReplaceAll(value, "")
|
||||
attributes = append(attributes, attribute.String(key, replacedValue))
|
||||
}
|
||||
span.SetAttributes(attributes...)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP propagates call to the by wrapped by `otelhttp` next handler.
|
||||
|
||||
@@ -16,6 +16,7 @@ func TestOpenTelemetryWrapper_newOpenTelemetryWrapper(t *testing.T) {
|
||||
|
||||
if otw, err = newOpenTelemetryWrapper(ctx,
|
||||
"",
|
||||
nil,
|
||||
); err != nil {
|
||||
t.Errorf("newOpenTelemetryWrapper() error = %v", err)
|
||||
t.FailNow()
|
||||
|
||||
+45
-21
@@ -181,33 +181,46 @@ func (m VarsMatcher) MatchWithError(r *http.Request) (bool, error) {
|
||||
vars := r.Context().Value(VarsCtxKey).(map[string]any)
|
||||
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
|
||||
var fromPlaceholder bool
|
||||
var matcherValExpanded, valExpanded, varStr, v string
|
||||
var varValue any
|
||||
for key, vals := range m {
|
||||
var varValue any
|
||||
if strings.HasPrefix(key, "{") &&
|
||||
strings.HasSuffix(key, "}") &&
|
||||
strings.Count(key, "{") == 1 {
|
||||
varValue, _ = repl.Get(strings.Trim(key, "{}"))
|
||||
fromPlaceholder = true
|
||||
} else {
|
||||
varValue = vars[key]
|
||||
fromPlaceholder = false
|
||||
}
|
||||
|
||||
switch vv := varValue.(type) {
|
||||
case string:
|
||||
varStr = vv
|
||||
case fmt.Stringer:
|
||||
varStr = vv.String()
|
||||
case error:
|
||||
varStr = vv.Error()
|
||||
case nil:
|
||||
varStr = ""
|
||||
default:
|
||||
varStr = fmt.Sprintf("%v", vv)
|
||||
}
|
||||
|
||||
// Only expand placeholders in values from literal variable names
|
||||
// (e.g. map outputs). Values resolved from placeholder keys are
|
||||
// already final and must not be re-expanded, as that would allow
|
||||
// user input like {env.SECRET} to be evaluated.
|
||||
valExpanded = varStr
|
||||
if !fromPlaceholder {
|
||||
valExpanded = repl.ReplaceAll(varStr, "")
|
||||
}
|
||||
|
||||
// see if any of the values given in the matcher match the actual value
|
||||
for _, v := range vals {
|
||||
matcherValExpanded := repl.ReplaceAll(v, "")
|
||||
var varStr string
|
||||
switch vv := varValue.(type) {
|
||||
case string:
|
||||
varStr = vv
|
||||
case fmt.Stringer:
|
||||
varStr = vv.String()
|
||||
case error:
|
||||
varStr = vv.Error()
|
||||
case nil:
|
||||
varStr = ""
|
||||
default:
|
||||
varStr = fmt.Sprintf("%v", vv)
|
||||
}
|
||||
if varStr == matcherValExpanded {
|
||||
for _, v = range vals {
|
||||
matcherValExpanded = repl.ReplaceAll(v, "")
|
||||
if valExpanded == matcherValExpanded {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
@@ -310,17 +323,21 @@ func (m MatchVarsRE) Match(r *http.Request) bool {
|
||||
func (m MatchVarsRE) MatchWithError(r *http.Request) (bool, error) {
|
||||
vars := r.Context().Value(VarsCtxKey).(map[string]any)
|
||||
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
|
||||
var fromPlaceholder, match bool
|
||||
var valExpanded, varStr string
|
||||
var varValue any
|
||||
for key, val := range m {
|
||||
var varValue any
|
||||
if strings.HasPrefix(key, "{") &&
|
||||
strings.HasSuffix(key, "}") &&
|
||||
strings.Count(key, "{") == 1 {
|
||||
varValue, _ = repl.Get(strings.Trim(key, "{}"))
|
||||
fromPlaceholder = true
|
||||
} else {
|
||||
varValue = vars[key]
|
||||
fromPlaceholder = false
|
||||
}
|
||||
|
||||
var varStr string
|
||||
switch vv := varValue.(type) {
|
||||
case string:
|
||||
varStr = vv
|
||||
@@ -334,8 +351,15 @@ func (m MatchVarsRE) MatchWithError(r *http.Request) (bool, error) {
|
||||
varStr = fmt.Sprintf("%v", vv)
|
||||
}
|
||||
|
||||
valExpanded := repl.ReplaceAll(varStr, "")
|
||||
if match := val.Match(valExpanded, repl); match {
|
||||
// Only expand placeholders in values from literal variable names
|
||||
// (e.g. map outputs). Values resolved from placeholder keys are
|
||||
// already final and must not be re-expanded, as that would allow
|
||||
// user input like {env.SECRET} to be evaluated.
|
||||
valExpanded = varStr
|
||||
if !fromPlaceholder {
|
||||
valExpanded = repl.ReplaceAll(varStr, "")
|
||||
}
|
||||
if match = val.Match(valExpanded, repl); match {
|
||||
return match, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ package acmeserver
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
weakrand "math/rand"
|
||||
weakrand "math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddypki"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddytls"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -140,6 +141,8 @@ func (ash *Handler) Provision(ctx caddy.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
ash.warnIfPolicyAllowsAll()
|
||||
|
||||
// get a reference to the configured CA
|
||||
appModule, err := ctx.App("pki")
|
||||
if err != nil {
|
||||
@@ -214,6 +217,21 @@ func (ash *Handler) Provision(ctx caddy.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ash *Handler) warnIfPolicyAllowsAll() {
|
||||
allow := ash.Policy.normalizeAllowRules()
|
||||
deny := ash.Policy.normalizeDenyRules()
|
||||
if allow != nil || deny != nil {
|
||||
return
|
||||
}
|
||||
|
||||
allowWildcardNames := ash.Policy != nil && ash.Policy.AllowWildcardNames
|
||||
ash.logger.Warn(
|
||||
"acme_server policy has no allow/deny rules; order identifiers are unrestricted (allow-all)",
|
||||
zap.String("ca", ash.CA),
|
||||
zap.Bool("allow_wildcard_names", allowWildcardNames),
|
||||
)
|
||||
}
|
||||
|
||||
func (ash Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||
if strings.HasPrefix(r.URL.Path, ash.PathPrefix) {
|
||||
acmeCtx := acme.NewContext(
|
||||
@@ -287,7 +305,19 @@ func (ash Handler) openDatabase() (*db.AuthDB, error) {
|
||||
// makeClient creates an ACME client which will use a custom
|
||||
// resolver instead of net.DefaultResolver.
|
||||
func (ash Handler) makeClient() (acme.Client, error) {
|
||||
for _, v := range ash.Resolvers {
|
||||
// If no local resolvers are configured, check for global resolvers from TLS app
|
||||
resolversToUse := ash.Resolvers
|
||||
if len(resolversToUse) == 0 {
|
||||
tlsAppIface, err := ash.ctx.App("tls")
|
||||
if err == nil {
|
||||
tlsApp := tlsAppIface.(*caddytls.TLS)
|
||||
if len(tlsApp.Resolvers) > 0 {
|
||||
resolversToUse = tlsApp.Resolvers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range resolversToUse {
|
||||
addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -307,7 +337,7 @@ func (ash Handler) makeClient() (acme.Client, error) {
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
//nolint:gosec
|
||||
addr := ash.resolvers[weakrand.Intn(len(ash.resolvers))]
|
||||
addr := ash.resolvers[weakrand.IntN(len(ash.resolvers))]
|
||||
return dialer.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
package acmeserver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
)
|
||||
|
||||
func TestHandler_warnIfPolicyAllowsAll(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *Policy
|
||||
wantWarns int
|
||||
wantAllowWildcard bool
|
||||
}{
|
||||
{
|
||||
name: "warns when policy is nil",
|
||||
policy: nil,
|
||||
wantWarns: 1,
|
||||
wantAllowWildcard: false,
|
||||
},
|
||||
{
|
||||
name: "warns when allow/deny rules are empty",
|
||||
policy: &Policy{},
|
||||
wantWarns: 1,
|
||||
wantAllowWildcard: false,
|
||||
},
|
||||
{
|
||||
name: "warns when only allow_wildcard_names is true",
|
||||
policy: &Policy{
|
||||
AllowWildcardNames: true,
|
||||
},
|
||||
wantWarns: 1,
|
||||
wantAllowWildcard: true,
|
||||
},
|
||||
{
|
||||
name: "does not warn when allow rules are configured",
|
||||
policy: &Policy{
|
||||
Allow: &RuleSet{
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
},
|
||||
wantWarns: 0,
|
||||
wantAllowWildcard: false,
|
||||
},
|
||||
{
|
||||
name: "does not warn when deny rules are configured",
|
||||
policy: &Policy{
|
||||
Deny: &RuleSet{
|
||||
Domains: []string{"bad.example.com"},
|
||||
},
|
||||
},
|
||||
wantWarns: 0,
|
||||
wantAllowWildcard: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
core, logs := observer.New(zap.WarnLevel)
|
||||
ash := &Handler{
|
||||
CA: "local",
|
||||
Policy: tt.policy,
|
||||
logger: zap.New(core),
|
||||
}
|
||||
|
||||
ash.warnIfPolicyAllowsAll()
|
||||
if logs.Len() != tt.wantWarns {
|
||||
t.Fatalf("expected %d warning logs, got %d", tt.wantWarns, logs.Len())
|
||||
}
|
||||
|
||||
if tt.wantWarns == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
entry := logs.All()[0]
|
||||
if entry.Level != zap.WarnLevel {
|
||||
t.Fatalf("expected warn level, got %v", entry.Level)
|
||||
}
|
||||
if !strings.Contains(entry.Message, "policy has no allow/deny rules") {
|
||||
t.Fatalf("unexpected log message: %q", entry.Message)
|
||||
}
|
||||
ctx := entry.ContextMap()
|
||||
if ctx["ca"] != "local" {
|
||||
t.Fatalf("expected ca=local, got %v", ctx["ca"])
|
||||
}
|
||||
if ctx["allow_wildcard_names"] != tt.wantAllowWildcard {
|
||||
t.Fatalf("expected allow_wildcard_names=%v, got %v", tt.wantAllowWildcard, ctx["allow_wildcard_names"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -163,9 +163,9 @@ func (a *adminAPI) handleCACerts(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
||||
_, err = w.Write(interCert)
|
||||
_, err = w.Write(interCert) //nolint:gosec // false positive... no XSS in a PEM for cryin' out loud
|
||||
if err == nil {
|
||||
_, _ = w.Write(rootCert)
|
||||
_, _ = w.Write(rootCert) //nolint:gosec // false positive... no XSS in a PEM for cryin' out loud
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -222,11 +222,16 @@ func rootAndIntermediatePEM(ca *CA) (root, inter []byte, err error) {
|
||||
if err != nil {
|
||||
return root, inter, err
|
||||
}
|
||||
inter, err = pemEncodeCert(ca.IntermediateCertificate().Raw)
|
||||
if err != nil {
|
||||
return root, inter, err
|
||||
|
||||
for _, interCert := range ca.IntermediateCertificateChain() {
|
||||
pemBytes, err := pemEncodeCert(interCert.Raw)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
inter = append(inter, pemBytes...)
|
||||
}
|
||||
return root, inter, err
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// caInfo is the response structure for the CA info API endpoint.
|
||||
|
||||
+60
-30
@@ -63,6 +63,15 @@ type CA struct {
|
||||
// The intermediate (signing) certificate; if null, one will be generated.
|
||||
Intermediate *KeyPair `json:"intermediate,omitempty"`
|
||||
|
||||
// How often to check if intermediate (and root, when applicable) certificates need renewal.
|
||||
// Default: 10m.
|
||||
MaintenanceInterval caddy.Duration `json:"maintenance_interval,omitempty"`
|
||||
|
||||
// The fraction of certificate lifetime (0.0–1.0) after which renewal is attempted.
|
||||
// For example, 0.2 means renew when 20% of the lifetime remains (e.g. ~73 days for a 1-year cert).
|
||||
// Default: 0.2.
|
||||
RenewalWindowRatio float64 `json:"renewal_window_ratio,omitempty"`
|
||||
|
||||
// Optionally configure a separate storage module associated with this
|
||||
// issuer, instead of using Caddy's global/default-configured storage.
|
||||
// This can be useful if you want to keep your signing keys in a
|
||||
@@ -75,10 +84,11 @@ type CA struct {
|
||||
// and module provisioning.
|
||||
ID string `json:"-"`
|
||||
|
||||
storage certmagic.Storage
|
||||
root, inter *x509.Certificate
|
||||
interKey any // TODO: should we just store these as crypto.Signer?
|
||||
mu *sync.RWMutex
|
||||
storage certmagic.Storage
|
||||
root *x509.Certificate
|
||||
interChain []*x509.Certificate
|
||||
interKey crypto.Signer
|
||||
mu *sync.RWMutex
|
||||
|
||||
rootCertPath string // mainly used for logging purposes if trusting
|
||||
log *zap.Logger
|
||||
@@ -125,16 +135,24 @@ func (ca *CA) Provision(ctx caddy.Context, id string, log *zap.Logger) error {
|
||||
if ca.IntermediateLifetime == 0 {
|
||||
ca.IntermediateLifetime = caddy.Duration(defaultIntermediateLifetime)
|
||||
}
|
||||
if ca.MaintenanceInterval == 0 {
|
||||
ca.MaintenanceInterval = caddy.Duration(defaultMaintenanceInterval)
|
||||
}
|
||||
if ca.RenewalWindowRatio <= 0 || ca.RenewalWindowRatio > 1 {
|
||||
ca.RenewalWindowRatio = defaultRenewalWindowRatio
|
||||
}
|
||||
|
||||
// load the certs and key that will be used for signing
|
||||
var rootCert, interCert *x509.Certificate
|
||||
var rootCert *x509.Certificate
|
||||
var rootCertChain, interCertChain []*x509.Certificate
|
||||
var rootKey, interKey crypto.Signer
|
||||
var err error
|
||||
if ca.Root != nil {
|
||||
if ca.Root.Format == "" || ca.Root.Format == "pem_file" {
|
||||
ca.rootCertPath = ca.Root.Certificate
|
||||
}
|
||||
rootCert, rootKey, err = ca.Root.Load()
|
||||
rootCertChain, rootKey, err = ca.Root.Load()
|
||||
rootCert = rootCertChain[0]
|
||||
} else {
|
||||
ca.rootCertPath = "storage:" + ca.storageKeyRootCert()
|
||||
rootCert, rootKey, err = ca.loadOrGenRoot()
|
||||
@@ -142,21 +160,23 @@ func (ca *CA) Provision(ctx caddy.Context, id string, log *zap.Logger) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actualRootLifetime := time.Until(rootCert.NotAfter)
|
||||
if time.Duration(ca.IntermediateLifetime) >= actualRootLifetime {
|
||||
return fmt.Errorf("intermediate certificate lifetime must be less than actual root certificate lifetime (%s)", actualRootLifetime)
|
||||
}
|
||||
|
||||
if ca.Intermediate != nil {
|
||||
interCert, interKey, err = ca.Intermediate.Load()
|
||||
interCertChain, interKey, err = ca.Intermediate.Load()
|
||||
} else {
|
||||
interCert, interKey, err = ca.loadOrGenIntermediate(rootCert, rootKey)
|
||||
actualRootLifetime := time.Until(rootCert.NotAfter)
|
||||
if time.Duration(ca.IntermediateLifetime) >= actualRootLifetime {
|
||||
return fmt.Errorf("intermediate certificate lifetime must be less than actual root certificate lifetime (%s)", actualRootLifetime)
|
||||
}
|
||||
|
||||
interCertChain, interKey, err = ca.loadOrGenIntermediate(rootCert, rootKey)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ca.mu.Lock()
|
||||
ca.root, ca.inter, ca.interKey = rootCert, interCert, interKey
|
||||
ca.root, ca.interChain, ca.interKey = rootCert, interCertChain, interKey
|
||||
ca.mu.Unlock()
|
||||
|
||||
return nil
|
||||
@@ -172,21 +192,21 @@ func (ca CA) RootCertificate() *x509.Certificate {
|
||||
// RootKey returns the CA's root private key. Since the root key is
|
||||
// not cached in memory long-term, it needs to be loaded from storage,
|
||||
// which could yield an error.
|
||||
func (ca CA) RootKey() (any, error) {
|
||||
func (ca CA) RootKey() (crypto.Signer, error) {
|
||||
_, rootKey, err := ca.loadOrGenRoot()
|
||||
return rootKey, err
|
||||
}
|
||||
|
||||
// IntermediateCertificate returns the CA's intermediate
|
||||
// certificate (public key).
|
||||
func (ca CA) IntermediateCertificate() *x509.Certificate {
|
||||
// IntermediateCertificateChain returns the CA's intermediate
|
||||
// certificate chain.
|
||||
func (ca CA) IntermediateCertificateChain() []*x509.Certificate {
|
||||
ca.mu.RLock()
|
||||
defer ca.mu.RUnlock()
|
||||
return ca.inter
|
||||
return ca.interChain
|
||||
}
|
||||
|
||||
// IntermediateKey returns the CA's intermediate private key.
|
||||
func (ca CA) IntermediateKey() any {
|
||||
func (ca CA) IntermediateKey() crypto.Signer {
|
||||
ca.mu.RLock()
|
||||
defer ca.mu.RUnlock()
|
||||
return ca.interKey
|
||||
@@ -207,26 +227,27 @@ func (ca *CA) NewAuthority(authorityConfig AuthorityConfig) (*authority.Authorit
|
||||
// cert/key directly, since it's unlikely to expire
|
||||
// while Caddy is running (long lifetime)
|
||||
var issuerCert *x509.Certificate
|
||||
var issuerKey any
|
||||
var issuerKey crypto.Signer
|
||||
issuerCert = rootCert
|
||||
var err error
|
||||
issuerKey, err = ca.RootKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading signing key: %v", err)
|
||||
}
|
||||
signerOption = authority.WithX509Signer(issuerCert, issuerKey.(crypto.Signer))
|
||||
signerOption = authority.WithX509Signer(issuerCert, issuerKey)
|
||||
} else {
|
||||
// if we're signing with intermediate, we need to make
|
||||
// sure it's always fresh, because the intermediate may
|
||||
// renew while Caddy is running (medium lifetime)
|
||||
signerOption = authority.WithX509SignerFunc(func() ([]*x509.Certificate, crypto.Signer, error) {
|
||||
issuerCert := ca.IntermediateCertificate()
|
||||
issuerKey := ca.IntermediateKey().(crypto.Signer)
|
||||
issuerChain := ca.IntermediateCertificateChain()
|
||||
issuerCert := issuerChain[0]
|
||||
issuerKey := ca.IntermediateKey()
|
||||
ca.log.Debug("using intermediate signer",
|
||||
zap.String("serial", issuerCert.SerialNumber.String()),
|
||||
zap.String("not_before", issuerCert.NotBefore.String()),
|
||||
zap.String("not_after", issuerCert.NotAfter.String()))
|
||||
return []*x509.Certificate{issuerCert}, issuerKey, nil
|
||||
return issuerChain, issuerKey, nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -252,7 +273,11 @@ func (ca *CA) NewAuthority(authorityConfig AuthorityConfig) (*authority.Authorit
|
||||
|
||||
func (ca CA) loadOrGenRoot() (rootCert *x509.Certificate, rootKey crypto.Signer, err error) {
|
||||
if ca.Root != nil {
|
||||
return ca.Root.Load()
|
||||
rootChain, rootSigner, err := ca.Root.Load()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return rootChain[0], rootSigner, nil
|
||||
}
|
||||
rootCertPEM, err := ca.storage.Load(ca.ctx, ca.storageKeyRootCert())
|
||||
if err != nil {
|
||||
@@ -268,7 +293,7 @@ func (ca CA) loadOrGenRoot() (rootCert *x509.Certificate, rootKey crypto.Signer,
|
||||
}
|
||||
|
||||
if rootCert == nil {
|
||||
rootCert, err = pemDecodeSingleCert(rootCertPEM)
|
||||
rootCert, err = pemDecodeCertificate(rootCertPEM)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parsing root certificate PEM: %v", err)
|
||||
}
|
||||
@@ -314,7 +339,8 @@ func (ca CA) genRoot() (rootCert *x509.Certificate, rootKey crypto.Signer, err e
|
||||
return rootCert, rootKey, nil
|
||||
}
|
||||
|
||||
func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Signer) (interCert *x509.Certificate, interKey crypto.Signer, err error) {
|
||||
func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Signer) (interCertChain []*x509.Certificate, interKey crypto.Signer, err error) {
|
||||
var interCert *x509.Certificate
|
||||
interCertPEM, err := ca.storage.Load(ca.ctx, ca.storageKeyIntermediateCert())
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
@@ -326,10 +352,12 @@ func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Si
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generating new intermediate cert: %v", err)
|
||||
}
|
||||
|
||||
interCertChain = append(interCertChain, interCert)
|
||||
}
|
||||
|
||||
if interCert == nil {
|
||||
interCert, err = pemDecodeSingleCert(interCertPEM)
|
||||
if len(interCertChain) == 0 {
|
||||
interCertChain, err = pemDecodeCertificateChain(interCertPEM)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("decoding intermediate certificate PEM: %v", err)
|
||||
}
|
||||
@@ -346,7 +374,7 @@ func (ca CA) loadOrGenIntermediate(rootCert *x509.Certificate, rootKey crypto.Si
|
||||
}
|
||||
}
|
||||
|
||||
return interCert, interKey, nil
|
||||
return interCertChain, interKey, nil
|
||||
}
|
||||
|
||||
func (ca CA) genIntermediate(rootCert *x509.Certificate, rootKey crypto.Signer) (interCert *x509.Certificate, interKey crypto.Signer, err error) {
|
||||
@@ -443,4 +471,6 @@ const (
|
||||
|
||||
defaultRootLifetime = 24 * time.Hour * 30 * 12 * 10
|
||||
defaultIntermediateLifetime = 24 * time.Hour * 7
|
||||
defaultMaintenanceInterval = 10 * time.Minute
|
||||
defaultRenewalWindowRatio = 0.2
|
||||
)
|
||||
|
||||
@@ -17,15 +17,20 @@ package caddypki
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func pemDecodeSingleCert(pemDER []byte) (*x509.Certificate, error) {
|
||||
func pemDecodeCertificate(pemDER []byte) (*x509.Certificate, error) {
|
||||
pemBlock, remaining := pem.Decode(pemDER)
|
||||
if pemBlock == nil {
|
||||
return nil, fmt.Errorf("no PEM block found")
|
||||
@@ -39,6 +44,15 @@ func pemDecodeSingleCert(pemDER []byte) (*x509.Certificate, error) {
|
||||
return x509.ParseCertificate(pemBlock.Bytes)
|
||||
}
|
||||
|
||||
func pemDecodeCertificateChain(pemDER []byte) ([]*x509.Certificate, error) {
|
||||
chain, err := pemutil.ParseCertificateBundle(pemDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing certificate chain: %w", err)
|
||||
}
|
||||
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
func pemEncodeCert(der []byte) ([]byte, error) {
|
||||
return pemEncode("CERTIFICATE", der)
|
||||
}
|
||||
@@ -63,22 +77,25 @@ type KeyPair struct {
|
||||
|
||||
// The private key. By default, this should be the path to
|
||||
// a PEM file unless format is something else.
|
||||
PrivateKey string `json:"private_key,omitempty"`
|
||||
PrivateKey string `json:"private_key,omitempty"` //nolint:gosec // false positive: yes it's exported, since it needs to encode/decode as JSON; and is often just a filepath
|
||||
|
||||
// The format in which the certificate and private
|
||||
// key are provided. Default: pem_file
|
||||
Format string `json:"format,omitempty"`
|
||||
}
|
||||
|
||||
// Load loads the certificate and key.
|
||||
func (kp KeyPair) Load() (*x509.Certificate, crypto.Signer, error) {
|
||||
// Load loads the certificate chain and (optional) private key from
|
||||
// the corresponding files, using the configured format. If a
|
||||
// private key is read, it will be verified to belong to the first
|
||||
// certificate in the chain.
|
||||
func (kp KeyPair) Load() ([]*x509.Certificate, crypto.Signer, error) {
|
||||
switch kp.Format {
|
||||
case "", "pem_file":
|
||||
certData, err := os.ReadFile(kp.Certificate)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cert, err := pemDecodeSingleCert(certData)
|
||||
chain, err := pemDecodeCertificateChain(certData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -93,11 +110,49 @@ func (kp KeyPair) Load() (*x509.Certificate, crypto.Signer, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := verifyKeysMatch(chain[0], key); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return cert, key, nil
|
||||
return chain, key, nil
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported format: %s", kp.Format)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyKeysMatch verifies that the public key in the [x509.Certificate] matches
|
||||
// the public key of the [crypto.Signer].
|
||||
func verifyKeysMatch(crt *x509.Certificate, signer crypto.Signer) error {
|
||||
switch pub := crt.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
pk, ok := signer.Public().(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("private key type %T does not match issuer public key type %T", signer.Public(), pub)
|
||||
}
|
||||
if !pub.Equal(pk) {
|
||||
return errors.New("private key does not match issuer public key")
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
pk, ok := signer.Public().(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("private key type %T does not match issuer public key type %T", signer.Public(), pub)
|
||||
}
|
||||
if !pub.Equal(pk) {
|
||||
return errors.New("private key does not match issuer public key")
|
||||
}
|
||||
case ed25519.PublicKey:
|
||||
pk, ok := signer.Public().(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("private key type %T does not match issuer public key type %T", signer.Public(), pub)
|
||||
}
|
||||
if !pub.Equal(pk) {
|
||||
return errors.New("private key does not match issuer public key")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported key type: %T", pub)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddypki
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestKeyPair_Load(t *testing.T) {
|
||||
rootSigner, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed creating signer: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "test-root"},
|
||||
IsCA: true,
|
||||
MaxPathLen: 3,
|
||||
}
|
||||
rootBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, rootSigner.Public(), rootSigner)
|
||||
if err != nil {
|
||||
t.Fatalf("Creating root certificate failed: %v", err)
|
||||
}
|
||||
|
||||
root, err := x509.ParseCertificate(rootBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Parsing root certificate failed: %v", err)
|
||||
}
|
||||
|
||||
intermediateSigner, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Creating intermedaite signer failed: %v", err)
|
||||
}
|
||||
|
||||
intermediateBytes, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "test-first-intermediate"},
|
||||
IsCA: true,
|
||||
MaxPathLen: 2,
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
}, root, intermediateSigner.Public(), rootSigner)
|
||||
if err != nil {
|
||||
t.Fatalf("Creating intermediate certificate failed: %v", err)
|
||||
}
|
||||
|
||||
intermediate, err := x509.ParseCertificate(intermediateBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Parsing intermediate certificate failed: %v", err)
|
||||
}
|
||||
|
||||
var chainContents []byte
|
||||
chain := []*x509.Certificate{intermediate, root}
|
||||
for _, cert := range chain {
|
||||
b, err := pemutil.Serialize(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed serializing intermediate certificate: %v", err)
|
||||
}
|
||||
chainContents = append(chainContents, pem.EncodeToMemory(b)...)
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
rootCertFile := filepath.Join(dir, "root.pem")
|
||||
if _, err = pemutil.Serialize(root, pemutil.WithFilename(rootCertFile)); err != nil {
|
||||
t.Fatalf("Failed serializing root certificate: %v", err)
|
||||
}
|
||||
rootKeyFile := filepath.Join(dir, "root.key")
|
||||
if _, err = pemutil.Serialize(rootSigner, pemutil.WithFilename(rootKeyFile)); err != nil {
|
||||
t.Fatalf("Failed serializing root key: %v", err)
|
||||
}
|
||||
intermediateCertFile := filepath.Join(dir, "intermediate.pem")
|
||||
if _, err = pemutil.Serialize(intermediate, pemutil.WithFilename(intermediateCertFile)); err != nil {
|
||||
t.Fatalf("Failed serializing intermediate certificate: %v", err)
|
||||
}
|
||||
intermediateKeyFile := filepath.Join(dir, "intermediate.key")
|
||||
if _, err = pemutil.Serialize(intermediateSigner, pemutil.WithFilename(intermediateKeyFile)); err != nil {
|
||||
t.Fatalf("Failed serializing intermediate key: %v", err)
|
||||
}
|
||||
chainFile := filepath.Join(dir, "chain.pem")
|
||||
if err := os.WriteFile(chainFile, chainContents, 0644); err != nil {
|
||||
t.Fatalf("Failed writing intermediate chain: %v", err)
|
||||
}
|
||||
|
||||
t.Run("ok/single-certificate-without-signer", func(t *testing.T) {
|
||||
kp := KeyPair{
|
||||
Certificate: rootCertFile,
|
||||
}
|
||||
chain, signer, err := kp.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed loading KeyPair: %v", err)
|
||||
}
|
||||
if len(chain) != 1 {
|
||||
t.Errorf("Expected 1 certificate in chain; got %d", len(chain))
|
||||
}
|
||||
if signer != nil {
|
||||
t.Error("Expected no signer to be returned")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ok/single-certificate-with-signer", func(t *testing.T) {
|
||||
kp := KeyPair{
|
||||
Certificate: rootCertFile,
|
||||
PrivateKey: rootKeyFile,
|
||||
}
|
||||
chain, signer, err := kp.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed loading KeyPair: %v", err)
|
||||
}
|
||||
if len(chain) != 1 {
|
||||
t.Errorf("Expected 1 certificate in chain; got %d", len(chain))
|
||||
}
|
||||
if signer == nil {
|
||||
t.Error("Expected signer to be returned")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ok/multiple-certificates-with-signer", func(t *testing.T) {
|
||||
kp := KeyPair{
|
||||
Certificate: chainFile,
|
||||
PrivateKey: intermediateKeyFile,
|
||||
}
|
||||
chain, signer, err := kp.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed loading KeyPair: %v", err)
|
||||
}
|
||||
if len(chain) != 2 {
|
||||
t.Errorf("Expected 2 certificates in chain; got %d", len(chain))
|
||||
}
|
||||
if signer == nil {
|
||||
t.Error("Expected signer to be returned")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail/non-matching-public-key", func(t *testing.T) {
|
||||
kp := KeyPair{
|
||||
Certificate: intermediateCertFile,
|
||||
PrivateKey: rootKeyFile,
|
||||
}
|
||||
chain, signer, err := kp.Load()
|
||||
if err == nil {
|
||||
t.Error("Expected loading KeyPair to return an error")
|
||||
}
|
||||
if chain != nil {
|
||||
t.Error("Expected no chain to be returned")
|
||||
}
|
||||
if signer != nil {
|
||||
t.Error("Expected no signer to be returned")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_pemDecodeCertificate(t *testing.T) {
|
||||
signer, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed creating signer: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "test-cert"},
|
||||
IsCA: true,
|
||||
MaxPathLen: 3,
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, signer.Public(), signer)
|
||||
if err != nil {
|
||||
t.Fatalf("Creating root certificate failed: %v", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Parsing root certificate failed: %v", err)
|
||||
}
|
||||
|
||||
pemBlock, err := pemutil.Serialize(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed serializing certificate: %v", err)
|
||||
}
|
||||
pemData := pem.EncodeToMemory(pemBlock)
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
cert, err := pemDecodeCertificate(pemData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed decoding PEM data: %v", err)
|
||||
}
|
||||
if cert == nil {
|
||||
t.Errorf("Expected a certificate in PEM data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail/no-pem-data", func(t *testing.T) {
|
||||
cert, err := pemDecodeCertificate(nil)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected pemDecodeCertificate to return an error")
|
||||
}
|
||||
if cert != nil {
|
||||
t.Errorf("Expected pemDecodeCertificate to return nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail/multiple", func(t *testing.T) {
|
||||
multiplePEMData := append(pemData, pemData...)
|
||||
cert, err := pemDecodeCertificate(multiplePEMData)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected pemDecodeCertificate to return an error")
|
||||
}
|
||||
if cert != nil {
|
||||
t.Errorf("Expected pemDecodeCertificate to return nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail/no-pem-certificate", func(t *testing.T) {
|
||||
pkData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: []byte("some-bogus-private-key"),
|
||||
})
|
||||
cert, err := pemDecodeCertificate(pkData)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected pemDecodeCertificate to return an error")
|
||||
}
|
||||
if cert != nil {
|
||||
t.Errorf("Expected pemDecodeCertificate to return nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_pemDecodeCertificateChain(t *testing.T) {
|
||||
signer, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed creating signer: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "test-cert"},
|
||||
IsCA: true,
|
||||
MaxPathLen: 3,
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, signer.Public(), signer)
|
||||
if err != nil {
|
||||
t.Fatalf("Creating root certificate failed: %v", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Parsing root certificate failed: %v", err)
|
||||
}
|
||||
|
||||
pemBlock, err := pemutil.Serialize(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed serializing certificate: %v", err)
|
||||
}
|
||||
pemData := pem.EncodeToMemory(pemBlock)
|
||||
|
||||
t.Run("ok/single", func(t *testing.T) {
|
||||
certs, err := pemDecodeCertificateChain(pemData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed decoding PEM data: %v", err)
|
||||
}
|
||||
if len(certs) != 1 {
|
||||
t.Errorf("Expected 1 certificate in PEM data; got %d", len(certs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ok/multiple", func(t *testing.T) {
|
||||
multiplePEMData := append(pemData, pemData...)
|
||||
certs, err := pemDecodeCertificateChain(multiplePEMData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed decoding PEM data: %v", err)
|
||||
}
|
||||
if len(certs) != 2 {
|
||||
t.Errorf("Expected 2 certificates in PEM data; got %d", len(certs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail/no-pem-certificate", func(t *testing.T) {
|
||||
pkData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: []byte("some-bogus-private-key"),
|
||||
})
|
||||
certs, err := pemDecodeCertificateChain(pkData)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected pemDecodeCertificateChain to return an error")
|
||||
}
|
||||
if len(certs) != 0 {
|
||||
t.Errorf("Expected 0 certificates in PEM data; got %d", len(certs))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fail/no-der-certificate", func(t *testing.T) {
|
||||
certs, err := pemDecodeCertificateChain([]byte("invalid-der-data"))
|
||||
if err == nil {
|
||||
t.Fatalf("Expected pemDecodeCertificateChain to return an error")
|
||||
}
|
||||
if len(certs) != 0 {
|
||||
t.Errorf("Expected 0 certificates in PEM data; got %d", len(certs))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -24,20 +24,24 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (p *PKI) maintenance() {
|
||||
func (p *PKI) maintenanceForCA(ca *CA) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("[PANIC] PKI maintenance: %v\n%s", err, debug.Stack())
|
||||
log.Printf("[PANIC] PKI maintenance for CA %s: %v\n%s", ca.ID, err, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(10 * time.Minute) // TODO: make configurable
|
||||
interval := time.Duration(ca.MaintenanceInterval)
|
||||
if interval <= 0 {
|
||||
interval = defaultMaintenanceInterval
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
p.renewCerts()
|
||||
_ = p.renewCertsForCA(ca)
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -63,19 +67,19 @@ func (p *PKI) renewCertsForCA(ca *CA) error {
|
||||
|
||||
// only maintain the root if it's not manually provided in the config
|
||||
if ca.Root == nil {
|
||||
if needsRenewal(ca.root) {
|
||||
if ca.needsRenewal(ca.root) {
|
||||
// TODO: implement root renewal (use same key)
|
||||
log.Warn("root certificate expiring soon (FIXME: ROOT RENEWAL NOT YET IMPLEMENTED)",
|
||||
zap.Duration("time_remaining", time.Until(ca.inter.NotAfter)),
|
||||
zap.Duration("time_remaining", time.Until(ca.interChain[0].NotAfter)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// only maintain the intermediate if it's not manually provided in the config
|
||||
if ca.Intermediate == nil {
|
||||
if needsRenewal(ca.inter) {
|
||||
if ca.needsRenewal(ca.interChain[0]) {
|
||||
log.Info("intermediate expires soon; renewing",
|
||||
zap.Duration("time_remaining", time.Until(ca.inter.NotAfter)),
|
||||
zap.Duration("time_remaining", time.Until(ca.interChain[0].NotAfter)),
|
||||
)
|
||||
|
||||
rootCert, rootKey, err := ca.loadOrGenRoot()
|
||||
@@ -86,10 +90,10 @@ func (p *PKI) renewCertsForCA(ca *CA) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating new certificate: %v", err)
|
||||
}
|
||||
ca.inter, ca.interKey = interCert, interKey
|
||||
ca.interChain, ca.interKey = []*x509.Certificate{interCert}, interKey
|
||||
|
||||
log.Info("renewed intermediate",
|
||||
zap.Time("new_expiration", ca.inter.NotAfter),
|
||||
zap.Time("new_expiration", ca.interChain[0].NotAfter),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -97,11 +101,15 @@ func (p *PKI) renewCertsForCA(ca *CA) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func needsRenewal(cert *x509.Certificate) bool {
|
||||
// needsRenewal reports whether the certificate is within its renewal window
|
||||
// (i.e. the fraction of lifetime remaining is less than or equal to RenewalWindowRatio).
|
||||
func (ca *CA) needsRenewal(cert *x509.Certificate) bool {
|
||||
ratio := ca.RenewalWindowRatio
|
||||
if ratio <= 0 {
|
||||
ratio = defaultRenewalWindowRatio
|
||||
}
|
||||
lifetime := cert.NotAfter.Sub(cert.NotBefore)
|
||||
renewalWindow := time.Duration(float64(lifetime) * renewalWindowRatio)
|
||||
renewalWindow := time.Duration(float64(lifetime) * ratio)
|
||||
renewalWindowStart := cert.NotAfter.Add(-renewalWindow)
|
||||
return time.Now().After(renewalWindowStart)
|
||||
}
|
||||
|
||||
const renewalWindowRatio = 0.2 // TODO: make configurable
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddypki
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCA_needsRenewal(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
// cert with 100 days lifetime; last 20% = 20 days before expiry
|
||||
// So renewal window starts at (NotAfter - 20 days)
|
||||
makeCert := func(daysUntilExpiry int, lifetimeDays int) *x509.Certificate {
|
||||
notAfter := now.AddDate(0, 0, daysUntilExpiry)
|
||||
notBefore := notAfter.AddDate(0, 0, -lifetimeDays)
|
||||
return &x509.Certificate{NotBefore: notBefore, NotAfter: notAfter}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ca *CA
|
||||
cert *x509.Certificate
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "inside renewal window with ratio 0.2",
|
||||
ca: &CA{RenewalWindowRatio: 0.2},
|
||||
cert: makeCert(10, 100),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "outside renewal window with ratio 0.2",
|
||||
ca: &CA{RenewalWindowRatio: 0.2},
|
||||
cert: makeCert(50, 100),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "outside renewal window with 21 days left",
|
||||
ca: &CA{RenewalWindowRatio: 0.2},
|
||||
cert: makeCert(21, 100),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "just inside renewal window with ratio 0.5",
|
||||
ca: &CA{RenewalWindowRatio: 0.5},
|
||||
cert: makeCert(30, 100),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "zero ratio uses default",
|
||||
ca: &CA{RenewalWindowRatio: 0},
|
||||
cert: makeCert(10, 100),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "invalid ratio uses default",
|
||||
ca: &CA{RenewalWindowRatio: 1.5},
|
||||
cert: makeCert(10, 100),
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.ca.needsRenewal(tt.cert)
|
||||
if got != tt.expect {
|
||||
t.Errorf("needsRenewal() = %v, want %v", got, tt.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -109,8 +109,10 @@ func (p *PKI) Start() error {
|
||||
// see if root/intermediates need renewal...
|
||||
p.renewCerts()
|
||||
|
||||
// ...and keep them renewed
|
||||
go p.maintenance()
|
||||
// ...and keep them renewed (one goroutine per CA with its own interval)
|
||||
for _, ca := range p.CAs {
|
||||
go p.maintenanceForCA(ca)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -178,6 +178,7 @@ func (iss *ACMEIssuer) Provision(ctx caddy.Context) error {
|
||||
PropagationTimeout: time.Duration(iss.Challenges.DNS.PropagationTimeout),
|
||||
Resolvers: iss.Challenges.DNS.Resolvers,
|
||||
OverrideDomain: iss.Challenges.DNS.OverrideDomain,
|
||||
Logger: iss.logger.Named("dns_manager"),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -336,7 +337,7 @@ func (iss *ACMEIssuer) generateZeroSSLEABCredentials(ctx context.Context, acct a
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("User-Agent", certmagic.UserAgent)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec // no SSRF since URL is from trusted config
|
||||
if err != nil {
|
||||
return nil, acct, fmt.Errorf("performing EAB credentials request: %v", err)
|
||||
}
|
||||
@@ -671,7 +672,7 @@ func ParseCaddyfilePreferredChainsOptions(d *caddyfile.Dispenser) (*ChainPrefere
|
||||
switch d.Val() {
|
||||
case "root_common_name":
|
||||
rootCommonNameOpt := d.RemainingArgs()
|
||||
chainPref.RootCommonName = rootCommonNameOpt
|
||||
chainPref.RootCommonName = append(chainPref.RootCommonName, rootCommonNameOpt...)
|
||||
if rootCommonNameOpt == nil {
|
||||
return nil, d.ArgErr()
|
||||
}
|
||||
@@ -681,7 +682,7 @@ func ParseCaddyfilePreferredChainsOptions(d *caddyfile.Dispenser) (*ChainPrefere
|
||||
|
||||
case "any_common_name":
|
||||
anyCommonNameOpt := d.RemainingArgs()
|
||||
chainPref.AnyCommonName = anyCommonNameOpt
|
||||
chainPref.AnyCommonName = append(chainPref.AnyCommonName, anyCommonNameOpt...)
|
||||
if anyCommonNameOpt == nil {
|
||||
return nil, d.ArgErr()
|
||||
}
|
||||
|
||||
@@ -243,22 +243,49 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
|
||||
}
|
||||
}
|
||||
|
||||
// build certmagic.Config and attach it to the policy
|
||||
storage := ap.storage
|
||||
if storage == nil {
|
||||
storage = tlsApp.ctx.Storage()
|
||||
}
|
||||
cfg, err := ap.makeCertMagicConfig(tlsApp, issuers, storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certCacheMu.RLock()
|
||||
ap.magic = certmagic.New(certCache, cfg)
|
||||
certCacheMu.RUnlock()
|
||||
|
||||
// give issuers a chance to see the config pointer
|
||||
for _, issuer := range ap.magic.Issuers {
|
||||
if annoying, ok := issuer.(ConfigSetter); ok {
|
||||
annoying.SetConfig(ap.magic)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeCertMagicConfig constructs a certmagic.Config for this policy using the
|
||||
// provided issuers and storage. It encapsulates common logic shared between
|
||||
// Provision and RebuildCertMagic so we don't duplicate code.
|
||||
func (ap *AutomationPolicy) makeCertMagicConfig(tlsApp *TLS, issuers []certmagic.Issuer, storage certmagic.Storage) (certmagic.Config, error) {
|
||||
// key source
|
||||
keyType := ap.KeyType
|
||||
if keyType != "" {
|
||||
var err error
|
||||
keyType, err = caddy.NewReplacer().ReplaceOrErr(ap.KeyType, true, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid key type %s: %s", ap.KeyType, err)
|
||||
return certmagic.Config{}, fmt.Errorf("invalid key type %s: %s", ap.KeyType, err)
|
||||
}
|
||||
if _, ok := supportedCertKeyTypes[keyType]; !ok {
|
||||
return fmt.Errorf("unrecognized key type: %s", keyType)
|
||||
return certmagic.Config{}, fmt.Errorf("unrecognized key type: %s", keyType)
|
||||
}
|
||||
}
|
||||
keySource := certmagic.StandardKeyGenerator{
|
||||
KeyType: supportedCertKeyTypes[keyType],
|
||||
}
|
||||
|
||||
storage := ap.storage
|
||||
if storage == nil {
|
||||
storage = tlsApp.ctx.Storage()
|
||||
}
|
||||
@@ -277,7 +304,7 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
|
||||
if noProtections {
|
||||
if !ap.hadExplicitManagers {
|
||||
// no managers, no explicitly-configured permission module, this is a config error
|
||||
return fmt.Errorf("on-demand TLS cannot be enabled without a permission module to prevent abuse; please refer to documentation for details")
|
||||
return certmagic.Config{}, fmt.Errorf("on-demand TLS cannot be enabled without a permission module to prevent abuse; please refer to documentation for details")
|
||||
}
|
||||
// allow on-demand to be enabled but only for the purpose of the Managers; issuance won't be allowed from Issuers
|
||||
tlsApp.logger.Warn("on-demand TLS can only get certificates from the configured external manager(s) because no ask endpoint / permission module is specified")
|
||||
@@ -334,7 +361,7 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
|
||||
}
|
||||
}
|
||||
|
||||
template := certmagic.Config{
|
||||
cfg := certmagic.Config{
|
||||
MustStaple: ap.MustStaple,
|
||||
RenewalWindowRatio: ap.RenewalWindowRatio,
|
||||
KeySource: keySource,
|
||||
@@ -349,8 +376,31 @@ func (ap *AutomationPolicy) Provision(tlsApp *TLS) error {
|
||||
Issuers: issuers,
|
||||
Logger: tlsApp.logger,
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// IsProvisioned reports whether the automation policy has been
|
||||
// provisioned. A provisioned policy has an initialized CertMagic
|
||||
// instance (i.e. ap.magic != nil).
|
||||
func (ap *AutomationPolicy) IsProvisioned() bool { return ap.magic != nil }
|
||||
|
||||
// RebuildCertMagic rebuilds the policy's CertMagic configuration from the
|
||||
// policy's already-populated fields (Issuers, Managers, storage, etc.) and
|
||||
// replaces the internal CertMagic instance. This is a lightweight
|
||||
// alternative to calling Provision because it does not re-provision
|
||||
// modules or re-run module Provision; instead, it constructs a new
|
||||
// certmagic.Config and calls SetConfig on issuers so they receive updated
|
||||
// templates (for example, alternate HTTP/TLS ports supplied by the HTTP
|
||||
// app). RebuildCertMagic should only be called when the policy's required
|
||||
// fields are already populated.
|
||||
func (ap *AutomationPolicy) RebuildCertMagic(tlsApp *TLS) error {
|
||||
cfg, err := ap.makeCertMagicConfig(tlsApp, ap.Issuers, ap.storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certCacheMu.RLock()
|
||||
ap.magic = certmagic.New(certCache, template)
|
||||
ap.magic = certmagic.New(certCache, cfg)
|
||||
certCacheMu.RUnlock()
|
||||
|
||||
// sometimes issuers may need the parent certmagic.Config in
|
||||
|
||||
@@ -257,7 +257,7 @@ func (PKIIntermediateCAPool) CaddyModule() caddy.ModuleInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// Loads the PKI app and load the intermediate certificates into the certificate pool
|
||||
// Loads the PKI app and loads the intermediate certificates into the certificate pool
|
||||
func (p *PKIIntermediateCAPool) Provision(ctx caddy.Context) error {
|
||||
pkiApp, err := ctx.AppIfConfigured("pki")
|
||||
if err != nil {
|
||||
@@ -274,7 +274,9 @@ func (p *PKIIntermediateCAPool) Provision(ctx caddy.Context) error {
|
||||
|
||||
caPool := x509.NewCertPool()
|
||||
for _, ca := range p.ca {
|
||||
caPool.AddCert(ca.IntermediateCertificate())
|
||||
for _, c := range ca.IntermediateCertificateChain() {
|
||||
caPool.AddCert(c)
|
||||
}
|
||||
}
|
||||
p.pool = caPool
|
||||
return nil
|
||||
@@ -500,8 +502,8 @@ func (t *TLSConfig) unmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
// If there is no custom TLS configuration, a nil config may be returned.
|
||||
// copied from with minor modifications: modules/caddyhttp/reverseproxy/httptransport.go
|
||||
func (t *TLSConfig) makeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) {
|
||||
repl := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
if repl == nil {
|
||||
repl, ok := ctx.Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
if !ok || repl == nil {
|
||||
repl = caddy.NewReplacer()
|
||||
}
|
||||
cfg := new(tls.Config)
|
||||
@@ -586,7 +588,7 @@ func (hcp *HTTPCertPool) Provision(ctx caddy.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := httpClient.Do(req)
|
||||
res, err := httpClient.Do(req) //nolint:gosec // SSRF false positive... uri comes from config
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ func (hcg HTTPCertGetter) GetCertificate(ctx context.Context, hello *tls.ClientH
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec // SSRF false positive... request URI comes from config
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -168,21 +168,11 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy.Context) *tls.Config {
|
||||
tlsApp.RegisterServerNames(echNames)
|
||||
}
|
||||
|
||||
// TODO: Ideally, ECH keys should be rotated. However, as of Go 1.24, the std lib implementation
|
||||
// does not support safely modifying the tls.Config's EncryptedClientHelloKeys field.
|
||||
// So, we implement static ECH keys temporarily. See https://github.com/golang/go/issues/71920.
|
||||
// Revisit this after Go 1.25 is released and implement key rotation.
|
||||
var stdECHKeys []tls.EncryptedClientHelloKey
|
||||
for _, echConfigs := range tlsApp.EncryptedClientHello.configs {
|
||||
for _, c := range echConfigs {
|
||||
stdECHKeys = append(stdECHKeys, tls.EncryptedClientHelloKey{
|
||||
Config: c.configBin,
|
||||
PrivateKey: c.privKeyBin,
|
||||
SendAsRetry: c.sendAsRetry,
|
||||
})
|
||||
}
|
||||
tlsCfg.GetEncryptedClientHelloKeys = func(chi *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
|
||||
tlsApp.EncryptedClientHello.configsMu.RLock()
|
||||
defer tlsApp.EncryptedClientHello.configsMu.RUnlock()
|
||||
return tlsApp.EncryptedClientHello.stdlibReady, nil
|
||||
}
|
||||
tlsCfg.EncryptedClientHelloKeys = stdECHKeys
|
||||
}
|
||||
}
|
||||
|
||||
@@ -794,7 +784,7 @@ func (clientauth *ClientAuthentication) provision(ctx caddy.Context) error {
|
||||
for _, fpath := range clientauth.TrustedCACertPEMFiles {
|
||||
ders, err := convertPEMFilesToDER(fpath)
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
clientauth.TrustedCACerts = append(clientauth.TrustedCACerts, ders...)
|
||||
}
|
||||
@@ -807,7 +797,7 @@ func (clientauth *ClientAuthentication) provision(ctx caddy.Context) error {
|
||||
}
|
||||
err := caPool.Provision(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
clientauth.ca = caPool
|
||||
}
|
||||
@@ -895,24 +885,29 @@ func (clientauth *ClientAuthentication) ConfigureTLSConfig(cfg *tls.Config) erro
|
||||
|
||||
// if a custom verification function already exists, wrap it
|
||||
clientauth.existingVerifyPeerCert = cfg.VerifyPeerCertificate
|
||||
cfg.VerifyPeerCertificate = clientauth.verifyPeerCertificate
|
||||
cfg.VerifyConnection = clientauth.verifyConnection
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyPeerCertificate is for use as a tls.Config.VerifyPeerCertificate
|
||||
// callback to do custom client certificate verification. It is intended
|
||||
// for installation only by clientauth.ConfigureTLSConfig().
|
||||
func (clientauth *ClientAuthentication) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
// verifyConnection is for use as a tls.Config.VerifyConnection callback
|
||||
// to do custom client certificate verification. It is intended for
|
||||
// installation only by clientauth.ConfigureTLSConfig().
|
||||
//
|
||||
// Unlike VerifyPeerCertificate, VerifyConnection is called on every
|
||||
// connection including resumed sessions, preventing session-resumption bypass.
|
||||
func (clientauth *ClientAuthentication) verifyConnection(cs tls.ConnectionState) error {
|
||||
// first use any pre-existing custom verification function
|
||||
if clientauth.existingVerifyPeerCert != nil {
|
||||
err := clientauth.existingVerifyPeerCert(rawCerts, verifiedChains)
|
||||
if err != nil {
|
||||
rawCerts := make([][]byte, len(cs.PeerCertificates))
|
||||
for i, cert := range cs.PeerCertificates {
|
||||
rawCerts[i] = cert.Raw
|
||||
}
|
||||
if err := clientauth.existingVerifyPeerCert(rawCerts, cs.VerifiedChains); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, verifier := range clientauth.verifiers {
|
||||
err := verifier.VerifyClientCertificate(rawCerts, verifiedChains)
|
||||
if err != nil {
|
||||
if err := verifier.VerifyClientCertificate(nil, cs.VerifiedChains); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
+214
-64
@@ -2,6 +2,7 @@ package caddytls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
@@ -48,12 +50,6 @@ func init() {
|
||||
// applied will automatically upgrade the minimum TLS version to 1.3, even if
|
||||
// configured to a lower version.
|
||||
//
|
||||
// Note that, as of Caddy 2.10.0 (~March 2025), ECH keys are not automatically
|
||||
// rotated due to a limitation in the Go standard library (see
|
||||
// https://github.com/golang/go/issues/71920). This should be resolved when
|
||||
// Go 1.25 is released (~Aug. 2025), and Caddy will be updated to automatically
|
||||
// rotate ECH keys/configs at that point.
|
||||
//
|
||||
// EXPERIMENTAL: Subject to change.
|
||||
type ECH struct {
|
||||
// The list of ECH configurations for which to automatically generate
|
||||
@@ -73,14 +69,17 @@ type ECH struct {
|
||||
// DNS RRs. (This also typically requires that they use DoH or DoT.)
|
||||
Publication []*ECHPublication `json:"publication,omitempty"`
|
||||
|
||||
// map of public_name to list of configs
|
||||
configs map[string][]echConfig
|
||||
configsMu *sync.RWMutex // protects both configs and the list of configs/keys the standard library uses
|
||||
configs map[string][]echConfig // map of public_name to list of configs
|
||||
stdlibReady []tls.EncryptedClientHelloKey // ECH configs+keys in a format the standard library can use
|
||||
}
|
||||
|
||||
// Provision loads or creates ECH configs and returns outer names (for certificate
|
||||
// management), but does not publish any ECH configs. The DNS module is used as
|
||||
// a default for later publishing if needed.
|
||||
func (ech *ECH) Provision(ctx caddy.Context) ([]string, error) {
|
||||
ech.configsMu = new(sync.RWMutex)
|
||||
|
||||
logger := ctx.Logger().Named("ech")
|
||||
|
||||
// set up publication modules before we need to obtain a lock in storage,
|
||||
@@ -98,17 +97,60 @@ func (ech *ECH) Provision(ctx caddy.Context) ([]string, error) {
|
||||
// the rest of provisioning needs an exclusive lock so that instances aren't
|
||||
// stepping on each other when setting up ECH configs
|
||||
storage := ctx.Storage()
|
||||
const echLockName = "ech_provision"
|
||||
if err := storage.Lock(ctx, echLockName); err != nil {
|
||||
if err := storage.Lock(ctx, echStorageLockName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err := storage.Unlock(ctx, echLockName); err != nil {
|
||||
if err := storage.Unlock(ctx, echStorageLockName); err != nil {
|
||||
logger.Error("unable to unlock ECH provisioning in storage", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
var outerNames []string //nolint:prealloc // (FALSE POSITIVE - see https://github.com/alexkohler/prealloc/issues/30)
|
||||
ech.configsMu.Lock()
|
||||
defer ech.configsMu.Unlock()
|
||||
|
||||
outerNames, err := ech.setConfigsFromStorage(ctx, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading configs from storage: %w", err)
|
||||
}
|
||||
|
||||
// see if we need to make any new ones based on the input configuration
|
||||
for _, cfg := range ech.Configs {
|
||||
publicName := strings.ToLower(strings.TrimSpace(cfg.PublicName))
|
||||
|
||||
if list, ok := ech.configs[publicName]; !ok || len(list) == 0 {
|
||||
// no config with this public name was loaded, so create one
|
||||
echCfg, err := generateAndStoreECHConfig(ctx, publicName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debug("generated new ECH config",
|
||||
zap.String("public_name", echCfg.RawPublicName),
|
||||
zap.Uint8("id", echCfg.ConfigID))
|
||||
ech.configs[publicName] = append(ech.configs[publicName], echCfg)
|
||||
outerNames = append(outerNames, publicName)
|
||||
}
|
||||
}
|
||||
|
||||
// convert the configs into a structure ready for the std lib to use
|
||||
ech.updateKeyList()
|
||||
|
||||
// ensure any old keys are rotated out
|
||||
if err = ech.rotateECHKeys(ctx, logger, true); err != nil {
|
||||
return nil, fmt.Errorf("rotating ECH configs: %w", err)
|
||||
}
|
||||
|
||||
return outerNames, nil
|
||||
}
|
||||
|
||||
// setConfigsFromStorage sets the ECH configs in memory to those in storage.
|
||||
// It must be called in a write lock on ech.configsMu.
|
||||
func (ech *ECH) setConfigsFromStorage(ctx caddy.Context, logger *zap.Logger) ([]string, error) {
|
||||
storage := ctx.Storage()
|
||||
|
||||
ech.configs = make(map[string][]echConfig)
|
||||
|
||||
var outerNames []string
|
||||
|
||||
// start by loading all the existing configs (even the older ones on the way out,
|
||||
// since some clients may still be using them if they haven't yet picked up on the
|
||||
@@ -131,48 +173,145 @@ func (ech *ECH) Provision(ctx caddy.Context) ([]string, error) {
|
||||
logger.Debug("loaded ECH config",
|
||||
zap.String("public_name", cfg.RawPublicName),
|
||||
zap.Uint8("id", cfg.ConfigID))
|
||||
ech.configs[cfg.RawPublicName] = append(ech.configs[cfg.RawPublicName], cfg)
|
||||
outerNames = append(outerNames, cfg.RawPublicName)
|
||||
}
|
||||
|
||||
// all existing configs are now loaded; see if we need to make any new ones
|
||||
// based on the input configuration, and also mark the most recent one(s) as
|
||||
// current/active, so they can be used for ECH retries
|
||||
for _, cfg := range ech.Configs {
|
||||
publicName := strings.ToLower(strings.TrimSpace(cfg.PublicName))
|
||||
|
||||
if list, ok := ech.configs[publicName]; ok && len(list) > 0 {
|
||||
// at least one config with this public name was loaded, so find the
|
||||
// most recent one and mark it as active to be used with retries
|
||||
var mostRecentDate time.Time
|
||||
var mostRecentIdx int
|
||||
for i, c := range list {
|
||||
if mostRecentDate.IsZero() || c.meta.Created.After(mostRecentDate) {
|
||||
mostRecentDate = c.meta.Created
|
||||
mostRecentIdx = i
|
||||
}
|
||||
}
|
||||
list[mostRecentIdx].sendAsRetry = true
|
||||
} else {
|
||||
// no config with this public name was loaded, so create one
|
||||
echCfg, err := generateAndStoreECHConfig(ctx, publicName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debug("generated new ECH config",
|
||||
zap.String("public_name", echCfg.RawPublicName),
|
||||
zap.Uint8("id", echCfg.ConfigID))
|
||||
ech.configs[publicName] = append(ech.configs[publicName], echCfg)
|
||||
outerNames = append(outerNames, publicName)
|
||||
if _, seen := ech.configs[cfg.RawPublicName]; !seen {
|
||||
outerNames = append(outerNames, cfg.RawPublicName)
|
||||
}
|
||||
ech.configs[cfg.RawPublicName] = append(ech.configs[cfg.RawPublicName], cfg)
|
||||
}
|
||||
|
||||
return outerNames, nil
|
||||
}
|
||||
|
||||
func (t *TLS) publishECHConfigs() error {
|
||||
logger := t.logger.Named("ech")
|
||||
// rotateECHKeys updates the ECH keys/configs that are outdated if rotation is needed.
|
||||
// It should be called in a write lock on ech.configsMu. If a lock is already obtained
|
||||
// in storage, then pass true for storageSynced.
|
||||
//
|
||||
// This function sets/updates the stdlib-ready key list only if a rotation occurs.
|
||||
func (ech *ECH) rotateECHKeys(ctx caddy.Context, logger *zap.Logger, storageSynced bool) error {
|
||||
storage := ctx.Storage()
|
||||
|
||||
// all existing configs are now loaded; rotate keys "regularly" as recommended by the spec
|
||||
// (also: "Rotating too frequently limits the client anonymity set." - but the more server
|
||||
// names, the more frequently rotation can be done safely)
|
||||
const (
|
||||
rotationInterval = 24 * time.Hour * 30
|
||||
deleteAfter = 24 * time.Hour * 90
|
||||
)
|
||||
|
||||
if !ech.rotationNeeded(rotationInterval, deleteAfter) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sync this operation across cluster if not already
|
||||
if !storageSynced {
|
||||
if err := storage.Lock(ctx, echStorageLockName); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := storage.Unlock(ctx, echStorageLockName); err != nil {
|
||||
logger.Error("unable to unlock ECH rotation in storage", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// update what storage has, in case another instance already updated things
|
||||
if _, err := ech.setConfigsFromStorage(ctx, logger); err != nil {
|
||||
return fmt.Errorf("updating ECH keys from storage: %v", err)
|
||||
}
|
||||
|
||||
// iterate the updated list and do any updates as needed
|
||||
for publicName := range ech.configs {
|
||||
for i := 0; i < len(ech.configs[publicName]); i++ {
|
||||
cfg := ech.configs[publicName][i]
|
||||
if time.Since(cfg.meta.Created) >= rotationInterval && cfg.meta.Replaced.IsZero() {
|
||||
// key is due for rotation and it hasn't been replaced yet; do that now
|
||||
logger.Debug("ECH config is due for rotation",
|
||||
zap.String("public_name", cfg.RawPublicName),
|
||||
zap.Uint8("id", cfg.ConfigID),
|
||||
zap.Time("created", cfg.meta.Created),
|
||||
zap.Duration("age", time.Since(cfg.meta.Created)),
|
||||
zap.Duration("rotation_interval", rotationInterval))
|
||||
|
||||
// start by generating and storing the replacement ECH config
|
||||
newCfg, err := generateAndStoreECHConfig(ctx, publicName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating and storing new replacement ECH config: %w", err)
|
||||
}
|
||||
|
||||
// mark the key as replaced so we don't rotate it again, and instead delete it later
|
||||
ech.configs[publicName][i].meta.Replaced = time.Now()
|
||||
|
||||
// persist the updated metadata
|
||||
metaBytes, err := json.Marshal(ech.configs[publicName][i].meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling updated ECH config metadata: %v", err)
|
||||
}
|
||||
if err := storage.Store(ctx, echMetaKey(cfg.ConfigID), metaBytes); err != nil {
|
||||
return fmt.Errorf("storing updated ECH config metadata: %v", err)
|
||||
}
|
||||
|
||||
ech.configs[publicName] = append(ech.configs[publicName], newCfg)
|
||||
|
||||
logger.Debug("rotated ECH key",
|
||||
zap.String("public_name", cfg.RawPublicName),
|
||||
zap.Uint8("old_id", cfg.ConfigID),
|
||||
zap.Uint8("new_id", newCfg.ConfigID))
|
||||
} else if time.Since(cfg.meta.Created) >= deleteAfter && !cfg.meta.Replaced.IsZero() {
|
||||
// key has expired and is no longer supported; delete it from storage and memory
|
||||
cfgIDKey := path.Join(echConfigsKey, strconv.Itoa(int(cfg.ConfigID)))
|
||||
if err := storage.Delete(ctx, cfgIDKey); err != nil {
|
||||
return fmt.Errorf("deleting expired ECH config: %v", err)
|
||||
}
|
||||
|
||||
ech.configs[publicName] = append(ech.configs[publicName][:i], ech.configs[publicName][i+1:]...)
|
||||
i--
|
||||
|
||||
logger.Debug("deleted expired ECH key",
|
||||
zap.String("public_name", cfg.RawPublicName),
|
||||
zap.Uint8("id", cfg.ConfigID),
|
||||
zap.Duration("age", time.Since(cfg.meta.Created)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ech.updateKeyList()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rotationNeeded returns true if any ECH key needs to be replaced, or deleted.
|
||||
// It must be called inside a read or write lock of ech.configsMu (probably a
|
||||
// write lock, so that the rotation can occur correctly in the same lock).)
|
||||
func (ech *ECH) rotationNeeded(rotationInterval, deleteAfter time.Duration) bool {
|
||||
for publicName := range ech.configs {
|
||||
for i := 0; i < len(ech.configs[publicName]); i++ {
|
||||
cfg := ech.configs[publicName][i]
|
||||
if (time.Since(cfg.meta.Created) >= rotationInterval && cfg.meta.Replaced.IsZero()) ||
|
||||
(time.Since(cfg.meta.Created) >= deleteAfter && !cfg.meta.Replaced.IsZero()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// updateKeyList updates the list of ECH keys the std lib uses to serve ECH.
|
||||
// It must be called inside a write lock on ech.configsMu.
|
||||
func (ech *ECH) updateKeyList() {
|
||||
ech.stdlibReady = []tls.EncryptedClientHelloKey{}
|
||||
for _, cfgs := range ech.configs {
|
||||
for _, cfg := range cfgs {
|
||||
ech.stdlibReady = append(ech.stdlibReady, tls.EncryptedClientHelloKey{
|
||||
Config: cfg.configBin,
|
||||
PrivateKey: cfg.privKeyBin,
|
||||
SendAsRetry: cfg.meta.Replaced.IsZero(), // only send during retries if key has not been rotated out
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// publishECHConfigs publishes any configs that are configured for publication and which haven't been published already.
|
||||
func (t *TLS) publishECHConfigs(logger *zap.Logger) error {
|
||||
// make publication exclusive, since we don't need to repeat this unnecessarily
|
||||
storage := t.ctx.Storage()
|
||||
const echLockName = "ech_publish"
|
||||
@@ -197,7 +336,7 @@ func (t *TLS) publishECHConfigs() error {
|
||||
publishers: []ECHPublisher{
|
||||
&ECHDNSPublisher{
|
||||
provider: dnsProv,
|
||||
logger: t.logger,
|
||||
logger: logger,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -209,6 +348,7 @@ func (t *TLS) publishECHConfigs() error {
|
||||
// publish with it, and figure out which inner names to publish
|
||||
// to/for, then publish
|
||||
for _, publication := range publicationList {
|
||||
t.EncryptedClientHello.configsMu.RLock()
|
||||
// this publication is either configured for specific ECH configs,
|
||||
// or we just use an implied default of all ECH configs
|
||||
var echCfgList echConfigList
|
||||
@@ -231,6 +371,7 @@ func (t *TLS) publishECHConfigs() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
t.EncryptedClientHello.configsMu.RUnlock()
|
||||
|
||||
// marshal the ECH config list as binary for publication
|
||||
echCfgListBin, err := echCfgList.MarshalBinary()
|
||||
@@ -250,6 +391,10 @@ func (t *TLS) publishECHConfigs() error {
|
||||
if publication.Domains == nil {
|
||||
serverNamesSet = make(map[string]struct{}, len(t.serverNames))
|
||||
for name := range t.serverNames {
|
||||
// skip Tailscale names, a special case we also handle differently in our auto-HTTPS
|
||||
if strings.HasSuffix(name, ".ts.net") {
|
||||
continue
|
||||
}
|
||||
serverNamesSet[name] = struct{}{}
|
||||
}
|
||||
} else {
|
||||
@@ -304,7 +449,7 @@ func (t *TLS) publishECHConfigs() error {
|
||||
// at least a partial failure, maybe a complete failure, but we can
|
||||
// log each error by domain
|
||||
for innerName, domainErr := range publishErrs {
|
||||
t.logger.Error("failed to publish ECH configuration list",
|
||||
logger.Error("failed to publish ECH configuration list",
|
||||
zap.String("publisher", publisherKey),
|
||||
zap.String("domain", innerName),
|
||||
zap.Uint8s("config_ids", configIDs),
|
||||
@@ -312,7 +457,7 @@ func (t *TLS) publishECHConfigs() error {
|
||||
}
|
||||
} else if err != nil {
|
||||
// generic error; assume the entire thing failed, I guess
|
||||
t.logger.Error("failed publishing ECH configuration list",
|
||||
logger.Error("failed publishing ECH configuration list",
|
||||
zap.String("publisher", publisherKey),
|
||||
zap.Strings("domains", dnsNamesToPublish),
|
||||
zap.Uint8s("config_ids", configIDs),
|
||||
@@ -334,7 +479,7 @@ func (t *TLS) publishECHConfigs() error {
|
||||
successNames = append(successNames, name)
|
||||
}
|
||||
}
|
||||
t.logger.Info("successfully published ECH configuration list for "+someAll+" domains",
|
||||
logger.Info("successfully published ECH configuration list for "+someAll+" domains",
|
||||
zap.String("publisher", publisherKey),
|
||||
zap.Strings("domains", successNames),
|
||||
zap.Uint8s("config_ids", configIDs))
|
||||
@@ -353,13 +498,12 @@ func (t *TLS) publishECHConfigs() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling ECH config metadata: %v", err)
|
||||
}
|
||||
metaKey := path.Join(echConfigsKey, strconv.Itoa(int(cfg.ConfigID)), "meta.json")
|
||||
if err := t.ctx.Storage().Store(t.ctx, metaKey, metaBytes); err != nil {
|
||||
if err := t.ctx.Storage().Store(t.ctx, echMetaKey(cfg.ConfigID), metaBytes); err != nil {
|
||||
return fmt.Errorf("storing updated ECH config metadata: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.logger.Error("all domains failed to publish ECH configuration list (see earlier errors)",
|
||||
logger.Error("all domains failed to publish ECH configuration list (see earlier errors)",
|
||||
zap.String("publisher", publisherKey),
|
||||
zap.Strings("domains", dnsNamesToPublish),
|
||||
zap.Uint8s("config_ids", configIDs))
|
||||
@@ -489,7 +633,7 @@ func generateAndStoreECHConfig(ctx caddy.Context, publicName string) (echConfig,
|
||||
|
||||
echCfg := echConfig{
|
||||
PublicKey: publicKey,
|
||||
Version: draftTLSESNI22,
|
||||
Version: draftTLSESNI25,
|
||||
ConfigID: configID,
|
||||
RawPublicName: publicName,
|
||||
KEMID: kemChoice,
|
||||
@@ -507,7 +651,6 @@ func generateAndStoreECHConfig(ctx caddy.Context, publicName string) (echConfig,
|
||||
AEADID: hpke.AEAD_ChaCha20Poly1305,
|
||||
},
|
||||
},
|
||||
sendAsRetry: true,
|
||||
}
|
||||
meta := echConfigMeta{
|
||||
Created: time.Now(),
|
||||
@@ -786,10 +929,9 @@ type echConfig struct {
|
||||
|
||||
// these fields are not part of the spec, but are here for
|
||||
// our use when setting up TLS servers or maintenance
|
||||
configBin []byte
|
||||
privKeyBin []byte
|
||||
meta echConfigMeta
|
||||
sendAsRetry bool
|
||||
configBin []byte
|
||||
privKeyBin []byte
|
||||
meta echConfigMeta
|
||||
}
|
||||
|
||||
func (echCfg echConfig) MarshalBinary() ([]byte, error) {
|
||||
@@ -811,8 +953,8 @@ func (echCfg *echConfig) UnmarshalBinary(data []byte) error {
|
||||
if !b.ReadUint16(&echCfg.Version) {
|
||||
return errInvalidLen
|
||||
}
|
||||
if echCfg.Version != draftTLSESNI22 {
|
||||
return fmt.Errorf("supported version must be %d: got %d", draftTLSESNI22, echCfg.Version)
|
||||
if echCfg.Version != draftTLSESNI25 {
|
||||
return fmt.Errorf("supported version must be %d: got %d", draftTLSESNI25, echCfg.Version)
|
||||
}
|
||||
|
||||
if !b.ReadUint16LengthPrefixed(&content) || !b.Empty() {
|
||||
@@ -1022,19 +1164,27 @@ func (p PublishECHConfigListErrors) Error() string {
|
||||
|
||||
type echConfigMeta struct {
|
||||
Created time.Time `json:"created"`
|
||||
Replaced time.Time `json:"replaced,omitzero"`
|
||||
Publications publicationHistory `json:"publications"`
|
||||
}
|
||||
|
||||
func echMetaKey(configID uint8) string {
|
||||
return path.Join(echConfigsKey, strconv.Itoa(int(configID)), "meta.json")
|
||||
}
|
||||
|
||||
// publicationHistory is a map of publisher key to
|
||||
// map of inner name to timestamp
|
||||
type publicationHistory map[string]map[string]time.Time
|
||||
|
||||
// echStorageLockName is the name of the storage lock to sync ECH updates.
|
||||
const echStorageLockName = "ech_rotation"
|
||||
|
||||
// The key prefix when putting ECH configs in storage. After this
|
||||
// comes the config ID.
|
||||
const echConfigsKey = "ech/configs"
|
||||
|
||||
// https://www.ietf.org/archive/id/draft-ietf-tls-esni-22.html
|
||||
const draftTLSESNI22 = 0xfe0d
|
||||
// https://www.ietf.org/archive/id/draft-ietf-tls-esni-25.html
|
||||
const draftTLSESNI25 = 0xfe0d
|
||||
|
||||
// Interface guard
|
||||
var _ ECHPublisher = (*ECHDNSPublisher)(nil)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -62,18 +63,27 @@ func (fl FolderLoader) Provision(ctx caddy.Context) error {
|
||||
func (fl FolderLoader) LoadCertificates() ([]Certificate, error) {
|
||||
var certs []Certificate
|
||||
for _, dir := range fl {
|
||||
err := filepath.Walk(dir, func(fpath string, info os.FileInfo, err error) error {
|
||||
root, err := os.OpenRoot(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open root directory %s: %w", dir, err)
|
||||
}
|
||||
err = filepath.WalkDir(dir, func(fpath string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to traverse into path: %s", fpath)
|
||||
}
|
||||
if info.IsDir() {
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
|
||||
if !strings.HasSuffix(strings.ToLower(d.Name()), ".pem") {
|
||||
return nil
|
||||
}
|
||||
|
||||
bundle, err := os.ReadFile(fpath)
|
||||
rel, err := filepath.Rel(dir, fpath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get relative path for %s: %w", fpath, err)
|
||||
}
|
||||
|
||||
bundle, err := root.ReadFile(rel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -83,11 +93,11 @@ func (fl FolderLoader) LoadCertificates() ([]Certificate, error) {
|
||||
}
|
||||
|
||||
certs = append(certs, Certificate{Certificate: cert})
|
||||
|
||||
return nil
|
||||
})
|
||||
_ = root.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("walking certificates directory %s: %w", dir, err)
|
||||
}
|
||||
}
|
||||
return certs, nil
|
||||
|
||||
@@ -115,7 +115,8 @@ func (iss InternalIssuer) Issue(ctx context.Context, csr *x509.CertificateReques
|
||||
if iss.SignWithRoot {
|
||||
issuerCert = iss.ca.RootCertificate()
|
||||
} else {
|
||||
issuerCert = iss.ca.IntermediateCertificate()
|
||||
chain := iss.ca.IntermediateCertificateChain()
|
||||
issuerCert = chain[0]
|
||||
}
|
||||
|
||||
// ensure issued certificate does not expire later than its issuer
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddypki"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestInternalIssuer_Issue(t *testing.T) {
|
||||
rootSigner, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Creating root signer failed: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "test-root"},
|
||||
IsCA: true,
|
||||
MaxPathLen: 3,
|
||||
NotAfter: time.Now().Add(7 * 24 * time.Hour),
|
||||
NotBefore: time.Now().Add(-7 * 24 * time.Hour),
|
||||
}
|
||||
rootBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, rootSigner.Public(), rootSigner)
|
||||
if err != nil {
|
||||
t.Fatalf("Creating root certificate failed: %v", err)
|
||||
}
|
||||
|
||||
root, err := x509.ParseCertificate(rootBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Parsing root certificate failed: %v", err)
|
||||
}
|
||||
|
||||
firstIntermediateSigner, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Creating intermedaite signer failed: %v", err)
|
||||
}
|
||||
|
||||
firstIntermediateBytes, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "test-first-intermediate"},
|
||||
IsCA: true,
|
||||
MaxPathLen: 2,
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
NotBefore: time.Now().Add(-24 * time.Hour),
|
||||
}, root, firstIntermediateSigner.Public(), rootSigner)
|
||||
if err != nil {
|
||||
t.Fatalf("Creating intermediate certificate failed: %v", err)
|
||||
}
|
||||
|
||||
firstIntermediate, err := x509.ParseCertificate(firstIntermediateBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Parsing intermediate certificate failed: %v", err)
|
||||
}
|
||||
|
||||
secondIntermediateSigner, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Creating second intermedaite signer failed: %v", err)
|
||||
}
|
||||
|
||||
secondIntermediateBytes, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "test-second-intermediate"},
|
||||
IsCA: true,
|
||||
MaxPathLen: 2,
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
NotBefore: time.Now().Add(-24 * time.Hour),
|
||||
}, firstIntermediate, secondIntermediateSigner.Public(), firstIntermediateSigner)
|
||||
if err != nil {
|
||||
t.Fatalf("Creating second intermediate certificate failed: %v", err)
|
||||
}
|
||||
|
||||
secondIntermediate, err := x509.ParseCertificate(secondIntermediateBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Parsing second intermediate certificate failed: %v", err)
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
storageDir := filepath.Join(dir, "certmagic")
|
||||
rootCertFile := filepath.Join(dir, "root.pem")
|
||||
if _, err = pemutil.Serialize(root, pemutil.WithFilename(rootCertFile)); err != nil {
|
||||
t.Fatalf("Failed serializing root certificate: %v", err)
|
||||
}
|
||||
intermediateCertFile := filepath.Join(dir, "intermediate.pem")
|
||||
if _, err = pemutil.Serialize(firstIntermediate, pemutil.WithFilename(intermediateCertFile)); err != nil {
|
||||
t.Fatalf("Failed serializing intermediate certificate: %v", err)
|
||||
}
|
||||
intermediateKeyFile := filepath.Join(dir, "intermediate.key")
|
||||
if _, err = pemutil.Serialize(firstIntermediateSigner, pemutil.WithFilename(intermediateKeyFile)); err != nil {
|
||||
t.Fatalf("Failed serializing intermediate key: %v", err)
|
||||
}
|
||||
|
||||
var intermediateChainContents []byte
|
||||
intermediateChain := []*x509.Certificate{secondIntermediate, firstIntermediate}
|
||||
for _, cert := range intermediateChain {
|
||||
b, err := pemutil.Serialize(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed serializing intermediate certificate: %v", err)
|
||||
}
|
||||
intermediateChainContents = append(intermediateChainContents, pem.EncodeToMemory(b)...)
|
||||
}
|
||||
intermediateChainFile := filepath.Join(dir, "intermediates.pem")
|
||||
if err := os.WriteFile(intermediateChainFile, intermediateChainContents, 0644); err != nil {
|
||||
t.Fatalf("Failed writing intermediate chain: %v", err)
|
||||
}
|
||||
intermediateChainKeyFile := filepath.Join(dir, "intermediates.key")
|
||||
if _, err = pemutil.Serialize(secondIntermediateSigner, pemutil.WithFilename(intermediateChainKeyFile)); err != nil {
|
||||
t.Fatalf("Failed serializing intermediate key: %v", err)
|
||||
}
|
||||
|
||||
signer, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed creating signer: %v", err)
|
||||
}
|
||||
|
||||
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{
|
||||
Subject: pkix.Name{CommonName: "test"},
|
||||
}, signer)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed creating CSR: %v", err)
|
||||
}
|
||||
|
||||
csr, err := x509.ParseCertificateRequest(csrBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed parsing CSR: %v", err)
|
||||
}
|
||||
|
||||
t.Run("generated-with-defaults", func(t *testing.T) {
|
||||
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: t.Context()})
|
||||
t.Cleanup(cancel)
|
||||
logger := zap.NewNop()
|
||||
|
||||
ca := &caddypki.CA{
|
||||
StorageRaw: []byte(fmt.Sprintf(`{"module": "file_system", "root": %q}`, storageDir)),
|
||||
}
|
||||
if err := ca.Provision(caddyCtx, "local-test-generated", logger); err != nil {
|
||||
t.Fatalf("Failed provisioning CA: %v", err)
|
||||
}
|
||||
|
||||
iss := InternalIssuer{
|
||||
SignWithRoot: false,
|
||||
ca: ca,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
c, err := iss.Issue(t.Context(), csr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed issuing certificate: %v", err)
|
||||
}
|
||||
|
||||
chain, err := pemutil.ParseCertificateBundle(c.Certificate)
|
||||
if err != nil {
|
||||
t.Errorf("Failed issuing certificate: %v", err)
|
||||
}
|
||||
if len(chain) != 2 {
|
||||
t.Errorf("Expected 2 certificates in chain; got %d", len(chain))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single-intermediate-from-disk", func(t *testing.T) {
|
||||
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: t.Context()})
|
||||
t.Cleanup(cancel)
|
||||
logger := zap.NewNop()
|
||||
|
||||
ca := &caddypki.CA{
|
||||
Root: &caddypki.KeyPair{
|
||||
Certificate: rootCertFile,
|
||||
},
|
||||
Intermediate: &caddypki.KeyPair{
|
||||
Certificate: intermediateCertFile,
|
||||
PrivateKey: intermediateKeyFile,
|
||||
},
|
||||
StorageRaw: []byte(fmt.Sprintf(`{"module": "file_system", "root": %q}`, storageDir)),
|
||||
}
|
||||
|
||||
if err := ca.Provision(caddyCtx, "local-test-single-intermediate", logger); err != nil {
|
||||
t.Fatalf("Failed provisioning CA: %v", err)
|
||||
}
|
||||
|
||||
iss := InternalIssuer{
|
||||
ca: ca,
|
||||
SignWithRoot: false,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
c, err := iss.Issue(t.Context(), csr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed issuing certificate: %v", err)
|
||||
}
|
||||
|
||||
chain, err := pemutil.ParseCertificateBundle(c.Certificate)
|
||||
if err != nil {
|
||||
t.Errorf("Failed issuing certificate: %v", err)
|
||||
}
|
||||
if len(chain) != 2 {
|
||||
t.Errorf("Expected 2 certificates in chain; got %d", len(chain))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple-intermediates-from-disk", func(t *testing.T) {
|
||||
caddyCtx, cancel := caddy.NewContext(caddy.Context{Context: t.Context()})
|
||||
t.Cleanup(cancel)
|
||||
logger := zap.NewNop()
|
||||
|
||||
ca := &caddypki.CA{
|
||||
Root: &caddypki.KeyPair{
|
||||
Certificate: rootCertFile,
|
||||
},
|
||||
Intermediate: &caddypki.KeyPair{
|
||||
Certificate: intermediateChainFile,
|
||||
PrivateKey: intermediateChainKeyFile,
|
||||
},
|
||||
StorageRaw: []byte(fmt.Sprintf(`{"module": "file_system", "root": %q}`, storageDir)),
|
||||
}
|
||||
|
||||
if err := ca.Provision(caddyCtx, "local-test", zap.NewNop()); err != nil {
|
||||
t.Fatalf("Failed provisioning CA: %v", err)
|
||||
}
|
||||
|
||||
iss := InternalIssuer{
|
||||
ca: ca,
|
||||
SignWithRoot: false,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
c, err := iss.Issue(t.Context(), csr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed issuing certificate: %v", err)
|
||||
}
|
||||
|
||||
chain, err := pemutil.ParseCertificateBundle(c.Certificate)
|
||||
if err != nil {
|
||||
t.Errorf("Failed issuing certificate: %v", err)
|
||||
}
|
||||
if len(chain) != 3 {
|
||||
t.Errorf("Expected 3 certificates in chain; got %d", len(chain))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -29,9 +29,9 @@ func init() {
|
||||
caddy.RegisterModule(LeafFolderLoader{})
|
||||
}
|
||||
|
||||
// LeafFolderLoader loads certificates and their associated keys from disk
|
||||
// LeafFolderLoader loads certificates from disk
|
||||
// by recursively walking the specified directories, looking for PEM
|
||||
// files which contain both a certificate and a key.
|
||||
// files which contain a certificate.
|
||||
type LeafFolderLoader struct {
|
||||
Folders []string `json:"folders,omitempty"`
|
||||
}
|
||||
|
||||
+37
-6
@@ -123,8 +123,15 @@ type TLS struct {
|
||||
//
|
||||
// EXPERIMENTAL: Subject to change.
|
||||
DNSRaw json.RawMessage `json:"dns,omitempty" caddy:"namespace=dns.providers inline_key=name"`
|
||||
dns any // technically, it should be any/all of the libdns interfaces (RecordSetter, RecordAppender, etc.)
|
||||
|
||||
// The default DNS resolvers to use for TLS-related DNS operations, specifically
|
||||
// for ACME DNS challenges and ACME server DNS validations.
|
||||
// If not specified, the system default resolvers will be used.
|
||||
//
|
||||
// EXPERIMENTAL: Subject to change.
|
||||
Resolvers []string `json:"resolvers,omitempty"`
|
||||
|
||||
dns any // technically, it should be any/all of the libdns interfaces (RecordSetter, RecordAppender, etc.)
|
||||
certificateLoaders []CertificateLoader
|
||||
automateNames map[string]struct{}
|
||||
ctx caddy.Context
|
||||
@@ -335,7 +342,6 @@ func (t *TLS) Provision(ctx caddy.Context) error {
|
||||
|
||||
// ECH (Encrypted ClientHello) initialization
|
||||
if t.EncryptedClientHello != nil {
|
||||
t.EncryptedClientHello.configs = make(map[string][]echConfig)
|
||||
outerNames, err := t.EncryptedClientHello.Provision(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("provisioning Encrypted ClientHello components: %v", err)
|
||||
@@ -411,12 +417,37 @@ func (t *TLS) Start() error {
|
||||
return fmt.Errorf("automate: managing %v: %v", t.automateNames, err)
|
||||
}
|
||||
|
||||
// publish ECH configs in the background; does not need to block
|
||||
// server startup, as it could take a while
|
||||
if t.EncryptedClientHello != nil {
|
||||
echLogger := t.logger.Named("ech")
|
||||
|
||||
// publish ECH configs in the background; does not need to block
|
||||
// server startup, as it could take a while; then keep keys rotated
|
||||
go func() {
|
||||
if err := t.publishECHConfigs(); err != nil {
|
||||
t.logger.Named("ech").Error("publication(s) failed", zap.Error(err))
|
||||
// publish immediately first
|
||||
if err := t.publishECHConfigs(echLogger); err != nil {
|
||||
echLogger.Error("publication(s) failed", zap.Error(err))
|
||||
}
|
||||
|
||||
// then every so often, rotate and publish if needed
|
||||
// (both of these functions only do something if needed)
|
||||
for {
|
||||
select {
|
||||
case <-time.After(1 * time.Hour):
|
||||
// ensure old keys are rotated out
|
||||
t.EncryptedClientHello.configsMu.Lock()
|
||||
err = t.EncryptedClientHello.rotateECHKeys(t.ctx, echLogger, false)
|
||||
t.EncryptedClientHello.configsMu.Unlock()
|
||||
if err != nil {
|
||||
echLogger.Error("rotating ECH configs failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
err := t.publishECHConfigs(echLogger)
|
||||
if err != nil {
|
||||
echLogger.Error("publication(s) failed", zap.Error(err))
|
||||
}
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func init() {
|
||||
type ZeroSSLIssuer struct {
|
||||
// The API key (or "access key") for using the ZeroSSL API.
|
||||
// REQUIRED.
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"` //nolint:gosec // false positive... yes this is exported, for JSON interop
|
||||
|
||||
// How many days the certificate should be valid for.
|
||||
// Only certain values are accepted; see ZeroSSL docs.
|
||||
|
||||
+184
-23
@@ -63,7 +63,7 @@ func (m *fileMode) UnmarshalJSON(b []byte) error {
|
||||
|
||||
// MarshalJSON satisfies json.Marshaler.
|
||||
func (m *fileMode) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("\"%04o\"", *m)), nil
|
||||
return fmt.Appendf(nil, "\"%04o\"", *m), nil
|
||||
}
|
||||
|
||||
// parseFileMode parses a file mode string,
|
||||
@@ -90,6 +90,15 @@ type FileWriter struct {
|
||||
// 0600 by default.
|
||||
Mode fileMode `json:"mode,omitempty"`
|
||||
|
||||
// DirMode controls permissions for any directories created to reach Filename.
|
||||
// Default: 0700 (current behavior).
|
||||
//
|
||||
// Special values:
|
||||
// - "inherit" → copy the nearest existing parent directory's perms (with r→x normalization)
|
||||
// - "from_file" → derive from the file Mode (with r→x), e.g. 0644 → 0755, 0600 → 0700
|
||||
// Numeric octal strings (e.g. "0755") are also accepted. Subject to process umask.
|
||||
DirMode string `json:"dir_mode,omitempty"`
|
||||
|
||||
// Roll toggles log rolling or rotation, which is
|
||||
// enabled by default.
|
||||
Roll *bool `json:"roll,omitempty"`
|
||||
@@ -113,9 +122,16 @@ type FileWriter struct {
|
||||
// See https://github.com/DeRuina/timberjack#%EF%B8%8F-rotation-notes--warnings for caveats
|
||||
RollAt []string `json:"roll_at,omitempty"`
|
||||
|
||||
// Whether to compress rolled files. Default: true
|
||||
// Whether to compress rolled files.
|
||||
// Default: true.
|
||||
// Deprecated: Use RollCompression instead, setting it to "none".
|
||||
RollCompress *bool `json:"roll_gzip,omitempty"`
|
||||
|
||||
// RollCompression selects the compression algorithm for rolled files.
|
||||
// Accepted values: "none", "gzip", "zstd".
|
||||
// Default: gzip
|
||||
RollCompression string `json:"roll_compression,omitempty"`
|
||||
|
||||
// Whether to use local timestamps in rolled filenames.
|
||||
// Default: false
|
||||
RollLocalTime bool `json:"roll_local_time,omitempty"`
|
||||
@@ -177,11 +193,33 @@ func (fw FileWriter) OpenWriter() (io.WriteCloser, error) {
|
||||
// roll log files as a sensible default to avoid disk space exhaustion
|
||||
roll := fw.Roll == nil || *fw.Roll
|
||||
|
||||
// create the file if it does not exist; create with the configured mode, or default
|
||||
// to restrictive if not set. (timberjack will reuse the file mode across log rotation)
|
||||
if err := os.MkdirAll(filepath.Dir(fw.Filename), 0o700); err != nil {
|
||||
return nil, err
|
||||
// Ensure directory exists before opening the file.
|
||||
dirPath := filepath.Dir(fw.Filename)
|
||||
switch strings.ToLower(strings.TrimSpace(fw.DirMode)) {
|
||||
case "", "0":
|
||||
// Preserve current behavior: locked-down directories by default.
|
||||
if err := os.MkdirAll(dirPath, 0o700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "inherit":
|
||||
if err := mkdirAllInherit(dirPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "from_file":
|
||||
if err := mkdirAllFromFile(dirPath, os.FileMode(fw.Mode)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
dm, err := parseFileMode(fw.DirMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dir_mode: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(dirPath, dm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// create/open the file
|
||||
file, err := os.OpenFile(fw.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, modeIfCreating)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -223,27 +261,104 @@ func (fw FileWriter) OpenWriter() (io.WriteCloser, error) {
|
||||
if fw.RollKeepDays == 0 {
|
||||
fw.RollKeepDays = 90
|
||||
}
|
||||
|
||||
// Determine compression algorithm to use. Priority:
|
||||
// 1) explicit RollCompression (none|gzip|zstd)
|
||||
// 2) if RollCompress is unset or true -> gzip
|
||||
// 3) if RollCompress is false -> none
|
||||
var compression string
|
||||
if fw.RollCompression != "" {
|
||||
compression = strings.ToLower(strings.TrimSpace(fw.RollCompression))
|
||||
if compression != "none" && compression != "gzip" && compression != "zstd" {
|
||||
return nil, fmt.Errorf("invalid roll_compression: %s", fw.RollCompression)
|
||||
}
|
||||
} else {
|
||||
if fw.RollCompress == nil || *fw.RollCompress {
|
||||
compression = "gzip"
|
||||
} else {
|
||||
compression = "none"
|
||||
}
|
||||
}
|
||||
|
||||
return &timberjack.Logger{
|
||||
Filename: fw.Filename,
|
||||
MaxSize: fw.RollSizeMB,
|
||||
MaxAge: fw.RollKeepDays,
|
||||
MaxBackups: fw.RollKeep,
|
||||
LocalTime: fw.RollLocalTime,
|
||||
Compress: *fw.RollCompress,
|
||||
Compression: compression,
|
||||
RotationInterval: fw.RollInterval,
|
||||
RotateAtMinutes: fw.RollAtMinutes,
|
||||
RotateAt: fw.RollAt,
|
||||
BackupTimeFormat: fw.BackupTimeFormat,
|
||||
FileMode: os.FileMode(fw.Mode),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeDirPerm ensures that read bits also have execute bits set.
|
||||
func normalizeDirPerm(p os.FileMode) os.FileMode {
|
||||
if p&0o400 != 0 {
|
||||
p |= 0o100
|
||||
}
|
||||
if p&0o040 != 0 {
|
||||
p |= 0o010
|
||||
}
|
||||
if p&0o004 != 0 {
|
||||
p |= 0o001
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// mkdirAllInherit creates missing dirs using the nearest existing parent's
|
||||
// permissions, normalized with r→x.
|
||||
func mkdirAllInherit(dir string) error {
|
||||
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
|
||||
return nil
|
||||
}
|
||||
cur := dir
|
||||
var parent string
|
||||
for {
|
||||
next := filepath.Dir(cur)
|
||||
if next == cur {
|
||||
parent = next
|
||||
break
|
||||
}
|
||||
if fi, err := os.Stat(next); err == nil {
|
||||
if !fi.IsDir() {
|
||||
return fmt.Errorf("path component %s exists and is not a directory", next)
|
||||
}
|
||||
parent = next
|
||||
break
|
||||
}
|
||||
cur = next
|
||||
}
|
||||
perm := os.FileMode(0o700)
|
||||
if fi, err := os.Stat(parent); err == nil && fi.IsDir() {
|
||||
perm = fi.Mode().Perm()
|
||||
}
|
||||
perm = normalizeDirPerm(perm)
|
||||
return os.MkdirAll(dir, perm)
|
||||
}
|
||||
|
||||
// mkdirAllFromFile creates missing dirs using the file's mode (with r→x) so
|
||||
// 0644 → 0755, 0600 → 0700, etc.
|
||||
func mkdirAllFromFile(dir string, fileMode os.FileMode) error {
|
||||
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
|
||||
return nil
|
||||
}
|
||||
perm := normalizeDirPerm(fileMode.Perm()) | 0o200 // ensure owner write on dir so files can be created
|
||||
return os.MkdirAll(dir, perm)
|
||||
}
|
||||
|
||||
// UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax:
|
||||
//
|
||||
// file <filename> {
|
||||
// mode <mode>
|
||||
// dir_mode <mode|inherit|from_file>
|
||||
// roll_disabled
|
||||
// roll_size <size>
|
||||
// roll_uncompressed
|
||||
// roll_compression <none|gzip|zstd>
|
||||
// roll_local_time
|
||||
// roll_keep <num>
|
||||
// roll_keep_for <days>
|
||||
@@ -284,6 +399,22 @@ func (fw *FileWriter) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
fw.Mode = fileMode(mode)
|
||||
|
||||
case "dir_mode":
|
||||
var val string
|
||||
if !d.AllArgs(&val) {
|
||||
return d.ArgErr()
|
||||
}
|
||||
val = strings.TrimSpace(val)
|
||||
switch strings.ToLower(val) {
|
||||
case "inherit", "from_file":
|
||||
fw.DirMode = val
|
||||
default:
|
||||
if _, err := parseFileMode(val); err != nil {
|
||||
return d.Errf("parsing dir_mode: %v", err)
|
||||
}
|
||||
fw.DirMode = val
|
||||
}
|
||||
|
||||
case "roll_disabled":
|
||||
var f bool
|
||||
fw.Roll = &f
|
||||
@@ -309,6 +440,19 @@ func (fw *FileWriter) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
return d.ArgErr()
|
||||
}
|
||||
|
||||
case "roll_compression":
|
||||
var comp string
|
||||
if !d.AllArgs(&comp) {
|
||||
return d.ArgErr()
|
||||
}
|
||||
comp = strings.ToLower(strings.TrimSpace(comp))
|
||||
switch comp {
|
||||
case "none", "gzip", "zstd":
|
||||
fw.RollCompression = comp
|
||||
default:
|
||||
return d.Errf("parsing roll_compression: must be 'none', 'gzip' or 'zstd'")
|
||||
}
|
||||
|
||||
case "roll_local_time":
|
||||
fw.RollLocalTime = true
|
||||
if d.NextArg() {
|
||||
@@ -352,31 +496,48 @@ func (fw *FileWriter) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
fw.RollInterval = duration
|
||||
|
||||
case "roll_minutes":
|
||||
var minutesArrayStr string
|
||||
if !d.AllArgs(&minutesArrayStr) {
|
||||
// Accept either a single comma-separated argument or
|
||||
// multiple space-separated arguments. Collect all
|
||||
// remaining args on the line and split on commas.
|
||||
args := d.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
return d.ArgErr()
|
||||
}
|
||||
minutesStr := strings.Split(minutesArrayStr, ",")
|
||||
minutes := make([]int, len(minutesStr))
|
||||
for i := range minutesStr {
|
||||
ms := strings.Trim(minutesStr[i], " ")
|
||||
m, err := strconv.Atoi(ms)
|
||||
if err != nil {
|
||||
return d.Errf("parsing roll_minutes number: %v", err)
|
||||
var minutes []int
|
||||
for _, arg := range args {
|
||||
parts := strings.SplitSeq(arg, ",")
|
||||
for p := range parts {
|
||||
ms := strings.TrimSpace(p)
|
||||
if ms == "" {
|
||||
return d.Errf("parsing roll_minutes: empty value")
|
||||
}
|
||||
m, err := strconv.Atoi(ms)
|
||||
if err != nil {
|
||||
return d.Errf("parsing roll_minutes number: %v", err)
|
||||
}
|
||||
minutes = append(minutes, m)
|
||||
}
|
||||
minutes[i] = m
|
||||
}
|
||||
fw.RollAtMinutes = minutes
|
||||
|
||||
case "roll_at":
|
||||
var timeArrayStr string
|
||||
if !d.AllArgs(&timeArrayStr) {
|
||||
// Accept either a single comma-separated argument or
|
||||
// multiple space-separated arguments. Collect all
|
||||
// remaining args on the line and split on commas.
|
||||
args := d.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
return d.ArgErr()
|
||||
}
|
||||
timeStr := strings.Split(timeArrayStr, ",")
|
||||
times := make([]string, len(timeStr))
|
||||
for i := range timeStr {
|
||||
times[i] = strings.Trim(timeStr[i], " ")
|
||||
var times []string
|
||||
for _, arg := range args {
|
||||
parts := strings.SplitSeq(arg, ",")
|
||||
for p := range parts {
|
||||
ts := strings.TrimSpace(p)
|
||||
if ts == "" {
|
||||
return d.Errf("parsing roll_at: empty value")
|
||||
}
|
||||
times = append(times, ts)
|
||||
}
|
||||
}
|
||||
fw.RollAt = times
|
||||
|
||||
|
||||
@@ -385,3 +385,225 @@ func TestFileModeModification(t *testing.T) {
|
||||
t.Errorf("file mode is %v, want %v", st.Mode(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirMode_Inherit(t *testing.T) {
|
||||
m := syscall.Umask(0)
|
||||
defer syscall.Umask(m)
|
||||
|
||||
parent := t.TempDir()
|
||||
if err := os.Chmod(parent, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
targetDir := filepath.Join(parent, "a", "b")
|
||||
fw := &FileWriter{
|
||||
Filename: filepath.Join(targetDir, "test.log"),
|
||||
DirMode: "inherit",
|
||||
Mode: 0o640,
|
||||
Roll: func() *bool { f := false; return &f }(),
|
||||
}
|
||||
w, err := fw.OpenWriter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = w.Close()
|
||||
|
||||
st, err := os.Stat(targetDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := st.Mode().Perm(); got != 0o755 {
|
||||
t.Fatalf("dir perm = %o, want 0755", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirMode_FromFile(t *testing.T) {
|
||||
m := syscall.Umask(0)
|
||||
defer syscall.Umask(m)
|
||||
|
||||
base := t.TempDir()
|
||||
|
||||
dir1 := filepath.Join(base, "logs1")
|
||||
fw1 := &FileWriter{
|
||||
Filename: filepath.Join(dir1, "app.log"),
|
||||
DirMode: "from_file",
|
||||
Mode: 0o644, // => dir 0755
|
||||
Roll: func() *bool { f := false; return &f }(),
|
||||
}
|
||||
w1, err := fw1.OpenWriter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = w1.Close()
|
||||
|
||||
st1, err := os.Stat(dir1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := st1.Mode().Perm(); got != 0o755 {
|
||||
t.Fatalf("dir perm = %o, want 0755", got)
|
||||
}
|
||||
|
||||
dir2 := filepath.Join(base, "logs2")
|
||||
fw2 := &FileWriter{
|
||||
Filename: filepath.Join(dir2, "app.log"),
|
||||
DirMode: "from_file",
|
||||
Mode: 0o600, // => dir 0700
|
||||
Roll: func() *bool { f := false; return &f }(),
|
||||
}
|
||||
w2, err := fw2.OpenWriter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = w2.Close()
|
||||
|
||||
st2, err := os.Stat(dir2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := st2.Mode().Perm(); got != 0o700 {
|
||||
t.Fatalf("dir perm = %o, want 0700", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirMode_ExplicitOctal(t *testing.T) {
|
||||
m := syscall.Umask(0)
|
||||
defer syscall.Umask(m)
|
||||
|
||||
base := t.TempDir()
|
||||
dest := filepath.Join(base, "logs3")
|
||||
fw := &FileWriter{
|
||||
Filename: filepath.Join(dest, "app.log"),
|
||||
DirMode: "0750",
|
||||
Mode: 0o640,
|
||||
Roll: func() *bool { f := false; return &f }(),
|
||||
}
|
||||
w, err := fw.OpenWriter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = w.Close()
|
||||
|
||||
st, err := os.Stat(dest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := st.Mode().Perm(); got != 0o750 {
|
||||
t.Fatalf("dir perm = %o, want 0750", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirMode_Default0700(t *testing.T) {
|
||||
m := syscall.Umask(0)
|
||||
defer syscall.Umask(m)
|
||||
|
||||
base := t.TempDir()
|
||||
dest := filepath.Join(base, "logs4")
|
||||
fw := &FileWriter{
|
||||
Filename: filepath.Join(dest, "app.log"),
|
||||
Mode: 0o640,
|
||||
Roll: func() *bool { f := false; return &f }(),
|
||||
}
|
||||
w, err := fw.OpenWriter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = w.Close()
|
||||
|
||||
st, err := os.Stat(dest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := st.Mode().Perm(); got != 0o700 {
|
||||
t.Fatalf("dir perm = %o, want 0700", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirMode_UmaskInteraction(t *testing.T) {
|
||||
_ = syscall.Umask(0o022) // typical umask; restore after
|
||||
defer syscall.Umask(0)
|
||||
|
||||
base := t.TempDir()
|
||||
dest := filepath.Join(base, "logs5")
|
||||
fw := &FileWriter{
|
||||
Filename: filepath.Join(dest, "app.log"),
|
||||
DirMode: "0755",
|
||||
Mode: 0o644,
|
||||
Roll: func() *bool { f := false; return &f }(),
|
||||
}
|
||||
w, err := fw.OpenWriter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = w.Close()
|
||||
|
||||
st, err := os.Stat(dest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 0755 &^ 0022 still 0755 for dirs; this just sanity-checks we didn't get stricter unexpectedly
|
||||
if got := st.Mode().Perm(); got != 0o755 {
|
||||
t.Fatalf("dir perm = %o, want 0755 (considering umask)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaddyfile_DirMode_Inherit(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(`
|
||||
file /var/log/app.log {
|
||||
dir_mode inherit
|
||||
mode 0640
|
||||
}`)
|
||||
var fw FileWriter
|
||||
if err := fw.UnmarshalCaddyfile(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fw.DirMode != "inherit" {
|
||||
t.Fatalf("got %q", fw.DirMode)
|
||||
}
|
||||
if fw.Mode != 0o640 {
|
||||
t.Fatalf("mode = %o", fw.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaddyfile_DirMode_FromFile(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(`
|
||||
file /var/log/app.log {
|
||||
dir_mode from_file
|
||||
mode 0600
|
||||
}`)
|
||||
var fw FileWriter
|
||||
if err := fw.UnmarshalCaddyfile(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fw.DirMode != "from_file" {
|
||||
t.Fatalf("got %q", fw.DirMode)
|
||||
}
|
||||
if fw.Mode != 0o600 {
|
||||
t.Fatalf("mode = %o", fw.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaddyfile_DirMode_Octal(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(`
|
||||
file /var/log/app.log {
|
||||
dir_mode 0755
|
||||
}`)
|
||||
var fw FileWriter
|
||||
if err := fw.UnmarshalCaddyfile(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fw.DirMode != "0755" {
|
||||
t.Fatalf("got %q", fw.DirMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaddyfile_DirMode_Invalid(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(`
|
||||
file /var/log/app.log {
|
||||
dir_mode nope
|
||||
}`)
|
||||
var fw FileWriter
|
||||
if err := fw.UnmarshalCaddyfile(d); err == nil {
|
||||
t.Fatal("expected error for invalid dir_mode")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,9 +23,9 @@ import (
|
||||
)
|
||||
|
||||
// Windows relies on ACLs instead of unix permissions model.
|
||||
// Go allows to open files with a particular mode put it is limited to read or write.
|
||||
// Go allows to open files with a particular mode but it is limited to read or write.
|
||||
// See https://cs.opensource.google/go/go/+/refs/tags/go1.22.3:src/syscall/syscall_windows.go;l=708.
|
||||
// This is pretty restrictive and has few interest for log files and thus we just test that log files are
|
||||
// This is pretty restrictive and has little interest for log files and thus we just test that log files are
|
||||
// opened with R/W permissions by default on Windows too.
|
||||
func TestFileCreationMode(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "caddytest")
|
||||
@@ -53,3 +53,41 @@ func TestFileCreationMode(t *testing.T) {
|
||||
t.Fatalf("file mode is %v, want rw for user", st.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirMode_Windows_CreateSucceeds(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "caddytest")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create tempdir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dirMode string
|
||||
}{
|
||||
{"inherit", "inherit"},
|
||||
{"from_file", "from_file"},
|
||||
{"octal", "0755"},
|
||||
{"default", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
subdir := path.Join(dir, "logs-"+tt.name)
|
||||
fw := &FileWriter{
|
||||
Filename: path.Join(subdir, "test.log"),
|
||||
DirMode: tt.dirMode,
|
||||
Mode: 0o600,
|
||||
}
|
||||
w, err := fw.OpenWriter()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open writer: %v", err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if _, err := os.Stat(fw.Filename); err != nil {
|
||||
t.Fatalf("expected file to exist: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,7 +255,7 @@ func (m IPMaskFilter) Filter(in zapcore.Field) zapcore.Field {
|
||||
}
|
||||
|
||||
func (m IPMaskFilter) mask(s string) string {
|
||||
output := ""
|
||||
parts := make([]string, 0)
|
||||
for value := range strings.SplitSeq(s, ",") {
|
||||
value = strings.TrimSpace(value)
|
||||
host, port, err := net.SplitHostPort(value)
|
||||
@@ -264,7 +264,7 @@ func (m IPMaskFilter) mask(s string) string {
|
||||
}
|
||||
ipAddr := net.ParseIP(host)
|
||||
if ipAddr == nil {
|
||||
output += value + ", "
|
||||
parts = append(parts, value)
|
||||
continue
|
||||
}
|
||||
mask := m.v4Mask
|
||||
@@ -273,13 +273,13 @@ func (m IPMaskFilter) mask(s string) string {
|
||||
}
|
||||
masked := ipAddr.Mask(mask)
|
||||
if port == "" {
|
||||
output += masked.String() + ", "
|
||||
parts = append(parts, masked.String())
|
||||
continue
|
||||
}
|
||||
|
||||
output += net.JoinHostPort(masked.String(), port) + ", "
|
||||
parts = append(parts, net.JoinHostPort(masked.String(), port))
|
||||
}
|
||||
return strings.TrimSuffix(output, ", ")
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
type filterAction string
|
||||
|
||||
Reference in New Issue
Block a user