mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	Merge pull request #1656 from tw4452852/1587-limits
Introduce `limits` middleware
This commit is contained in:
		
						commit
						f06b825f44
					
				@ -16,9 +16,9 @@ import (
 | 
				
			|||||||
	_ "github.com/mholt/caddy/caddyhttp/header"
 | 
						_ "github.com/mholt/caddy/caddyhttp/header"
 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/index"
 | 
						_ "github.com/mholt/caddy/caddyhttp/index"
 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/internalsrv"
 | 
						_ "github.com/mholt/caddy/caddyhttp/internalsrv"
 | 
				
			||||||
 | 
						_ "github.com/mholt/caddy/caddyhttp/limits"
 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/log"
 | 
						_ "github.com/mholt/caddy/caddyhttp/log"
 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/markdown"
 | 
						_ "github.com/mholt/caddy/caddyhttp/markdown"
 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/maxrequestbody"
 | 
					 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/mime"
 | 
						_ "github.com/mholt/caddy/caddyhttp/mime"
 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/pprof"
 | 
						_ "github.com/mholt/caddy/caddyhttp/pprof"
 | 
				
			||||||
	_ "github.com/mholt/caddy/caddyhttp/proxy"
 | 
						_ "github.com/mholt/caddy/caddyhttp/proxy"
 | 
				
			||||||
 | 
				
			|||||||
@ -436,7 +436,7 @@ var directives = []string{
 | 
				
			|||||||
	"root",
 | 
						"root",
 | 
				
			||||||
	"index",
 | 
						"index",
 | 
				
			||||||
	"bind",
 | 
						"bind",
 | 
				
			||||||
	"maxrequestbody", // TODO: 'limits'
 | 
						"limits",
 | 
				
			||||||
	"timeouts",
 | 
						"timeouts",
 | 
				
			||||||
	"tls",
 | 
						"tls",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -302,7 +302,7 @@ func (r *replacer) getSubstitution(key string) string {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		_, err := ioutil.ReadAll(r.request.Body)
 | 
							_, err := ioutil.ReadAll(r.request.Body)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			if _, ok := err.(MaxBytesExceeded); ok {
 | 
								if err == MaxBytesExceededErr {
 | 
				
			||||||
				return r.emptyValue
 | 
									return r.emptyValue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
				
			|||||||
@ -4,8 +4,8 @@ package httpserver
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"crypto/tls"
 | 
						"crypto/tls"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
@ -66,6 +66,7 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
 | 
				
			|||||||
		sites:       group,
 | 
							sites:       group,
 | 
				
			||||||
		connTimeout: GracefulTimeout,
 | 
							connTimeout: GracefulTimeout,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						s.Server = makeHTTPServerWithHeaderLimit(s.Server, group)
 | 
				
			||||||
	s.Server.Handler = s // this is weird, but whatever
 | 
						s.Server.Handler = s // this is weird, but whatever
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// extract TLS settings from each site config to build
 | 
						// extract TLS settings from each site config to build
 | 
				
			||||||
@ -127,6 +128,32 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
 | 
				
			|||||||
	return s, nil
 | 
						return s, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// makeHTTPServerWithHeaderLimit apply minimum header limit within a group to given http.Server
 | 
				
			||||||
 | 
					func makeHTTPServerWithHeaderLimit(s *http.Server, group []*SiteConfig) *http.Server {
 | 
				
			||||||
 | 
						var min int64
 | 
				
			||||||
 | 
						for _, cfg := range group {
 | 
				
			||||||
 | 
							limit := cfg.Limits.MaxRequestHeaderSize
 | 
				
			||||||
 | 
							if limit == 0 {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// not set yet
 | 
				
			||||||
 | 
							if min == 0 {
 | 
				
			||||||
 | 
								min = limit
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// find a better one
 | 
				
			||||||
 | 
							if limit < min {
 | 
				
			||||||
 | 
								min = limit
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if min > 0 {
 | 
				
			||||||
 | 
							s.MaxHeaderBytes = int(min)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return s
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// makeHTTPServerWithTimeouts makes an http.Server from the group of
 | 
					// makeHTTPServerWithTimeouts makes an http.Server from the group of
 | 
				
			||||||
// configs in a way that configures timeouts (or, if not set, it uses
 | 
					// configs in a way that configures timeouts (or, if not set, it uses
 | 
				
			||||||
// the default timeouts) by combining the configuration of each
 | 
					// the default timeouts) by combining the configuration of each
 | 
				
			||||||
@ -359,20 +386,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Apply the path-based request body size limit
 | 
					 | 
				
			||||||
	// The error returned by MaxBytesReader is meant to be handled
 | 
					 | 
				
			||||||
	// by whichever middleware/plugin that receives it when calling
 | 
					 | 
				
			||||||
	// .Read() or a similar method on the request body
 | 
					 | 
				
			||||||
	// TODO: Make this middleware instead?
 | 
					 | 
				
			||||||
	if r.Body != nil {
 | 
					 | 
				
			||||||
		for _, pathlimit := range vhost.MaxRequestBodySizes {
 | 
					 | 
				
			||||||
			if Path(r.URL.Path).Matches(pathlimit.Path) {
 | 
					 | 
				
			||||||
				r.Body = MaxBytesReader(w, r.Body, pathlimit.Limit)
 | 
					 | 
				
			||||||
				break
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return vhost.middlewareChain.ServeHTTP(w, r)
 | 
						return vhost.middlewareChain.ServeHTTP(w, r)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -465,73 +478,9 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
 | 
				
			|||||||
	return ln.TCPListener.File()
 | 
						return ln.TCPListener.File()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// MaxBytesExceeded is the error type returned by MaxBytesReader
 | 
					// MaxBytesExceeded is the error returned by MaxBytesReader
 | 
				
			||||||
// when the request body exceeds the limit imposed
 | 
					// when the request body exceeds the limit imposed
 | 
				
			||||||
type MaxBytesExceeded struct{}
 | 
					var MaxBytesExceededErr = errors.New("http: request body too large")
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (err MaxBytesExceeded) Error() string {
 | 
					 | 
				
			||||||
	return "http: request body too large"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// MaxBytesReader and its associated methods are borrowed from the
 | 
					 | 
				
			||||||
// Go Standard library (comments intact). The only difference is that
 | 
					 | 
				
			||||||
// it returns a MaxBytesExceeded error instead of a generic error message
 | 
					 | 
				
			||||||
// when the request body has exceeded the requested limit
 | 
					 | 
				
			||||||
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
 | 
					 | 
				
			||||||
	return &maxBytesReader{w: w, r: r, n: n}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type maxBytesReader struct {
 | 
					 | 
				
			||||||
	w   http.ResponseWriter
 | 
					 | 
				
			||||||
	r   io.ReadCloser // underlying reader
 | 
					 | 
				
			||||||
	n   int64         // max bytes remaining
 | 
					 | 
				
			||||||
	err error         // sticky error
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
 | 
					 | 
				
			||||||
	if l.err != nil {
 | 
					 | 
				
			||||||
		return 0, l.err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(p) == 0 {
 | 
					 | 
				
			||||||
		return 0, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// If they asked for a 32KB byte read but only 5 bytes are
 | 
					 | 
				
			||||||
	// remaining, no need to read 32KB. 6 bytes will answer the
 | 
					 | 
				
			||||||
	// question of the whether we hit the limit or go past it.
 | 
					 | 
				
			||||||
	if int64(len(p)) > l.n+1 {
 | 
					 | 
				
			||||||
		p = p[:l.n+1]
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	n, err = l.r.Read(p)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if int64(n) <= l.n {
 | 
					 | 
				
			||||||
		l.n -= int64(n)
 | 
					 | 
				
			||||||
		l.err = err
 | 
					 | 
				
			||||||
		return n, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	n = int(l.n)
 | 
					 | 
				
			||||||
	l.n = 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// The server code and client code both use
 | 
					 | 
				
			||||||
	// maxBytesReader. This "requestTooLarge" check is
 | 
					 | 
				
			||||||
	// only used by the server code. To prevent binaries
 | 
					 | 
				
			||||||
	// which only using the HTTP Client code (such as
 | 
					 | 
				
			||||||
	// cmd/go) from also linking in the HTTP server, don't
 | 
					 | 
				
			||||||
	// use a static type assertion to the server
 | 
					 | 
				
			||||||
	// "*response" type. Check this interface instead:
 | 
					 | 
				
			||||||
	type requestTooLarger interface {
 | 
					 | 
				
			||||||
		requestTooLarge()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if res, ok := l.w.(requestTooLarger); ok {
 | 
					 | 
				
			||||||
		res.requestTooLarge()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	l.err = MaxBytesExceeded{}
 | 
					 | 
				
			||||||
	return n, l.err
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (l *maxBytesReader) Close() error {
 | 
					 | 
				
			||||||
	return l.r.Close()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// DefaultErrorFunc responds to an HTTP request with a simple description
 | 
					// DefaultErrorFunc responds to an HTTP request with a simple description
 | 
				
			||||||
// of the specified HTTP status code.
 | 
					// of the specified HTTP status code.
 | 
				
			||||||
 | 
				
			|||||||
@ -15,7 +15,7 @@ func TestAddress(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMakeHTTPServer(t *testing.T) {
 | 
					func TestMakeHTTPServerWithTimeouts(t *testing.T) {
 | 
				
			||||||
	for i, tc := range []struct {
 | 
						for i, tc := range []struct {
 | 
				
			||||||
		group    []*SiteConfig
 | 
							group    []*SiteConfig
 | 
				
			||||||
		expected Timeouts
 | 
							expected Timeouts
 | 
				
			||||||
@ -111,3 +111,36 @@ func TestMakeHTTPServer(t *testing.T) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
 | 
				
			||||||
 | 
						for name, c := range map[string]struct {
 | 
				
			||||||
 | 
							group  []*SiteConfig
 | 
				
			||||||
 | 
							expect int
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							"disable": {
 | 
				
			||||||
 | 
								group:  []*SiteConfig{{}},
 | 
				
			||||||
 | 
								expect: 0,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"oneSite": {
 | 
				
			||||||
 | 
								group: []*SiteConfig{{Limits: Limits{
 | 
				
			||||||
 | 
									MaxRequestHeaderSize: 100,
 | 
				
			||||||
 | 
								}}},
 | 
				
			||||||
 | 
								expect: 100,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"multiSites": {
 | 
				
			||||||
 | 
								group: []*SiteConfig{
 | 
				
			||||||
 | 
									{Limits: Limits{MaxRequestHeaderSize: 100}},
 | 
				
			||||||
 | 
									{Limits: Limits{MaxRequestHeaderSize: 50}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								expect: 50,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						} {
 | 
				
			||||||
 | 
							c := c
 | 
				
			||||||
 | 
							t.Run(name, func(t *testing.T) {
 | 
				
			||||||
 | 
								actual := makeHTTPServerWithHeaderLimit(&http.Server{}, c.group)
 | 
				
			||||||
 | 
								if got := actual.MaxHeaderBytes; got != c.expect {
 | 
				
			||||||
 | 
									t.Errorf("Expect %d, but got %d", c.expect, got)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -38,8 +38,8 @@ type SiteConfig struct {
 | 
				
			|||||||
	// for a request.
 | 
						// for a request.
 | 
				
			||||||
	HiddenFiles []string
 | 
						HiddenFiles []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Max amount of bytes a request can send on a given path
 | 
						// Max request's header/body size
 | 
				
			||||||
	MaxRequestBodySizes []PathLimit
 | 
						Limits Limits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// The path to the Caddyfile used to generate this site config
 | 
						// The path to the Caddyfile used to generate this site config
 | 
				
			||||||
	originCaddyfile string
 | 
						originCaddyfile string
 | 
				
			||||||
@ -71,6 +71,12 @@ type Timeouts struct {
 | 
				
			|||||||
	IdleTimeoutSet       bool
 | 
						IdleTimeoutSet       bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Limits specify size limit of request's header and body.
 | 
				
			||||||
 | 
					type Limits struct {
 | 
				
			||||||
 | 
						MaxRequestHeaderSize int64
 | 
				
			||||||
 | 
						MaxRequestBodySizes  []PathLimit
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// PathLimit is a mapping from a site's path to its corresponding
 | 
					// PathLimit is a mapping from a site's path to its corresponding
 | 
				
			||||||
// maximum request body size (in bytes)
 | 
					// maximum request body size (in bytes)
 | 
				
			||||||
type PathLimit struct {
 | 
					type PathLimit struct {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										90
									
								
								caddyhttp/limits/handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								caddyhttp/limits/handler.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					package limits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/mholt/caddy/caddyhttp/httpserver"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Limit is a middleware to control request body size
 | 
				
			||||||
 | 
					type Limit struct {
 | 
				
			||||||
 | 
						Next       httpserver.Handler
 | 
				
			||||||
 | 
						BodyLimits []httpserver.PathLimit
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (l Limit) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			||||||
 | 
						if r.Body == nil {
 | 
				
			||||||
 | 
							return l.Next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// apply the path-based request body size limit.
 | 
				
			||||||
 | 
						for _, bl := range l.BodyLimits {
 | 
				
			||||||
 | 
							if httpserver.Path(r.URL.Path).Matches(bl.Path) {
 | 
				
			||||||
 | 
								r.Body = MaxBytesReader(w, r.Body, bl.Limit)
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return l.Next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MaxBytesReader and its associated methods are borrowed from the
 | 
				
			||||||
 | 
					// Go Standard library (comments intact). The only difference is that
 | 
				
			||||||
 | 
					// it returns a MaxBytesExceeded error instead of a generic error message
 | 
				
			||||||
 | 
					// when the request body has exceeded the requested limit
 | 
				
			||||||
 | 
					func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
 | 
				
			||||||
 | 
						return &maxBytesReader{w: w, r: r, n: n}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type maxBytesReader struct {
 | 
				
			||||||
 | 
						w   http.ResponseWriter
 | 
				
			||||||
 | 
						r   io.ReadCloser // underlying reader
 | 
				
			||||||
 | 
						n   int64         // max bytes remaining
 | 
				
			||||||
 | 
						err error         // sticky error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (l *maxBytesReader) Read(p []byte) (n int, err error) {
 | 
				
			||||||
 | 
						if l.err != nil {
 | 
				
			||||||
 | 
							return 0, l.err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(p) == 0 {
 | 
				
			||||||
 | 
							return 0, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// If they asked for a 32KB byte read but only 5 bytes are
 | 
				
			||||||
 | 
						// remaining, no need to read 32KB. 6 bytes will answer the
 | 
				
			||||||
 | 
						// question of the whether we hit the limit or go past it.
 | 
				
			||||||
 | 
						if int64(len(p)) > l.n+1 {
 | 
				
			||||||
 | 
							p = p[:l.n+1]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						n, err = l.r.Read(p)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if int64(n) <= l.n {
 | 
				
			||||||
 | 
							l.n -= int64(n)
 | 
				
			||||||
 | 
							l.err = err
 | 
				
			||||||
 | 
							return n, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n = int(l.n)
 | 
				
			||||||
 | 
						l.n = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// The server code and client code both use
 | 
				
			||||||
 | 
						// maxBytesReader. This "requestTooLarge" check is
 | 
				
			||||||
 | 
						// only used by the server code. To prevent binaries
 | 
				
			||||||
 | 
						// which only using the HTTP Client code (such as
 | 
				
			||||||
 | 
						// cmd/go) from also linking in the HTTP server, don't
 | 
				
			||||||
 | 
						// use a static type assertion to the server
 | 
				
			||||||
 | 
						// "*response" type. Check this interface instead:
 | 
				
			||||||
 | 
						type requestTooLarger interface {
 | 
				
			||||||
 | 
							requestTooLarge()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if res, ok := l.w.(requestTooLarger); ok {
 | 
				
			||||||
 | 
							res.requestTooLarge()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						l.err = httpserver.MaxBytesExceededErr
 | 
				
			||||||
 | 
						return n, l.err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (l *maxBytesReader) Close() error {
 | 
				
			||||||
 | 
						return l.r.Close()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										35
									
								
								caddyhttp/limits/handler_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								caddyhttp/limits/handler_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					package limits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/mholt/caddy/caddyhttp/httpserver"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBodySizeLimit(t *testing.T) {
 | 
				
			||||||
 | 
						var (
 | 
				
			||||||
 | 
							gotContent    []byte
 | 
				
			||||||
 | 
							gotError      error
 | 
				
			||||||
 | 
							expectContent = "hello"
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						l := Limit{
 | 
				
			||||||
 | 
							Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			||||||
 | 
								gotContent, gotError = ioutil.ReadAll(r.Body)
 | 
				
			||||||
 | 
								return 0, nil
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							BodyLimits: []httpserver.PathLimit{{Path: "/", Limit: int64(len(expectContent))}},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r := httptest.NewRequest("GET", "/", strings.NewReader(expectContent+expectContent))
 | 
				
			||||||
 | 
						l.ServeHTTP(httptest.NewRecorder(), r)
 | 
				
			||||||
 | 
						if got := string(gotContent); got != expectContent {
 | 
				
			||||||
 | 
							t.Errorf("expected content[%s], got[%s]", expectContent, got)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if gotError != httpserver.MaxBytesExceededErr {
 | 
				
			||||||
 | 
							t.Errorf("expect error %v, got %v", httpserver.MaxBytesExceededErr, gotError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
package maxrequestbody
 | 
					package limits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
@ -12,13 +12,13 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	serverType = "http"
 | 
						serverType = "http"
 | 
				
			||||||
	pluginName = "maxrequestbody"
 | 
						pluginName = "limits"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	caddy.RegisterPlugin(pluginName, caddy.Plugin{
 | 
						caddy.RegisterPlugin(pluginName, caddy.Plugin{
 | 
				
			||||||
		ServerType: serverType,
 | 
							ServerType: serverType,
 | 
				
			||||||
		Action:     setupMaxRequestBody,
 | 
							Action:     setupLimits,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,56 +28,97 @@ type pathLimitUnparsed struct {
 | 
				
			|||||||
	Limit string
 | 
						Limit string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func setupMaxRequestBody(c *caddy.Controller) error {
 | 
					func setupLimits(c *caddy.Controller) error {
 | 
				
			||||||
 | 
						bls, err := parseLimits(c)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
 | 
				
			||||||
 | 
							return Limit{Next: next, BodyLimits: bls}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func parseLimits(c *caddy.Controller) ([]httpserver.PathLimit, error) {
 | 
				
			||||||
	config := httpserver.GetConfig(c)
 | 
						config := httpserver.GetConfig(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !c.Next() {
 | 
						if !c.Next() {
 | 
				
			||||||
		return c.ArgErr()
 | 
							return nil, c.ArgErr()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	args := c.RemainingArgs()
 | 
						args := c.RemainingArgs()
 | 
				
			||||||
	argList := []pathLimitUnparsed{}
 | 
						argList := []pathLimitUnparsed{}
 | 
				
			||||||
 | 
						headerLimit := ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	switch len(args) {
 | 
						switch len(args) {
 | 
				
			||||||
	case 0:
 | 
						case 0:
 | 
				
			||||||
		// Format: { <path> <limit> ... }
 | 
							// Format: limits {
 | 
				
			||||||
 | 
							//	header <limit>
 | 
				
			||||||
 | 
							//	body <path> <limit>
 | 
				
			||||||
 | 
							//	body <limit>
 | 
				
			||||||
 | 
							//	...
 | 
				
			||||||
 | 
							// }
 | 
				
			||||||
		for c.NextBlock() {
 | 
							for c.NextBlock() {
 | 
				
			||||||
			path := c.Val()
 | 
								kind := c.Val()
 | 
				
			||||||
			if !c.NextArg() {
 | 
								pathOrLimit := c.RemainingArgs()
 | 
				
			||||||
				// Uneven pairing of path/limit
 | 
								switch kind {
 | 
				
			||||||
				return c.ArgErr()
 | 
								case "header":
 | 
				
			||||||
 | 
									if len(pathOrLimit) != 1 {
 | 
				
			||||||
 | 
										return nil, c.ArgErr()
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									headerLimit = pathOrLimit[0]
 | 
				
			||||||
 | 
								case "body":
 | 
				
			||||||
 | 
									if len(pathOrLimit) == 1 {
 | 
				
			||||||
					argList = append(argList, pathLimitUnparsed{
 | 
										argList = append(argList, pathLimitUnparsed{
 | 
				
			||||||
				Path:  path,
 | 
											Path:  "/",
 | 
				
			||||||
				Limit: c.Val(),
 | 
											Limit: pathOrLimit[0],
 | 
				
			||||||
					})
 | 
										})
 | 
				
			||||||
 | 
										break
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if len(pathOrLimit) == 2 {
 | 
				
			||||||
 | 
										argList = append(argList, pathLimitUnparsed{
 | 
				
			||||||
 | 
											Path:  pathOrLimit[0],
 | 
				
			||||||
 | 
											Limit: pathOrLimit[1],
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
										break
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									fallthrough
 | 
				
			||||||
 | 
								default:
 | 
				
			||||||
 | 
									return nil, c.ArgErr()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case 1:
 | 
						case 1:
 | 
				
			||||||
		// Format: <limit>
 | 
							// Format: limits <limit>
 | 
				
			||||||
 | 
							headerLimit = args[0]
 | 
				
			||||||
		argList = []pathLimitUnparsed{{
 | 
							argList = []pathLimitUnparsed{{
 | 
				
			||||||
			Path:  "/",
 | 
								Path:  "/",
 | 
				
			||||||
			Limit: args[0],
 | 
								Limit: args[0],
 | 
				
			||||||
		}}
 | 
							}}
 | 
				
			||||||
	case 2:
 | 
					 | 
				
			||||||
		// Format: <path> <limit>
 | 
					 | 
				
			||||||
		argList = []pathLimitUnparsed{{
 | 
					 | 
				
			||||||
			Path:  args[0],
 | 
					 | 
				
			||||||
			Limit: args[1],
 | 
					 | 
				
			||||||
		}}
 | 
					 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return c.ArgErr()
 | 
							return nil, c.ArgErr()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if headerLimit != "" {
 | 
				
			||||||
 | 
							size := parseSize(headerLimit)
 | 
				
			||||||
 | 
							if size < 1 { // also disallow size = 0
 | 
				
			||||||
 | 
								return nil, c.ArgErr()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							config.Limits.MaxRequestHeaderSize = size
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(argList) > 0 {
 | 
				
			||||||
		pathLimit, err := parseArguments(argList)
 | 
							pathLimit, err := parseArguments(argList)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
		return c.ArgErr()
 | 
								return nil, c.ArgErr()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							SortPathLimits(pathLimit)
 | 
				
			||||||
 | 
							config.Limits.MaxRequestBodySizes = pathLimit
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	SortPathLimits(pathLimit)
 | 
						return config.Limits.MaxRequestBodySizes, nil
 | 
				
			||||||
 | 
					 | 
				
			||||||
	config.MaxRequestBodySizes = pathLimit
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {
 | 
					func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {
 | 
				
			||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
package maxrequestbody
 | 
					package limits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
@ -14,32 +14,98 @@ const (
 | 
				
			|||||||
	GB = 1024 * 1024 * 1024
 | 
						GB = 1024 * 1024 * 1024
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSetupMaxRequestBody(t *testing.T) {
 | 
					func TestParseLimits(t *testing.T) {
 | 
				
			||||||
	cases := []struct {
 | 
						for name, c := range map[string]struct {
 | 
				
			||||||
		input     string
 | 
							input     string
 | 
				
			||||||
		hasError bool
 | 
							shouldErr bool
 | 
				
			||||||
 | 
							expect    httpserver.Limits
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		// Format: { <path> <limit> ... }
 | 
							"catchAll": {
 | 
				
			||||||
		{input: "maxrequestbody / 20MB", hasError: false},
 | 
								input: `limits 2kb`,
 | 
				
			||||||
		// Format: <limit>
 | 
								expect: httpserver.Limits{
 | 
				
			||||||
		{input: "maxrequestbody 999KB", hasError: false},
 | 
									MaxRequestHeaderSize: 2 * KB,
 | 
				
			||||||
		// Format: { <path> <limit> ... }
 | 
									MaxRequestBodySizes:  []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
 | 
				
			||||||
		{input: "maxrequestbody { /images 50MB /upload 10MB\n/test 10KB }", hasError: false},
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
		// Wrong formats
 | 
							"onlyHeader": {
 | 
				
			||||||
		{input: "maxrequestbody typo { /images 50MB }", hasError: true},
 | 
								input: `limits {
 | 
				
			||||||
		{input: "maxrequestbody 999MB /home 20KB", hasError: true},
 | 
									header 2kb
 | 
				
			||||||
	}
 | 
								}`,
 | 
				
			||||||
	for caseNum, c := range cases {
 | 
								expect: httpserver.Limits{
 | 
				
			||||||
 | 
									MaxRequestHeaderSize: 2 * KB,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"onlyBody": {
 | 
				
			||||||
 | 
								input: `limits {
 | 
				
			||||||
 | 
									body 2kb
 | 
				
			||||||
 | 
								}`,
 | 
				
			||||||
 | 
								expect: httpserver.Limits{
 | 
				
			||||||
 | 
									MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"onlyBodyWithPath": {
 | 
				
			||||||
 | 
								input: `limits {
 | 
				
			||||||
 | 
									body /test 2kb
 | 
				
			||||||
 | 
								}`,
 | 
				
			||||||
 | 
								expect: httpserver.Limits{
 | 
				
			||||||
 | 
									MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/test", Limit: 2 * KB}},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"mixture": {
 | 
				
			||||||
 | 
								input: `limits {
 | 
				
			||||||
 | 
									header 1kb
 | 
				
			||||||
 | 
									body 2kb
 | 
				
			||||||
 | 
									body /bar 3kb
 | 
				
			||||||
 | 
								}`,
 | 
				
			||||||
 | 
								expect: httpserver.Limits{
 | 
				
			||||||
 | 
									MaxRequestHeaderSize: 1 * KB,
 | 
				
			||||||
 | 
									MaxRequestBodySizes: []httpserver.PathLimit{
 | 
				
			||||||
 | 
										{Path: "/bar", Limit: 3 * KB},
 | 
				
			||||||
 | 
										{Path: "/", Limit: 2 * KB},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"invalidFormat": {
 | 
				
			||||||
 | 
								input:     `limits a b`,
 | 
				
			||||||
 | 
								shouldErr: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"invalidHeaderFormat": {
 | 
				
			||||||
 | 
								input: `limits {
 | 
				
			||||||
 | 
									header / 100
 | 
				
			||||||
 | 
								}`,
 | 
				
			||||||
 | 
								shouldErr: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"invalidBodyFormat": {
 | 
				
			||||||
 | 
								input: `limits {
 | 
				
			||||||
 | 
									body / 100 200
 | 
				
			||||||
 | 
								}`,
 | 
				
			||||||
 | 
								shouldErr: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"invalidKind": {
 | 
				
			||||||
 | 
								input: `limits {
 | 
				
			||||||
 | 
									head 100
 | 
				
			||||||
 | 
								}`,
 | 
				
			||||||
 | 
								shouldErr: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							"invalidLimitSize": {
 | 
				
			||||||
 | 
								input:     `limits 10bk`,
 | 
				
			||||||
 | 
								shouldErr: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						} {
 | 
				
			||||||
 | 
							c := c
 | 
				
			||||||
 | 
							t.Run(name, func(t *testing.T) {
 | 
				
			||||||
			controller := caddy.NewTestController("", c.input)
 | 
								controller := caddy.NewTestController("", c.input)
 | 
				
			||||||
		err := setupMaxRequestBody(controller)
 | 
								_, err := parseLimits(controller)
 | 
				
			||||||
 | 
								if c.shouldErr && err == nil {
 | 
				
			||||||
		if c.hasError && (err == nil) {
 | 
									t.Error("failed to get expected error")
 | 
				
			||||||
			t.Errorf("Expecting error for case %v but none encountered", caseNum)
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		if !c.hasError && (err != nil) {
 | 
								if !c.shouldErr && err != nil {
 | 
				
			||||||
			t.Errorf("Expecting no error for case %v but encountered %v", caseNum, err)
 | 
									t.Errorf("got unexpected error: %v", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								if got := httpserver.GetConfig(controller).Limits; !reflect.DeepEqual(got, c.expect) {
 | 
				
			||||||
 | 
									t.Errorf("expect %#v, but got %#v", c.expect, got)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -228,7 +228,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			|||||||
			return 0, nil
 | 
								return 0, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if _, ok := backendErr.(httpserver.MaxBytesExceeded); ok {
 | 
							if backendErr == httpserver.MaxBytesExceededErr {
 | 
				
			||||||
			return http.StatusRequestEntityTooLarge, backendErr
 | 
								return http.StatusRequestEntityTooLarge, backendErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user