mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	1. Replace original `maxrequestbody` directive. 2. Add request header limit. fix issue #1587 Signed-off-by: Tw <tw19881113@gmail.com>
		
			
				
	
	
		
			220 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			220 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package limits
 | 
						|
 | 
						|
import (
 | 
						|
	"errors"
 | 
						|
	"sort"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"github.com/mholt/caddy"
 | 
						|
	"github.com/mholt/caddy/caddyhttp/httpserver"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	serverType = "http"
 | 
						|
	pluginName = "limits"
 | 
						|
)
 | 
						|
 | 
						|
func init() {
 | 
						|
	caddy.RegisterPlugin(pluginName, caddy.Plugin{
 | 
						|
		ServerType: serverType,
 | 
						|
		Action:     setupLimits,
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
// pathLimitUnparsed is a PathLimit before it's parsed
 | 
						|
type pathLimitUnparsed struct {
 | 
						|
	Path  string
 | 
						|
	Limit string
 | 
						|
}
 | 
						|
 | 
						|
func setupLimits(c *caddy.Controller) error {
 | 
						|
	bls, err := parseLimits(c)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
 | 
						|
		return Limit{Next: next, BodyLimits: bls}
 | 
						|
	})
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func parseLimits(c *caddy.Controller) ([]httpserver.PathLimit, error) {
 | 
						|
	config := httpserver.GetConfig(c)
 | 
						|
 | 
						|
	if !c.Next() {
 | 
						|
		return nil, c.ArgErr()
 | 
						|
	}
 | 
						|
 | 
						|
	args := c.RemainingArgs()
 | 
						|
	argList := []pathLimitUnparsed{}
 | 
						|
	headerLimit := ""
 | 
						|
 | 
						|
	switch len(args) {
 | 
						|
	case 0:
 | 
						|
		// Format: limits {
 | 
						|
		//	header <limit>
 | 
						|
		//	body <path> <limit>
 | 
						|
		//	body <limit>
 | 
						|
		//	...
 | 
						|
		// }
 | 
						|
		for c.NextBlock() {
 | 
						|
			kind := c.Val()
 | 
						|
			pathOrLimit := c.RemainingArgs()
 | 
						|
			switch kind {
 | 
						|
			case "header":
 | 
						|
				if len(pathOrLimit) != 1 {
 | 
						|
					return nil, c.ArgErr()
 | 
						|
				}
 | 
						|
				headerLimit = pathOrLimit[0]
 | 
						|
			case "body":
 | 
						|
				if len(pathOrLimit) == 1 {
 | 
						|
					argList = append(argList, pathLimitUnparsed{
 | 
						|
						Path:  "/",
 | 
						|
						Limit: pathOrLimit[0],
 | 
						|
					})
 | 
						|
					break
 | 
						|
				}
 | 
						|
 | 
						|
				if len(pathOrLimit) == 2 {
 | 
						|
					argList = append(argList, pathLimitUnparsed{
 | 
						|
						Path:  pathOrLimit[0],
 | 
						|
						Limit: pathOrLimit[1],
 | 
						|
					})
 | 
						|
					break
 | 
						|
				}
 | 
						|
 | 
						|
				fallthrough
 | 
						|
			default:
 | 
						|
				return nil, c.ArgErr()
 | 
						|
			}
 | 
						|
		}
 | 
						|
	case 1:
 | 
						|
		// Format: limits <limit>
 | 
						|
		headerLimit = args[0]
 | 
						|
		argList = []pathLimitUnparsed{{
 | 
						|
			Path:  "/",
 | 
						|
			Limit: args[0],
 | 
						|
		}}
 | 
						|
	default:
 | 
						|
		return nil, c.ArgErr()
 | 
						|
	}
 | 
						|
 | 
						|
	if headerLimit != "" {
 | 
						|
		size := parseSize(headerLimit)
 | 
						|
		if size < 1 { // also disallow size = 0
 | 
						|
			return nil, c.ArgErr()
 | 
						|
		}
 | 
						|
		config.Limits.MaxRequestHeaderSize = size
 | 
						|
	}
 | 
						|
 | 
						|
	if len(argList) > 0 {
 | 
						|
		pathLimit, err := parseArguments(argList)
 | 
						|
		if err != nil {
 | 
						|
			return nil, c.ArgErr()
 | 
						|
		}
 | 
						|
		SortPathLimits(pathLimit)
 | 
						|
		config.Limits.MaxRequestBodySizes = pathLimit
 | 
						|
	}
 | 
						|
 | 
						|
	return config.Limits.MaxRequestBodySizes, nil
 | 
						|
}
 | 
						|
 | 
						|
func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {
 | 
						|
	pathLimit := []httpserver.PathLimit{}
 | 
						|
 | 
						|
	for _, pair := range args {
 | 
						|
		size := parseSize(pair.Limit)
 | 
						|
		if size < 1 { // also disallow size = 0
 | 
						|
			return pathLimit, errors.New("Parse failed")
 | 
						|
		}
 | 
						|
		pathLimit = addPathLimit(pathLimit, pair.Path, size)
 | 
						|
	}
 | 
						|
	return pathLimit, nil
 | 
						|
}
 | 
						|
 | 
						|
var validUnits = []struct {
 | 
						|
	symbol     string
 | 
						|
	multiplier int64
 | 
						|
}{
 | 
						|
	{"KB", 1024},
 | 
						|
	{"MB", 1024 * 1024},
 | 
						|
	{"GB", 1024 * 1024 * 1024},
 | 
						|
	{"B", 1},
 | 
						|
	{"", 1}, // defaulting to "B"
 | 
						|
}
 | 
						|
 | 
						|
// parseSize parses the given string as size limit
 | 
						|
// Size are positive numbers followed by a unit (case insensitive)
 | 
						|
// Allowed units: "B" (bytes), "KB" (kilo), "MB" (mega), "GB" (giga)
 | 
						|
// If the unit is omitted, "b" is assumed
 | 
						|
// Returns the parsed size in bytes, or -1 if cannot parse
 | 
						|
func parseSize(sizeStr string) int64 {
 | 
						|
	sizeStr = strings.ToUpper(sizeStr)
 | 
						|
 | 
						|
	for _, unit := range validUnits {
 | 
						|
		if strings.HasSuffix(sizeStr, unit.symbol) {
 | 
						|
			size, err := strconv.ParseInt(sizeStr[0:len(sizeStr)-len(unit.symbol)], 10, 64)
 | 
						|
			if err != nil {
 | 
						|
				return -1
 | 
						|
			}
 | 
						|
			return size * unit.multiplier
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Unreachable code
 | 
						|
	return -1
 | 
						|
}
 | 
						|
 | 
						|
// addPathLimit appends the path-to-request body limit mapping to pathLimit
 | 
						|
// Slashes are checked and added to path if necessary. Duplicates are ignored.
 | 
						|
func addPathLimit(pathLimit []httpserver.PathLimit, path string, limit int64) []httpserver.PathLimit {
 | 
						|
	// Enforces preceding slash
 | 
						|
	if path[0] != '/' {
 | 
						|
		path = "/" + path
 | 
						|
	}
 | 
						|
 | 
						|
	// Use the last value if there are duplicates
 | 
						|
	for i, p := range pathLimit {
 | 
						|
		if p.Path == path {
 | 
						|
			pathLimit[i].Limit = limit
 | 
						|
			return pathLimit
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return append(pathLimit, httpserver.PathLimit{Path: path, Limit: limit})
 | 
						|
}
 | 
						|
 | 
						|
// SortPathLimits sort pathLimits by their paths length, longest first
 | 
						|
func SortPathLimits(pathLimits []httpserver.PathLimit) {
 | 
						|
	sorter := &pathLimitSorter{
 | 
						|
		pathLimits: pathLimits,
 | 
						|
		by:         LengthDescending,
 | 
						|
	}
 | 
						|
	sort.Sort(sorter)
 | 
						|
}
 | 
						|
 | 
						|
// structs and methods implementing the sorting interfaces for PathLimit
 | 
						|
type pathLimitSorter struct {
 | 
						|
	pathLimits []httpserver.PathLimit
 | 
						|
	by         func(p1, p2 *httpserver.PathLimit) bool
 | 
						|
}
 | 
						|
 | 
						|
func (s *pathLimitSorter) Len() int {
 | 
						|
	return len(s.pathLimits)
 | 
						|
}
 | 
						|
 | 
						|
func (s *pathLimitSorter) Swap(i, j int) {
 | 
						|
	s.pathLimits[i], s.pathLimits[j] = s.pathLimits[j], s.pathLimits[i]
 | 
						|
}
 | 
						|
 | 
						|
func (s *pathLimitSorter) Less(i, j int) bool {
 | 
						|
	return s.by(&s.pathLimits[i], &s.pathLimits[j])
 | 
						|
}
 | 
						|
 | 
						|
// LengthDescending is the comparator for SortPathLimits
 | 
						|
func LengthDescending(p1, p2 *httpserver.PathLimit) bool {
 | 
						|
	return len(p1.Path) > len(p2.Path)
 | 
						|
}
 |