mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-10-28 09:12:52 -04:00 
			
		
		
		
	
		
			
				
	
	
		
			504 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			504 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2015 Matthew Holt and The Caddy Authors
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package caddyhttp
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"crypto/ecdsa"
 | |
| 	"crypto/ed25519"
 | |
| 	"crypto/elliptic"
 | |
| 	"crypto/rsa"
 | |
| 	"crypto/sha256"
 | |
| 	"crypto/tls"
 | |
| 	"crypto/x509"
 | |
| 	"encoding/asn1"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/pem"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/netip"
 | |
| 	"net/textproto"
 | |
| 	"net/url"
 | |
| 	"path"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/caddyserver/caddy/v2"
 | |
| 	"github.com/caddyserver/caddy/v2/modules/caddytls"
 | |
| 	"github.com/google/uuid"
 | |
| )
 | |
| 
 | |
| // NewTestReplacer creates a replacer for an http.Request
 | |
| // for use in tests that are not in this package
 | |
| func NewTestReplacer(req *http.Request) *caddy.Replacer {
 | |
| 	repl := caddy.NewReplacer()
 | |
| 	ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, repl)
 | |
| 	*req = *req.WithContext(ctx)
 | |
| 	addHTTPVarsToReplacer(repl, req, nil)
 | |
| 	return repl
 | |
| }
 | |
| 
 | |
| func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.ResponseWriter) {
 | |
| 	SetVar(req.Context(), "start_time", time.Now())
 | |
| 	SetVar(req.Context(), "uuid", new(requestID))
 | |
| 
 | |
| 	httpVars := func(key string) (any, bool) {
 | |
| 		if req != nil {
 | |
| 			// query string parameters
 | |
| 			if strings.HasPrefix(key, reqURIQueryReplPrefix) {
 | |
| 				vals := req.URL.Query()[key[len(reqURIQueryReplPrefix):]]
 | |
| 				// always return true, since the query param might
 | |
| 				// be present only in some requests
 | |
| 				return strings.Join(vals, ","), true
 | |
| 			}
 | |
| 
 | |
| 			// request header fields
 | |
| 			if strings.HasPrefix(key, reqHeaderReplPrefix) {
 | |
| 				field := key[len(reqHeaderReplPrefix):]
 | |
| 				vals := req.Header[textproto.CanonicalMIMEHeaderKey(field)]
 | |
| 				// always return true, since the header field might
 | |
| 				// be present only in some requests
 | |
| 				return strings.Join(vals, ","), true
 | |
| 			}
 | |
| 
 | |
| 			// cookies
 | |
| 			if strings.HasPrefix(key, reqCookieReplPrefix) {
 | |
| 				name := key[len(reqCookieReplPrefix):]
 | |
| 				for _, cookie := range req.Cookies() {
 | |
| 					if strings.EqualFold(name, cookie.Name) {
 | |
| 						// always return true, since the cookie might
 | |
| 						// be present only in some requests
 | |
| 						return cookie.Value, true
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			// http.request.tls.*
 | |
| 			if strings.HasPrefix(key, reqTLSReplPrefix) {
 | |
| 				return getReqTLSReplacement(req, key)
 | |
| 			}
 | |
| 
 | |
| 			switch key {
 | |
| 			case "http.request.method":
 | |
| 				return req.Method, true
 | |
| 			case "http.request.scheme":
 | |
| 				if req.TLS != nil {
 | |
| 					return "https", true
 | |
| 				}
 | |
| 				return "http", true
 | |
| 			case "http.request.proto":
 | |
| 				return req.Proto, true
 | |
| 			case "http.request.host":
 | |
| 				host, _, err := net.SplitHostPort(req.Host)
 | |
| 				if err != nil {
 | |
| 					return req.Host, true // OK; there probably was no port
 | |
| 				}
 | |
| 				return host, true
 | |
| 			case "http.request.port":
 | |
| 				_, port, _ := net.SplitHostPort(req.Host)
 | |
| 				if portNum, err := strconv.Atoi(port); err == nil {
 | |
| 					return portNum, true
 | |
| 				}
 | |
| 				return port, true
 | |
| 			case "http.request.hostport":
 | |
| 				return req.Host, true
 | |
| 			case "http.request.remote":
 | |
| 				return req.RemoteAddr, true
 | |
| 			case "http.request.remote.host":
 | |
| 				host, _, err := net.SplitHostPort(req.RemoteAddr)
 | |
| 				if err != nil {
 | |
| 					return req.RemoteAddr, true
 | |
| 				}
 | |
| 				return host, true
 | |
| 			case "http.request.remote.port":
 | |
| 				_, port, _ := net.SplitHostPort(req.RemoteAddr)
 | |
| 				if portNum, err := strconv.Atoi(port); err == nil {
 | |
| 					return portNum, true
 | |
| 				}
 | |
| 				return port, true
 | |
| 
 | |
| 			// current URI, including any internal rewrites
 | |
| 			case "http.request.uri":
 | |
| 				return req.URL.RequestURI(), true
 | |
| 			case "http.request.uri.path":
 | |
| 				return req.URL.Path, true
 | |
| 			case "http.request.uri.path.file":
 | |
| 				_, file := path.Split(req.URL.Path)
 | |
| 				return file, true
 | |
| 			case "http.request.uri.path.dir":
 | |
| 				dir, _ := path.Split(req.URL.Path)
 | |
| 				return dir, true
 | |
| 			case "http.request.uri.path.file.base":
 | |
| 				return strings.TrimSuffix(path.Base(req.URL.Path), path.Ext(req.URL.Path)), true
 | |
| 			case "http.request.uri.path.file.ext":
 | |
| 				return path.Ext(req.URL.Path), true
 | |
| 			case "http.request.uri.query":
 | |
| 				return req.URL.RawQuery, true
 | |
| 			case "http.request.duration":
 | |
| 				start := GetVar(req.Context(), "start_time").(time.Time)
 | |
| 				return time.Since(start), true
 | |
| 			case "http.request.duration_ms":
 | |
| 				start := GetVar(req.Context(), "start_time").(time.Time)
 | |
| 				return time.Since(start).Seconds() * 1e3, true // multiply seconds to preserve decimal (see #4666)
 | |
| 			case "http.request.uuid":
 | |
| 				id := GetVar(req.Context(), "uuid").(*requestID)
 | |
| 				return id.String(), true
 | |
| 			case "http.request.body":
 | |
| 				if req.Body == nil {
 | |
| 					return "", true
 | |
| 				}
 | |
| 				// normally net/http will close the body for us, but since we
 | |
| 				// are replacing it with a fake one, we have to ensure we close
 | |
| 				// the real body ourselves when we're done
 | |
| 				defer req.Body.Close()
 | |
| 				// read the request body into a buffer (can't pool because we
 | |
| 				// don't know its lifetime and would have to make a copy anyway)
 | |
| 				buf := new(bytes.Buffer)
 | |
| 				_, _ = io.Copy(buf, req.Body) // can't handle error, so just ignore it
 | |
| 				req.Body = io.NopCloser(buf)  // replace real body with buffered data
 | |
| 				return buf.String(), true
 | |
| 
 | |
| 			// original request, before any internal changes
 | |
| 			case "http.request.orig_method":
 | |
| 				or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
 | |
| 				return or.Method, true
 | |
| 			case "http.request.orig_uri":
 | |
| 				or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
 | |
| 				return or.RequestURI, true
 | |
| 			case "http.request.orig_uri.path":
 | |
| 				or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
 | |
| 				return or.URL.Path, true
 | |
| 			case "http.request.orig_uri.path.file":
 | |
| 				or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
 | |
| 				_, file := path.Split(or.URL.Path)
 | |
| 				return file, true
 | |
| 			case "http.request.orig_uri.path.dir":
 | |
| 				or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
 | |
| 				dir, _ := path.Split(or.URL.Path)
 | |
| 				return dir, true
 | |
| 			case "http.request.orig_uri.query":
 | |
| 				or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
 | |
| 				return or.URL.RawQuery, true
 | |
| 			}
 | |
| 
 | |
| 			// remote IP range/prefix (e.g. keep top 24 bits of 1.2.3.4  => "1.2.3.0/24")
 | |
| 			// syntax: "/V4,V6" where V4 = IPv4 bits, and V6 = IPv6 bits; if no comma, then same bit length used for both
 | |
| 			// (EXPERIMENTAL)
 | |
| 			if strings.HasPrefix(key, "http.request.remote.host/") {
 | |
| 				host, _, err := net.SplitHostPort(req.RemoteAddr)
 | |
| 				if err != nil {
 | |
| 					host = req.RemoteAddr // assume no port, I guess?
 | |
| 				}
 | |
| 				addr, err := netip.ParseAddr(host)
 | |
| 				if err != nil {
 | |
| 					return host, true // not an IP address
 | |
| 				}
 | |
| 				// extract the bits from the end of the placeholder (start after "/") then split on ","
 | |
| 				bitsBoth := key[strings.Index(key, "/")+1:]
 | |
| 				ipv4BitsStr, ipv6BitsStr, cutOK := strings.Cut(bitsBoth, ",")
 | |
| 				bitsStr := ipv4BitsStr
 | |
| 				if addr.Is6() && cutOK {
 | |
| 					bitsStr = ipv6BitsStr
 | |
| 				}
 | |
| 				// convert to integer then compute prefix
 | |
| 				bits, err := strconv.Atoi(bitsStr)
 | |
| 				if err != nil {
 | |
| 					return "", true
 | |
| 				}
 | |
| 				prefix, err := addr.Prefix(bits)
 | |
| 				if err != nil {
 | |
| 					return "", true
 | |
| 				}
 | |
| 				return prefix.String(), true
 | |
| 			}
 | |
| 
 | |
| 			// hostname labels
 | |
| 			if strings.HasPrefix(key, reqHostLabelsReplPrefix) {
 | |
| 				idxStr := key[len(reqHostLabelsReplPrefix):]
 | |
| 				idx, err := strconv.Atoi(idxStr)
 | |
| 				if err != nil || idx < 0 {
 | |
| 					return "", false
 | |
| 				}
 | |
| 				reqHost, _, err := net.SplitHostPort(req.Host)
 | |
| 				if err != nil {
 | |
| 					reqHost = req.Host // OK; assume there was no port
 | |
| 				}
 | |
| 				hostLabels := strings.Split(reqHost, ".")
 | |
| 				if idx >= len(hostLabels) {
 | |
| 					return "", true
 | |
| 				}
 | |
| 				return hostLabels[len(hostLabels)-idx-1], true
 | |
| 			}
 | |
| 
 | |
| 			// path parts
 | |
| 			if strings.HasPrefix(key, reqURIPathReplPrefix) {
 | |
| 				idxStr := key[len(reqURIPathReplPrefix):]
 | |
| 				idx, err := strconv.Atoi(idxStr)
 | |
| 				if err != nil {
 | |
| 					return "", false
 | |
| 				}
 | |
| 				pathParts := strings.Split(req.URL.Path, "/")
 | |
| 				if len(pathParts) > 0 && pathParts[0] == "" {
 | |
| 					pathParts = pathParts[1:]
 | |
| 				}
 | |
| 				if idx < 0 {
 | |
| 					return "", false
 | |
| 				}
 | |
| 				if idx >= len(pathParts) {
 | |
| 					return "", true
 | |
| 				}
 | |
| 				return pathParts[idx], true
 | |
| 			}
 | |
| 
 | |
| 			// orig uri path parts
 | |
| 			if strings.HasPrefix(key, reqOrigURIPathReplPrefix) {
 | |
| 				idxStr := key[len(reqOrigURIPathReplPrefix):]
 | |
| 				idx, err := strconv.Atoi(idxStr)
 | |
| 				if err != nil {
 | |
| 					return "", false
 | |
| 				}
 | |
| 				or, _ := req.Context().Value(OriginalRequestCtxKey).(http.Request)
 | |
| 				pathParts := strings.Split(or.URL.Path, "/")
 | |
| 				if len(pathParts) > 0 && pathParts[0] == "" {
 | |
| 					pathParts = pathParts[1:]
 | |
| 				}
 | |
| 				if idx < 0 {
 | |
| 					return "", false
 | |
| 				}
 | |
| 				if idx >= len(pathParts) {
 | |
| 					return "", true
 | |
| 				}
 | |
| 				return pathParts[idx], true
 | |
| 			}
 | |
| 
 | |
| 			// middleware variables
 | |
| 			if strings.HasPrefix(key, varsReplPrefix) {
 | |
| 				varName := key[len(varsReplPrefix):]
 | |
| 				raw := GetVar(req.Context(), varName)
 | |
| 				// variables can be dynamic, so always return true
 | |
| 				// even when it may not be set; treat as empty then
 | |
| 				return raw, true
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if w != nil {
 | |
| 			// response header fields
 | |
| 			if strings.HasPrefix(key, respHeaderReplPrefix) {
 | |
| 				field := key[len(respHeaderReplPrefix):]
 | |
| 				vals := w.Header()[textproto.CanonicalMIMEHeaderKey(field)]
 | |
| 				// always return true, since the header field might
 | |
| 				// be present only in some responses
 | |
| 				return strings.Join(vals, ","), true
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		switch {
 | |
| 		case key == "http.shutting_down":
 | |
| 			server := req.Context().Value(ServerCtxKey).(*Server)
 | |
| 			server.shutdownAtMu.RLock()
 | |
| 			defer server.shutdownAtMu.RUnlock()
 | |
| 			return !server.shutdownAt.IsZero(), true
 | |
| 		case key == "http.time_until_shutdown":
 | |
| 			server := req.Context().Value(ServerCtxKey).(*Server)
 | |
| 			server.shutdownAtMu.RLock()
 | |
| 			defer server.shutdownAtMu.RUnlock()
 | |
| 			if server.shutdownAt.IsZero() {
 | |
| 				return nil, true
 | |
| 			}
 | |
| 			return time.Until(server.shutdownAt), true
 | |
| 		}
 | |
| 
 | |
| 		return nil, false
 | |
| 	}
 | |
| 
 | |
| 	repl.Map(httpVars)
 | |
| }
 | |
| 
 | |
| func getReqTLSReplacement(req *http.Request, key string) (any, bool) {
 | |
| 	if req == nil || req.TLS == nil {
 | |
| 		return nil, false
 | |
| 	}
 | |
| 
 | |
| 	if len(key) < len(reqTLSReplPrefix) {
 | |
| 		return nil, false
 | |
| 	}
 | |
| 
 | |
| 	field := strings.ToLower(key[len(reqTLSReplPrefix):])
 | |
| 
 | |
| 	if strings.HasPrefix(field, "client.") {
 | |
| 		cert := getTLSPeerCert(req.TLS)
 | |
| 		if cert == nil {
 | |
| 			return nil, false
 | |
| 		}
 | |
| 
 | |
| 		// subject alternate names (SANs)
 | |
| 		if strings.HasPrefix(field, "client.san.") {
 | |
| 			field = field[len("client.san."):]
 | |
| 			var fieldName string
 | |
| 			var fieldValue any
 | |
| 			switch {
 | |
| 			case strings.HasPrefix(field, "dns_names"):
 | |
| 				fieldName = "dns_names"
 | |
| 				fieldValue = cert.DNSNames
 | |
| 			case strings.HasPrefix(field, "emails"):
 | |
| 				fieldName = "emails"
 | |
| 				fieldValue = cert.EmailAddresses
 | |
| 			case strings.HasPrefix(field, "ips"):
 | |
| 				fieldName = "ips"
 | |
| 				fieldValue = cert.IPAddresses
 | |
| 			case strings.HasPrefix(field, "uris"):
 | |
| 				fieldName = "uris"
 | |
| 				fieldValue = cert.URIs
 | |
| 			default:
 | |
| 				return nil, false
 | |
| 			}
 | |
| 			field = field[len(fieldName):]
 | |
| 
 | |
| 			// if no index was specified, return the whole list
 | |
| 			if field == "" {
 | |
| 				return fieldValue, true
 | |
| 			}
 | |
| 			if len(field) < 2 || field[0] != '.' {
 | |
| 				return nil, false
 | |
| 			}
 | |
| 			field = field[1:] // trim '.' between field name and index
 | |
| 
 | |
| 			// get the numeric index
 | |
| 			idx, err := strconv.Atoi(field)
 | |
| 			if err != nil || idx < 0 {
 | |
| 				return nil, false
 | |
| 			}
 | |
| 
 | |
| 			// access the indexed element and return it
 | |
| 			switch v := fieldValue.(type) {
 | |
| 			case []string:
 | |
| 				if idx >= len(v) {
 | |
| 					return nil, true
 | |
| 				}
 | |
| 				return v[idx], true
 | |
| 			case []net.IP:
 | |
| 				if idx >= len(v) {
 | |
| 					return nil, true
 | |
| 				}
 | |
| 				return v[idx], true
 | |
| 			case []*url.URL:
 | |
| 				if idx >= len(v) {
 | |
| 					return nil, true
 | |
| 				}
 | |
| 				return v[idx], true
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		switch field {
 | |
| 		case "client.fingerprint":
 | |
| 			return fmt.Sprintf("%x", sha256.Sum256(cert.Raw)), true
 | |
| 		case "client.public_key", "client.public_key_sha256":
 | |
| 			if cert.PublicKey == nil {
 | |
| 				return nil, true
 | |
| 			}
 | |
| 			pubKeyBytes, err := marshalPublicKey(cert.PublicKey)
 | |
| 			if err != nil {
 | |
| 				return nil, true
 | |
| 			}
 | |
| 			if strings.HasSuffix(field, "_sha256") {
 | |
| 				return fmt.Sprintf("%x", sha256.Sum256(pubKeyBytes)), true
 | |
| 			}
 | |
| 			return fmt.Sprintf("%x", pubKeyBytes), true
 | |
| 		case "client.issuer":
 | |
| 			return cert.Issuer, true
 | |
| 		case "client.serial":
 | |
| 			return cert.SerialNumber, true
 | |
| 		case "client.subject":
 | |
| 			return cert.Subject, true
 | |
| 		case "client.certificate_pem":
 | |
| 			block := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
 | |
| 			return pem.EncodeToMemory(&block), true
 | |
| 		case "client.certificate_der_base64":
 | |
| 			return base64.StdEncoding.EncodeToString(cert.Raw), true
 | |
| 		default:
 | |
| 			return nil, false
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	switch field {
 | |
| 	case "version":
 | |
| 		return caddytls.ProtocolName(req.TLS.Version), true
 | |
| 	case "cipher_suite":
 | |
| 		return tls.CipherSuiteName(req.TLS.CipherSuite), true
 | |
| 	case "resumed":
 | |
| 		return req.TLS.DidResume, true
 | |
| 	case "proto":
 | |
| 		return req.TLS.NegotiatedProtocol, true
 | |
| 	case "proto_mutual":
 | |
| 		// req.TLS.NegotiatedProtocolIsMutual is deprecated - it's always true.
 | |
| 		return true, true
 | |
| 	case "server_name":
 | |
| 		return req.TLS.ServerName, true
 | |
| 	}
 | |
| 	return nil, false
 | |
| }
 | |
| 
 | |
| // marshalPublicKey returns the byte encoding of pubKey.
 | |
| func marshalPublicKey(pubKey any) ([]byte, error) {
 | |
| 	switch key := pubKey.(type) {
 | |
| 	case *rsa.PublicKey:
 | |
| 		return asn1.Marshal(key)
 | |
| 	case *ecdsa.PublicKey:
 | |
| 		return elliptic.Marshal(key.Curve, key.X, key.Y), nil
 | |
| 	case ed25519.PublicKey:
 | |
| 		return key, nil
 | |
| 	}
 | |
| 	return nil, fmt.Errorf("unrecognized public key type: %T", pubKey)
 | |
| }
 | |
| 
 | |
| // getTLSPeerCert retrieves the first peer certificate from a TLS session.
 | |
| // Returns nil if no peer cert is in use.
 | |
| func getTLSPeerCert(cs *tls.ConnectionState) *x509.Certificate {
 | |
| 	if len(cs.PeerCertificates) == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 	return cs.PeerCertificates[0]
 | |
| }
 | |
| 
 | |
| type requestID struct {
 | |
| 	value string
 | |
| }
 | |
| 
 | |
| // Lazy generates UUID string or return cached value if present
 | |
| func (rid *requestID) String() string {
 | |
| 	if rid.value == "" {
 | |
| 		if id, err := uuid.NewRandom(); err == nil {
 | |
| 			rid.value = id.String()
 | |
| 		}
 | |
| 	}
 | |
| 	return rid.value
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	reqCookieReplPrefix      = "http.request.cookie."
 | |
| 	reqHeaderReplPrefix      = "http.request.header."
 | |
| 	reqHostLabelsReplPrefix  = "http.request.host.labels."
 | |
| 	reqTLSReplPrefix         = "http.request.tls."
 | |
| 	reqURIPathReplPrefix     = "http.request.uri.path."
 | |
| 	reqURIQueryReplPrefix    = "http.request.uri.query."
 | |
| 	respHeaderReplPrefix     = "http.response.header."
 | |
| 	varsReplPrefix           = "http.vars."
 | |
| 	reqOrigURIPathReplPrefix = "http.request.orig_uri.path."
 | |
| )
 |