mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	Merge pull request #112 from abiosoft/master
Gzip: Added compression level, extension and path filters.
This commit is contained in:
		
						commit
						995a2ea618
					
				@ -1,13 +1,84 @@
 | 
				
			|||||||
package setup
 | 
					package setup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/mholt/caddy/middleware"
 | 
						"github.com/mholt/caddy/middleware"
 | 
				
			||||||
	"github.com/mholt/caddy/middleware/gzip"
 | 
						"github.com/mholt/caddy/middleware/gzip"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Gzip configures a new gzip middleware instance.
 | 
					// Gzip configures a new gzip middleware instance.
 | 
				
			||||||
func Gzip(c *Controller) (middleware.Middleware, error) {
 | 
					func Gzip(c *Controller) (middleware.Middleware, error) {
 | 
				
			||||||
 | 
						configs, err := gzipParse(c)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return func(next middleware.Handler) middleware.Handler {
 | 
						return func(next middleware.Handler) middleware.Handler {
 | 
				
			||||||
		return gzip.Gzip{Next: next}
 | 
							return gzip.Gzip{Next: next, Configs: configs}
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func gzipParse(c *Controller) ([]gzip.Config, error) {
 | 
				
			||||||
 | 
						var configs []gzip.Config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for c.Next() {
 | 
				
			||||||
 | 
							config := gzip.Config{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							pathFilter := gzip.PathFilter{make(gzip.Set)}
 | 
				
			||||||
 | 
							extFilter := gzip.DefaultExtFilter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// no extra args expected
 | 
				
			||||||
 | 
							if len(c.RemainingArgs()) > 0 {
 | 
				
			||||||
 | 
								return configs, c.ArgErr()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for c.NextBlock() {
 | 
				
			||||||
 | 
								switch c.Val() {
 | 
				
			||||||
 | 
								case "ext":
 | 
				
			||||||
 | 
									exts := c.RemainingArgs()
 | 
				
			||||||
 | 
									if len(exts) == 0 {
 | 
				
			||||||
 | 
										return configs, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									for _, e := range exts {
 | 
				
			||||||
 | 
										if !strings.HasPrefix(e, ".") {
 | 
				
			||||||
 | 
											return configs, fmt.Errorf(`Invalid extension %v. Should start with "."`, e)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										extFilter.Exts.Add(e)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "not":
 | 
				
			||||||
 | 
									paths := c.RemainingArgs()
 | 
				
			||||||
 | 
									if len(paths) == 0 {
 | 
				
			||||||
 | 
										return configs, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									for _, p := range paths {
 | 
				
			||||||
 | 
										if !strings.HasPrefix(p, "/") {
 | 
				
			||||||
 | 
											return configs, fmt.Errorf(`Invalid path %v. Should start with "/"`, p)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										pathFilter.IgnoredPaths.Add(p)
 | 
				
			||||||
 | 
										// Warn user if / is used
 | 
				
			||||||
 | 
										if p == "/" {
 | 
				
			||||||
 | 
											fmt.Println("Warning: Paths ignored by gzip includes wildcard(/). No request will be gzipped.\nRemoving gzip directive from Caddyfile is preferred if this is intended.")
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "level":
 | 
				
			||||||
 | 
									if !c.NextArg() {
 | 
				
			||||||
 | 
										return configs, c.ArgErr()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									level, _ := strconv.Atoi(c.Val())
 | 
				
			||||||
 | 
									config.Level = level
 | 
				
			||||||
 | 
								default:
 | 
				
			||||||
 | 
									return configs, c.ArgErr()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// put pathFilter in front to filter with path first
 | 
				
			||||||
 | 
							config.Filters = []gzip.Filter{pathFilter, extFilter}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							configs = append(configs, config)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return configs, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -26,4 +26,47 @@ func TestGzip(t *testing.T) {
 | 
				
			|||||||
	if !sameNext(myHandler.Next, emptyNext) {
 | 
						if !sameNext(myHandler.Next, emptyNext) {
 | 
				
			||||||
		t.Error("'Next' field of handler was not set properly")
 | 
							t.Error("'Next' field of handler was not set properly")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							input     string
 | 
				
			||||||
 | 
							shouldErr bool
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{`gzip {`, true},
 | 
				
			||||||
 | 
							{`gzip {}`, true},
 | 
				
			||||||
 | 
							{`gzip a b`, true},
 | 
				
			||||||
 | 
							{`gzip a {`, true},
 | 
				
			||||||
 | 
							{`gzip { not f } `, true},
 | 
				
			||||||
 | 
							{`gzip { not } `, true},
 | 
				
			||||||
 | 
							{`gzip { not /file
 | 
				
			||||||
 | 
							 ext .html
 | 
				
			||||||
 | 
							 level 1
 | 
				
			||||||
 | 
							} `, false},
 | 
				
			||||||
 | 
							{`gzip { level 9 } `, false},
 | 
				
			||||||
 | 
							{`gzip { ext } `, true},
 | 
				
			||||||
 | 
							{`gzip { ext /f
 | 
				
			||||||
 | 
							} `, true},
 | 
				
			||||||
 | 
							{`gzip { not /file
 | 
				
			||||||
 | 
							 ext .html
 | 
				
			||||||
 | 
							 level 1
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							gzip`, false},
 | 
				
			||||||
 | 
							{`gzip { not /file
 | 
				
			||||||
 | 
							 ext .html
 | 
				
			||||||
 | 
							 level 1
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							gzip { not /file1
 | 
				
			||||||
 | 
							 ext .htm
 | 
				
			||||||
 | 
							 level 3
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							`, false},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i, test := range tests {
 | 
				
			||||||
 | 
							c := newTestController(test.input)
 | 
				
			||||||
 | 
							_, err := gzipParse(c)
 | 
				
			||||||
 | 
							if test.shouldErr && err == nil {
 | 
				
			||||||
 | 
								t.Errorf("Text %v: Expected error but found nil", i)
 | 
				
			||||||
 | 
							} else if !test.shouldErr && err != nil {
 | 
				
			||||||
 | 
								t.Errorf("Text %v: Expected no error but found error: ", i, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										90
									
								
								middleware/gzip/filter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								middleware/gzip/filter.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					package gzip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"path"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/mholt/caddy/middleware"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Filter determines if a request should be gzipped.
 | 
				
			||||||
 | 
					type Filter interface {
 | 
				
			||||||
 | 
						// ShouldCompress tells if compression gzip compression
 | 
				
			||||||
 | 
						// should be done on the request.
 | 
				
			||||||
 | 
						ShouldCompress(*http.Request) bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ExtFilter is Filter for file name extensions.
 | 
				
			||||||
 | 
					type ExtFilter struct {
 | 
				
			||||||
 | 
						// Exts is the file name extensions to accept
 | 
				
			||||||
 | 
						Exts Set
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// textExts is a list of extensions for text related files.
 | 
				
			||||||
 | 
					var textExts = []string{
 | 
				
			||||||
 | 
						".html", ".htm", ".css", ".json", ".php", ".js", ".txt", ".md", ".xml",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// extWildCard is the wildcard for extensions.
 | 
				
			||||||
 | 
					const extWildCard = "*"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DefaultExtFilter creates a default ExtFilter with
 | 
				
			||||||
 | 
					// file extensions for text types.
 | 
				
			||||||
 | 
					func DefaultExtFilter() ExtFilter {
 | 
				
			||||||
 | 
						e := ExtFilter{make(Set)}
 | 
				
			||||||
 | 
						for _, ext := range textExts {
 | 
				
			||||||
 | 
							e.Exts.Add(ext)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return e
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (e ExtFilter) ShouldCompress(r *http.Request) bool {
 | 
				
			||||||
 | 
						ext := path.Ext(r.URL.Path)
 | 
				
			||||||
 | 
						return e.Exts.Contains(extWildCard) || e.Exts.Contains(ext)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PathFilter is Filter for request path.
 | 
				
			||||||
 | 
					type PathFilter struct {
 | 
				
			||||||
 | 
						// IgnoredPaths is the paths to ignore
 | 
				
			||||||
 | 
						IgnoredPaths Set
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ShouldCompress checks if the request path matches any of the
 | 
				
			||||||
 | 
					// registered paths to ignore. If returns false if an ignored path
 | 
				
			||||||
 | 
					// is found and true otherwise.
 | 
				
			||||||
 | 
					func (p PathFilter) ShouldCompress(r *http.Request) bool {
 | 
				
			||||||
 | 
						return !p.IgnoredPaths.ContainsFunc(func(value string) bool {
 | 
				
			||||||
 | 
							return middleware.Path(r.URL.Path).Matches(value)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Set stores distinct strings.
 | 
				
			||||||
 | 
					type Set map[string]struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add adds an element to the set.
 | 
				
			||||||
 | 
					func (s Set) Add(value string) {
 | 
				
			||||||
 | 
						s[value] = struct{}{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Remove removes an element from the set.
 | 
				
			||||||
 | 
					func (s Set) Remove(value string) {
 | 
				
			||||||
 | 
						delete(s, value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Contains check if the set contains value.
 | 
				
			||||||
 | 
					func (s Set) Contains(value string) bool {
 | 
				
			||||||
 | 
						_, ok := s[value]
 | 
				
			||||||
 | 
						return ok
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ContainsFunc is similar to Contains. It iterates all the
 | 
				
			||||||
 | 
					// elements in the set and passes each to f. It returns true
 | 
				
			||||||
 | 
					// on the first call to f that returns true and false otherwise.
 | 
				
			||||||
 | 
					func (s Set) ContainsFunc(f func(string) bool) bool {
 | 
				
			||||||
 | 
						for k, _ := range s {
 | 
				
			||||||
 | 
							if f(k) {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										106
									
								
								middleware/gzip/filter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								middleware/gzip/filter_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,106 @@
 | 
				
			|||||||
 | 
					package gzip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSet(t *testing.T) {
 | 
				
			||||||
 | 
						set := make(Set)
 | 
				
			||||||
 | 
						set.Add("a")
 | 
				
			||||||
 | 
						if len(set) != 1 {
 | 
				
			||||||
 | 
							t.Errorf("Expected 1 found %v", len(set))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						set.Add("a")
 | 
				
			||||||
 | 
						if len(set) != 1 {
 | 
				
			||||||
 | 
							t.Errorf("Expected 1 found %v", len(set))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						set.Add("b")
 | 
				
			||||||
 | 
						if len(set) != 2 {
 | 
				
			||||||
 | 
							t.Errorf("Expected 2 found %v", len(set))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !set.Contains("a") {
 | 
				
			||||||
 | 
							t.Errorf("Set should contain a")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !set.Contains("b") {
 | 
				
			||||||
 | 
							t.Errorf("Set should contain a")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						set.Add("c")
 | 
				
			||||||
 | 
						if len(set) != 3 {
 | 
				
			||||||
 | 
							t.Errorf("Expected 3 found %v", len(set))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !set.Contains("c") {
 | 
				
			||||||
 | 
							t.Errorf("Set should contain c")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						set.Remove("a")
 | 
				
			||||||
 | 
						if len(set) != 2 {
 | 
				
			||||||
 | 
							t.Errorf("Expected 2 found %v", len(set))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if set.Contains("a") {
 | 
				
			||||||
 | 
							t.Errorf("Set should not contain a")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !set.ContainsFunc(func(v string) bool {
 | 
				
			||||||
 | 
							return v == "c"
 | 
				
			||||||
 | 
						}) {
 | 
				
			||||||
 | 
							t.Errorf("ContainsFunc should return true")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestExtFilter(t *testing.T) {
 | 
				
			||||||
 | 
						var filter Filter = DefaultExtFilter()
 | 
				
			||||||
 | 
						_ = filter.(ExtFilter)
 | 
				
			||||||
 | 
						for i, e := range textExts {
 | 
				
			||||||
 | 
							r := urlRequest("file" + e)
 | 
				
			||||||
 | 
							if !filter.ShouldCompress(r) {
 | 
				
			||||||
 | 
								t.Errorf("Test %v: Should be valid filter", i)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var exts = []string{
 | 
				
			||||||
 | 
							".html", ".css", ".md",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i, e := range exts {
 | 
				
			||||||
 | 
							r := urlRequest("file" + e)
 | 
				
			||||||
 | 
							if !filter.ShouldCompress(r) {
 | 
				
			||||||
 | 
								t.Errorf("Test %v: Should be valid filter", i)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						exts = []string{
 | 
				
			||||||
 | 
							".htm1", ".abc", ".mdx",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i, e := range exts {
 | 
				
			||||||
 | 
							r := urlRequest("file" + e)
 | 
				
			||||||
 | 
							if filter.ShouldCompress(r) {
 | 
				
			||||||
 | 
								t.Errorf("Test %v: Should not be valid filter", i)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestPathFilter(t *testing.T) {
 | 
				
			||||||
 | 
						paths := []string{
 | 
				
			||||||
 | 
							"/a", "/b", "/c", "/de",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var filter Filter = PathFilter{make(Set)}
 | 
				
			||||||
 | 
						for _, p := range paths {
 | 
				
			||||||
 | 
							filter.(PathFilter).IgnoredPaths.Add(p)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i, p := range paths {
 | 
				
			||||||
 | 
							r := urlRequest(p)
 | 
				
			||||||
 | 
							if filter.ShouldCompress(r) {
 | 
				
			||||||
 | 
								t.Errorf("Test %v: Should not be valid filter", i)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						paths = []string{
 | 
				
			||||||
 | 
							"/f", "/g", "/h", "/ed",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i, p := range paths {
 | 
				
			||||||
 | 
							r := urlRequest(p)
 | 
				
			||||||
 | 
							if !filter.ShouldCompress(r) {
 | 
				
			||||||
 | 
								t.Errorf("Test %v: Should be valid filter", i)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func urlRequest(url string) *http.Request {
 | 
				
			||||||
 | 
						r, _ := http.NewRequest("GET", url, nil)
 | 
				
			||||||
 | 
						return r
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -17,7 +17,14 @@ import (
 | 
				
			|||||||
// specifies the Content-Type, otherwise some clients will assume
 | 
					// specifies the Content-Type, otherwise some clients will assume
 | 
				
			||||||
// application/x-gzip and try to download a file.
 | 
					// application/x-gzip and try to download a file.
 | 
				
			||||||
type Gzip struct {
 | 
					type Gzip struct {
 | 
				
			||||||
	Next middleware.Handler
 | 
						Next    middleware.Handler
 | 
				
			||||||
 | 
						Configs []Config
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Config holds the configuration for Gzip middleware
 | 
				
			||||||
 | 
					type Config struct {
 | 
				
			||||||
 | 
						Filters []Filter // Filters to use
 | 
				
			||||||
 | 
						Level   int      // Compression level
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ServeHTTP serves a gzipped response if the client supports it.
 | 
					// ServeHTTP serves a gzipped response if the client supports it.
 | 
				
			||||||
@ -26,27 +33,56 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			|||||||
		return g.Next.ServeHTTP(w, r)
 | 
							return g.Next.ServeHTTP(w, r)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Delete this header so gzipping isn't repeated later in the chain
 | 
					outer:
 | 
				
			||||||
	r.Header.Del("Accept-Encoding")
 | 
						for _, c := range g.Configs {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	w.Header().Set("Content-Encoding", "gzip")
 | 
							// Check filters to determine if gzipping is permitted for this
 | 
				
			||||||
	gzipWriter := gzip.NewWriter(w)
 | 
							// request
 | 
				
			||||||
	defer gzipWriter.Close()
 | 
							for _, filter := range c.Filters {
 | 
				
			||||||
	gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w}
 | 
								if !filter.ShouldCompress(r) {
 | 
				
			||||||
 | 
									continue outer
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Any response in forward middleware will now be compressed
 | 
							// Delete this header so gzipping is not repeated later in the chain
 | 
				
			||||||
	status, err := g.Next.ServeHTTP(gz, r)
 | 
							r.Header.Del("Accept-Encoding")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If there was an error that remained unhandled, we need
 | 
							w.Header().Set("Content-Encoding", "gzip")
 | 
				
			||||||
	// to send something back before gzipWriter gets closed at
 | 
							gzipWriter, err := newWriter(c, w)
 | 
				
			||||||
	// the return of this method!
 | 
							if err != nil {
 | 
				
			||||||
	if status >= 400 {
 | 
								// should not happen
 | 
				
			||||||
		gz.Header().Set("Content-Type", "text/plain") // very necessary
 | 
								return http.StatusInternalServerError, err
 | 
				
			||||||
		gz.WriteHeader(status)
 | 
							}
 | 
				
			||||||
		fmt.Fprintf(gz, "%d %s", status, http.StatusText(status))
 | 
							defer gzipWriter.Close()
 | 
				
			||||||
		return 0, err
 | 
							gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Any response in forward middleware will now be compressed
 | 
				
			||||||
 | 
							status, err := g.Next.ServeHTTP(gz, r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// If there was an error that remained unhandled, we need
 | 
				
			||||||
 | 
							// to send something back before gzipWriter gets closed at
 | 
				
			||||||
 | 
							// the return of this method!
 | 
				
			||||||
 | 
							if status >= 400 {
 | 
				
			||||||
 | 
								gz.Header().Set("Content-Type", "text/plain") // very necessary
 | 
				
			||||||
 | 
								gz.WriteHeader(status)
 | 
				
			||||||
 | 
								fmt.Fprintf(gz, "%d %s", status, http.StatusText(status))
 | 
				
			||||||
 | 
								return 0, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return status, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return status, err
 | 
					
 | 
				
			||||||
 | 
						// no matching filter
 | 
				
			||||||
 | 
						return g.Next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// newWriter create a new Gzip Writer based on the compression level.
 | 
				
			||||||
 | 
					// If the level is valid (i.e. between 1 and 9), it uses the level.
 | 
				
			||||||
 | 
					// Otherwise, it uses default compression level.
 | 
				
			||||||
 | 
					func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) {
 | 
				
			||||||
 | 
						if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression {
 | 
				
			||||||
 | 
							return gzip.NewWriterLevel(w, c.Level)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return gzip.NewWriter(w), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// gzipResponeWriter wraps the underlying Write method
 | 
					// gzipResponeWriter wraps the underlying Write method
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										100
									
								
								middleware/gzip/gzip_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								middleware/gzip/gzip_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,100 @@
 | 
				
			|||||||
 | 
					package gzip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/mholt/caddy/middleware"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestGzipHandler(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pathFilter := PathFilter{make(Set)}
 | 
				
			||||||
 | 
						badPaths := []string{"/bad", "/nogzip", "/nongzip"}
 | 
				
			||||||
 | 
						for _, p := range badPaths {
 | 
				
			||||||
 | 
							pathFilter.IgnoredPaths.Add(p)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						gz := Gzip{Configs: []Config{
 | 
				
			||||||
 | 
							Config{Filters: []Filter{DefaultExtFilter(), pathFilter}},
 | 
				
			||||||
 | 
						}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						w := httptest.NewRecorder()
 | 
				
			||||||
 | 
						gz.Next = nextFunc(true)
 | 
				
			||||||
 | 
						for _, e := range textExts {
 | 
				
			||||||
 | 
							url := "/file" + e
 | 
				
			||||||
 | 
							r, err := http.NewRequest("GET", url, nil)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							r.Header.Set("Accept-Encoding", "gzip")
 | 
				
			||||||
 | 
							_, err = gz.ServeHTTP(w, r)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						w = httptest.NewRecorder()
 | 
				
			||||||
 | 
						gz.Next = nextFunc(false)
 | 
				
			||||||
 | 
						for _, p := range badPaths {
 | 
				
			||||||
 | 
							for _, e := range textExts {
 | 
				
			||||||
 | 
								url := p + "/file" + e
 | 
				
			||||||
 | 
								r, err := http.NewRequest("GET", url, nil)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Error(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								r.Header.Set("Accept-Encoding", "gzip")
 | 
				
			||||||
 | 
								_, err = gz.ServeHTTP(w, r)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Error(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						w = httptest.NewRecorder()
 | 
				
			||||||
 | 
						gz.Next = nextFunc(false)
 | 
				
			||||||
 | 
						exts := []string{
 | 
				
			||||||
 | 
							".htm1", ".abc", ".mdx",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, e := range exts {
 | 
				
			||||||
 | 
							url := "/file" + e
 | 
				
			||||||
 | 
							r, err := http.NewRequest("GET", url, nil)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							r.Header.Set("Accept-Encoding", "gzip")
 | 
				
			||||||
 | 
							_, err = gz.ServeHTTP(w, r)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func nextFunc(shouldGzip bool) middleware.Handler {
 | 
				
			||||||
 | 
						return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			||||||
 | 
							if shouldGzip {
 | 
				
			||||||
 | 
								if r.Header.Get("Accept-Encoding") != "" {
 | 
				
			||||||
 | 
									return 0, fmt.Errorf("Accept-Encoding header not expected")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if w.Header().Get("Content-Encoding") != "gzip" {
 | 
				
			||||||
 | 
									return 0, fmt.Errorf("Content-Encoding must be gzip, found %v", r.Header.Get("Content-Encoding"))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if _, ok := w.(gzipResponseWriter); !ok {
 | 
				
			||||||
 | 
									return 0, fmt.Errorf("ResponseWriter should be gzipResponseWriter, found %T", w)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return 0, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if r.Header.Get("Accept-Encoding") == "" {
 | 
				
			||||||
 | 
								return 0, fmt.Errorf("Accept-Encoding header expected")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if w.Header().Get("Content-Encoding") == "gzip" {
 | 
				
			||||||
 | 
								return 0, fmt.Errorf("Content-Encoding must not be gzip, found gzip")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if _, ok := w.(gzipResponseWriter); ok {
 | 
				
			||||||
 | 
								return 0, fmt.Errorf("ResponseWriter should not be gzipResponseWriter")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return 0, nil
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user