diff --git a/caddyhttp/basicauth/basicauth.go b/caddyhttp/basicauth/basicauth.go new file mode 100644 index 000000000..c75cc7d1b --- /dev/null +++ b/caddyhttp/basicauth/basicauth.go @@ -0,0 +1,148 @@ +// Package basicauth implements HTTP Basic Authentication. +package basicauth + +import ( + "bufio" + "crypto/subtle" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/jimstudt/http-authentication/basic" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// BasicAuth is middleware to protect resources with a username and password. +// Note that HTTP Basic Authentication is not secure by itself and should +// not be used to protect important assets without HTTPS. Even then, the +// security of HTTP Basic Auth is disputed. Use discretion when deciding +// what to protect with BasicAuth. +type BasicAuth struct { + Next httpserver.Handler + SiteRoot string + Rules []Rule +} + +// ServeHTTP implements the httpserver.Handler interface. +func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + + var hasAuth bool + var isAuthenticated bool + + for _, rule := range a.Rules { + for _, res := range rule.Resources { + if !httpserver.Path(r.URL.Path).Matches(res) { + continue + } + + // Path matches; parse auth header + username, password, ok := r.BasicAuth() + hasAuth = true + + // Check credentials + if !ok || + username != rule.Username || + !rule.Password(password) { + //subtle.ConstantTimeCompare([]byte(password), []byte(rule.Password)) != 1 { + continue + } + + // Flag set only on successful authentication + isAuthenticated = true + } + } + + if hasAuth { + if !isAuthenticated { + w.Header().Set("WWW-Authenticate", "Basic") + return http.StatusUnauthorized, nil + } + // "It's an older code, sir, but it checks out. I was about to clear them." + return a.Next.ServeHTTP(w, r) + } + + // Pass-thru when no paths match + return a.Next.ServeHTTP(w, r) +} + +// Rule represents a BasicAuth rule. A username and password +// combination protect the associated resources, which are +// file or directory paths. +type Rule struct { + Username string + Password func(string) bool + Resources []string +} + +// PasswordMatcher determines whether a password matches a rule. +type PasswordMatcher func(pw string) bool + +var ( + htpasswords map[string]map[string]PasswordMatcher + htpasswordsMu sync.Mutex +) + +// GetHtpasswdMatcher matches password rules. +func GetHtpasswdMatcher(filename, username, siteRoot string) (PasswordMatcher, error) { + filename = filepath.Join(siteRoot, filename) + htpasswordsMu.Lock() + if htpasswords == nil { + htpasswords = make(map[string]map[string]PasswordMatcher) + } + pm := htpasswords[filename] + if pm == nil { + fh, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("open %q: %v", filename, err) + } + defer fh.Close() + pm = make(map[string]PasswordMatcher) + if err = parseHtpasswd(pm, fh); err != nil { + return nil, fmt.Errorf("parsing htpasswd %q: %v", fh.Name(), err) + } + htpasswords[filename] = pm + } + htpasswordsMu.Unlock() + if pm[username] == nil { + return nil, fmt.Errorf("username %q not found in %q", username, filename) + } + return pm[username], nil +} + +func parseHtpasswd(pm map[string]PasswordMatcher, r io.Reader) error { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.IndexByte(line, '#') == 0 { + continue + } + i := strings.IndexByte(line, ':') + if i <= 0 { + return fmt.Errorf("malformed line, no color: %q", line) + } + user, encoded := line[:i], line[i+1:] + for _, p := range basic.DefaultSystems { + matcher, err := p(encoded) + if err != nil { + return err + } + if matcher != nil { + pm[user] = matcher.MatchesPassword + break + } + } + } + return scanner.Err() +} + +// PlainMatcher returns a PasswordMatcher that does a constant-time +// byte-wise comparison. +func PlainMatcher(passw string) PasswordMatcher { + return func(pw string) bool { + return subtle.ConstantTimeCompare([]byte(pw), []byte(passw)) == 1 + } +} diff --git a/caddyhttp/basicauth/basicauth_test.go b/caddyhttp/basicauth/basicauth_test.go new file mode 100644 index 000000000..182feabf9 --- /dev/null +++ b/caddyhttp/basicauth/basicauth_test.go @@ -0,0 +1,146 @@ +package basicauth + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestBasicAuth(t *testing.T) { + rw := BasicAuth{ + Next: httpserver.HandlerFunc(contentHandler), + Rules: []Rule{ + {Username: "test", Password: PlainMatcher("ttest"), Resources: []string{"/testing"}}, + }, + } + + tests := []struct { + from string + result int + cred string + }{ + {"/testing", http.StatusUnauthorized, "ttest:test"}, + {"/testing", http.StatusOK, "test:ttest"}, + {"/testing", http.StatusUnauthorized, ""}, + } + + for i, test := range tests { + + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request %v", i, err) + } + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred)) + req.Header.Set("Authorization", auth) + + rec := httptest.NewRecorder() + result, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) + } + if result != test.result { + t.Errorf("Test %d: Expected Header '%d' but was '%d'", + i, test.result, result) + } + if result == http.StatusUnauthorized { + headers := rec.Header() + if val, ok := headers["Www-Authenticate"]; ok { + if val[0] != "Basic" { + t.Errorf("Test %d, Www-Authenticate should be %s provided %s", i, "Basic", val[0]) + } + } else { + t.Errorf("Test %d, should provide a header Www-Authenticate", i) + } + } + + } + +} + +func TestMultipleOverlappingRules(t *testing.T) { + rw := BasicAuth{ + Next: httpserver.HandlerFunc(contentHandler), + Rules: []Rule{ + {Username: "t", Password: PlainMatcher("p1"), Resources: []string{"/t"}}, + {Username: "t1", Password: PlainMatcher("p2"), Resources: []string{"/t/t"}}, + }, + } + + tests := []struct { + from string + result int + cred string + }{ + {"/t", http.StatusOK, "t:p1"}, + {"/t/t", http.StatusOK, "t:p1"}, + {"/t/t", http.StatusOK, "t1:p2"}, + {"/a", http.StatusOK, "t1:p2"}, + {"/t/t", http.StatusUnauthorized, "t1:p3"}, + {"/t", http.StatusUnauthorized, "t1:p2"}, + } + + for i, test := range tests { + + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request %v", i, err) + } + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred)) + req.Header.Set("Authorization", auth) + + rec := httptest.NewRecorder() + result, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) + } + if result != test.result { + t.Errorf("Test %d: Expected Header '%d' but was '%d'", + i, test.result, result) + } + + } + +} + +func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprintf(w, r.URL.String()) + return http.StatusOK, nil +} + +func TestHtpasswd(t *testing.T) { + htpasswdPasswd := "IedFOuGmTpT8" + htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww= +md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61` + + htfh, err := ioutil.TempFile("", "basicauth-") + if err != nil { + t.Skipf("Error creating temp file (%v), will skip htpassword test") + return + } + defer os.Remove(htfh.Name()) + if _, err = htfh.Write([]byte(htpasswdFile)); err != nil { + t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err) + } + htfh.Close() + + for i, username := range []string{"sha1", "md5"} { + rule := Rule{Username: username, Resources: []string{"/testing"}} + + siteRoot := filepath.Dir(htfh.Name()) + filename := filepath.Base(htfh.Name()) + if rule.Password, err = GetHtpasswdMatcher(filename, rule.Username, siteRoot); err != nil { + t.Fatalf("GetHtpasswdMatcher(%q, %q): %v", htfh.Name(), rule.Username, err) + } + t.Logf("%d. username=%q", i, rule.Username) + if !rule.Password(htpasswdPasswd) || rule.Password(htpasswdPasswd+"!") { + t.Errorf("%d (%s) password does not match.", i, rule.Username) + } + } +} diff --git a/caddyhttp/basicauth/setup.go b/caddyhttp/basicauth/setup.go new file mode 100644 index 000000000..d911f3bd5 --- /dev/null +++ b/caddyhttp/basicauth/setup.go @@ -0,0 +1,83 @@ +package basicauth + +import ( + "strings" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "basicauth", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new BasicAuth middleware instance. +func setup(c *caddy.Controller) error { + cfg := httpserver.GetConfig(c.Key) + root := cfg.Root + + rules, err := basicAuthParse(c) + if err != nil { + return err + } + + basic := BasicAuth{Rules: rules} + + cfg.AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + basic.Next = next + basic.SiteRoot = root + return basic + }) + + return nil +} + +func basicAuthParse(c *caddy.Controller) ([]Rule, error) { + var rules []Rule + cfg := httpserver.GetConfig(c.Key) + + var err error + for c.Next() { + var rule Rule + + args := c.RemainingArgs() + + switch len(args) { + case 2: + rule.Username = args[0] + if rule.Password, err = passwordMatcher(rule.Username, args[1], cfg.Root); err != nil { + return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err) + } + + for c.NextBlock() { + rule.Resources = append(rule.Resources, c.Val()) + if c.NextArg() { + return rules, c.Errf("Expecting only one resource per line (extra '%s')", c.Val()) + } + } + case 3: + rule.Resources = append(rule.Resources, args[0]) + rule.Username = args[1] + if rule.Password, err = passwordMatcher(rule.Username, args[2], cfg.Root); err != nil { + return rules, c.Errf("Get password matcher from %s: %v", c.Val(), err) + } + default: + return rules, c.ArgErr() + } + + rules = append(rules, rule) + } + + return rules, nil +} + +func passwordMatcher(username, passw, siteRoot string) (PasswordMatcher, error) { + if !strings.HasPrefix(passw, "htpasswd=") { + return PlainMatcher(passw), nil + } + return GetHtpasswdMatcher(passw[9:], username, siteRoot) +} diff --git a/caddyhttp/basicauth/setup_test.go b/caddyhttp/basicauth/setup_test.go new file mode 100644 index 000000000..c1245a9e1 --- /dev/null +++ b/caddyhttp/basicauth/setup_test.go @@ -0,0 +1,131 @@ +package basicauth + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`basicauth user pwd`)) + if err != nil { + t.Errorf("Expected no errors, but got: %v", err) + } + mids := httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, got 0 instead") + } + + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(BasicAuth) + if !ok { + t.Fatalf("Expected handler to be type BasicAuth, got: %#v", handler) + } + + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } +} + +func TestBasicAuthParse(t *testing.T) { + htpasswdPasswd := "IedFOuGmTpT8" + htpasswdFile := `sha1:{SHA}dcAUljwz99qFjYR0YLTXx0RqLww= +md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61` + + var skipHtpassword bool + htfh, err := ioutil.TempFile(".", "basicauth-") + if err != nil { + t.Logf("Error creating temp file (%v), will skip htpassword test", err) + skipHtpassword = true + } else { + if _, err = htfh.Write([]byte(htpasswdFile)); err != nil { + t.Fatalf("write htpasswd file %q: %v", htfh.Name(), err) + } + htfh.Close() + defer os.Remove(htfh.Name()) + } + + tests := []struct { + input string + shouldErr bool + password string + expected []Rule + }{ + {`basicauth user pwd`, false, "pwd", []Rule{ + {Username: "user"}, + }}, + {`basicauth user pwd { + }`, false, "pwd", []Rule{ + {Username: "user"}, + }}, + {`basicauth user pwd { + /resource1 + /resource2 + }`, false, "pwd", []Rule{ + {Username: "user", Resources: []string{"/resource1", "/resource2"}}, + }}, + {`basicauth /resource user pwd`, false, "pwd", []Rule{ + {Username: "user", Resources: []string{"/resource"}}, + }}, + {`basicauth /res1 user1 pwd1 + basicauth /res2 user2 pwd2`, false, "pwd", []Rule{ + {Username: "user1", Resources: []string{"/res1"}}, + {Username: "user2", Resources: []string{"/res2"}}, + }}, + {`basicauth user`, true, "", []Rule{}}, + {`basicauth`, true, "", []Rule{}}, + {`basicauth /resource user pwd asdf`, true, "", []Rule{}}, + + {`basicauth sha1 htpasswd=` + htfh.Name(), false, htpasswdPasswd, []Rule{ + {Username: "sha1"}, + }}, + } + + for i, test := range tests { + actual, err := basicAuthParse(caddy.NewTestController(test.input)) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, expectedRule := range test.expected { + actualRule := actual[j] + + if actualRule.Username != expectedRule.Username { + t.Errorf("Test %d, rule %d: Expected username '%s', got '%s'", + i, j, expectedRule.Username, actualRule.Username) + } + + if strings.Contains(test.input, "htpasswd=") && skipHtpassword { + continue + } + pwd := test.password + if len(actual) > 1 { + pwd = fmt.Sprintf("%s%d", pwd, j+1) + } + if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") { + t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'", + i, j, test.password, actualRule.Password("")) + } + + expectedRes := fmt.Sprintf("%v", expectedRule.Resources) + actualRes := fmt.Sprintf("%v", actualRule.Resources) + if actualRes != expectedRes { + t.Errorf("Test %d, rule %d: Expected resource list %s, but got %s", + i, j, expectedRes, actualRes) + } + } + } +} diff --git a/caddyhttp/expvar/expvar.go b/caddyhttp/expvar/expvar.go new file mode 100644 index 000000000..d3107a048 --- /dev/null +++ b/caddyhttp/expvar/expvar.go @@ -0,0 +1,45 @@ +package expvar + +import ( + "expvar" + "fmt" + "net/http" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// ExpVar is a simple struct to hold expvar's configuration +type ExpVar struct { + Next httpserver.Handler + Resource Resource +} + +// ServeHTTP handles requests to expvar's configured entry point with +// expvar, or passes all other requests up the chain. +func (e ExpVar) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + if httpserver.Path(r.URL.Path).Matches(string(e.Resource)) { + expvarHandler(w, r) + return 0, nil + } + return e.Next.ServeHTTP(w, r) +} + +// expvarHandler returns a JSON object will all the published variables. +// +// This is lifted straight from the expvar package. +func expvarHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + fmt.Fprintf(w, "{\n") + first := true + expvar.Do(func(kv expvar.KeyValue) { + if !first { + fmt.Fprintf(w, ",\n") + } + first = false + fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value) + }) + fmt.Fprintf(w, "\n}\n") +} + +// Resource contains the path to the expvar entry point +type Resource string diff --git a/caddyhttp/expvar/expvar_test.go b/caddyhttp/expvar/expvar_test.go new file mode 100644 index 000000000..dfc7cb311 --- /dev/null +++ b/caddyhttp/expvar/expvar_test.go @@ -0,0 +1,46 @@ +package expvar + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestExpVar(t *testing.T) { + rw := ExpVar{ + Next: httpserver.HandlerFunc(contentHandler), + Resource: "/d/v", + } + + tests := []struct { + from string + result int + }{ + {"/d/v", 0}, + {"/x/y", http.StatusOK}, + } + + for i, test := range tests { + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request %v", i, err) + } + rec := httptest.NewRecorder() + result, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) + } + if result != test.result { + t.Errorf("Test %d: Expected Header '%d' but was '%d'", + i, test.result, result) + } + } +} + +func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprintf(w, r.URL.String()) + return http.StatusOK, nil +} diff --git a/caddyhttp/expvar/setup.go b/caddyhttp/expvar/setup.go new file mode 100644 index 000000000..4883d7ef3 --- /dev/null +++ b/caddyhttp/expvar/setup.go @@ -0,0 +1,70 @@ +package expvar + +import ( + "expvar" + "runtime" + "sync" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "expvar", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new ExpVar middleware instance. +func setup(c *caddy.Controller) error { + resource, err := expVarParse(c) + if err != nil { + return err + } + + // publish any extra information/metrics we may want to capture + publishExtraVars() + + ev := ExpVar{Resource: resource} + + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + ev.Next = next + return ev + }) + + return nil +} + +func expVarParse(c *caddy.Controller) (Resource, error) { + var resource Resource + var err error + + for c.Next() { + args := c.RemainingArgs() + switch len(args) { + case 0: + resource = Resource(defaultExpvarPath) + case 1: + resource = Resource(args[0]) + default: + return resource, c.ArgErr() + } + } + + return resource, err +} + +func publishExtraVars() { + // By using sync.Once instead of an init() function, we don't clutter + // the app's expvar export unnecessarily, or risk colliding with it. + publishOnce.Do(func() { + expvar.Publish("Goroutines", expvar.Func(func() interface{} { + return runtime.NumGoroutine() + })) + }) +} + +var publishOnce sync.Once // publishing variables should only be done once +var defaultExpvarPath = "/debug/vars" diff --git a/caddyhttp/expvar/setup_test.go b/caddyhttp/expvar/setup_test.go new file mode 100644 index 000000000..96dd4b038 --- /dev/null +++ b/caddyhttp/expvar/setup_test.go @@ -0,0 +1,40 @@ +package expvar + +import ( + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`expvar`)) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + mids := httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, got 0 instead") + } + + err = setup(caddy.NewTestController(`expvar /d/v`)) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + mids = httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, got 0 instead") + } + + handler := mids[1](httpserver.EmptyNext) + myHandler, ok := handler.(ExpVar) + if !ok { + t.Fatalf("Expected handler to be type ExpVar, got: %#v", handler) + } + if myHandler.Resource != "/d/v" { + t.Errorf("Expected /d/v as expvar resource") + } + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } +} diff --git a/caddyhttp/extensions/ext.go b/caddyhttp/extensions/ext.go new file mode 100644 index 000000000..46013b406 --- /dev/null +++ b/caddyhttp/extensions/ext.go @@ -0,0 +1,52 @@ +// Package extensions contains middleware for clean URLs. +// +// The root path of the site is passed in as well as possible extensions +// to try internally for paths requested that don't match an existing +// resource. The first path+ext combination that matches a valid file +// will be used. +package extensions + +import ( + "net/http" + "os" + "path" + "strings" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// Ext can assume an extension from clean URLs. +// It tries extensions in the order listed in Extensions. +type Ext struct { + // Next handler in the chain + Next httpserver.Handler + + // Path to ther root of the site + Root string + + // List of extensions to try + Extensions []string +} + +// ServeHTTP implements the httpserver.Handler interface. +func (e Ext) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + urlpath := strings.TrimSuffix(r.URL.Path, "/") + if path.Ext(urlpath) == "" && len(r.URL.Path) > 0 && r.URL.Path[len(r.URL.Path)-1] != '/' { + for _, ext := range e.Extensions { + if resourceExists(e.Root, urlpath+ext) { + r.URL.Path = urlpath + ext + break + } + } + } + return e.Next.ServeHTTP(w, r) +} + +// resourceExists returns true if the file specified at +// root + path exists; false otherwise. +func resourceExists(root, path string) bool { + _, err := os.Stat(root + path) + // technically we should use os.IsNotExist(err) + // but we don't handle any other kinds of errors anyway + return err == nil +} diff --git a/caddyhttp/extensions/setup.go b/caddyhttp/extensions/setup.go new file mode 100644 index 000000000..3d46e77e8 --- /dev/null +++ b/caddyhttp/extensions/setup.go @@ -0,0 +1,54 @@ +package extensions + +import ( + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "ext", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new instance of 'extensions' middleware for clean URLs. +func setup(c *caddy.Controller) error { + cfg := httpserver.GetConfig(c.Key) + root := cfg.Root + + exts, err := extParse(c) + if err != nil { + return err + } + + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Ext{ + Next: next, + Extensions: exts, + Root: root, + } + }) + + return nil +} + +// extParse sets up an instance of extension middleware +// from a middleware controller and returns a list of extensions. +func extParse(c *caddy.Controller) ([]string, error) { + var exts []string + + for c.Next() { + // At least one extension is required + if !c.NextArg() { + return exts, c.ArgErr() + } + exts = append(exts, c.Val()) + + // Tack on any other extensions that may have been listed + exts = append(exts, c.RemainingArgs()...) + } + + return exts, nil +} diff --git a/caddyhttp/extensions/setup_test.go b/caddyhttp/extensions/setup_test.go new file mode 100644 index 000000000..f6248beb5 --- /dev/null +++ b/caddyhttp/extensions/setup_test.go @@ -0,0 +1,74 @@ +package extensions + +import ( + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`ext .html .htm .php`)) + if err != nil { + t.Fatalf("Expected no errors, got: %v", err) + } + + mids := httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, had 0 instead") + } + + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(Ext) + + if !ok { + t.Fatalf("Expected handler to be type Ext, got: %#v", handler) + } + + if myHandler.Extensions[0] != ".html" { + t.Errorf("Expected .html in the list of Extensions") + } + if myHandler.Extensions[1] != ".htm" { + t.Errorf("Expected .htm in the list of Extensions") + } + if myHandler.Extensions[2] != ".php" { + t.Errorf("Expected .php in the list of Extensions") + } + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } + +} + +func TestExtParse(t *testing.T) { + tests := []struct { + inputExts string + shouldErr bool + expectedExts []string + }{ + {`ext .html .htm .php`, false, []string{".html", ".htm", ".php"}}, + {`ext .php .html .xml`, false, []string{".php", ".html", ".xml"}}, + {`ext .txt .php .xml`, false, []string{".txt", ".php", ".xml"}}, + } + for i, test := range tests { + actualExts, err := extParse(caddy.NewTestController(test.inputExts)) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } + + if len(actualExts) != len(test.expectedExts) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expectedExts), len(actualExts)) + } + for j, actualExt := range actualExts { + if actualExt != test.expectedExts[j] { + t.Fatalf("Test %d expected %dth extension to be %s , but got %s", + i, j, test.expectedExts[j], actualExt) + } + } + } + +} diff --git a/caddyhttp/header/header.go b/caddyhttp/header/header.go new file mode 100644 index 000000000..c51199d11 --- /dev/null +++ b/caddyhttp/header/header.go @@ -0,0 +1,51 @@ +// Package header provides middleware that appends headers to +// requests based on a set of configuration rules that define +// which routes receive which headers. +package header + +import ( + "net/http" + "strings" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// Headers is middleware that adds headers to the responses +// for requests matching a certain path. +type Headers struct { + Next httpserver.Handler + Rules []Rule +} + +// ServeHTTP implements the httpserver.Handler interface and serves requests, +// setting headers on the response according to the configured rules. +func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + replacer := httpserver.NewReplacer(r, nil, "") + for _, rule := range h.Rules { + if httpserver.Path(r.URL.Path).Matches(rule.Path) { + for _, header := range rule.Headers { + if strings.HasPrefix(header.Name, "-") { + w.Header().Del(strings.TrimLeft(header.Name, "-")) + } else { + w.Header().Set(header.Name, replacer.Replace(header.Value)) + } + } + } + } + return h.Next.ServeHTTP(w, r) +} + +type ( + // Rule groups a slice of HTTP headers by a URL pattern. + // TODO: use http.Header type instead? + Rule struct { + Path string + Headers []Header + } + + // Header represents a single HTTP header, simply a name and value. + Header struct { + Name string + Value string + } +) diff --git a/caddyhttp/header/header_test.go b/caddyhttp/header/header_test.go new file mode 100644 index 000000000..dd86a09cf --- /dev/null +++ b/caddyhttp/header/header_test.go @@ -0,0 +1,57 @@ +package header + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestHeader(t *testing.T) { + hostname, err := os.Hostname() + if err != nil { + t.Fatalf("Could not determine hostname: %v", err) + } + for i, test := range []struct { + from string + name string + value string + }{ + {"/a", "Foo", "Bar"}, + {"/a", "Bar", ""}, + {"/a", "Baz", ""}, + {"/a", "ServerName", hostname}, + {"/b", "Foo", ""}, + {"/b", "Bar", "Removed in /a"}, + } { + he := Headers{ + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + return 0, nil + }), + Rules: []Rule{ + {Path: "/a", Headers: []Header{ + {Name: "Foo", Value: "Bar"}, + {Name: "ServerName", Value: "{hostname}"}, + {Name: "-Bar"}, + }}, + }, + } + + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + rec.Header().Set("Bar", "Removed in /a") + + he.ServeHTTP(rec, req) + + if got := rec.Header().Get(test.name); got != test.value { + t.Errorf("Test %d: Expected %s header to be %q but was %q", + i, test.name, test.value, got) + } + } +} diff --git a/caddyhttp/header/setup.go b/caddyhttp/header/setup.go new file mode 100644 index 000000000..56034921c --- /dev/null +++ b/caddyhttp/header/setup.go @@ -0,0 +1,94 @@ +package header + +import ( + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "header", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new Headers middleware instance. +func setup(c *caddy.Controller) error { + rules, err := headersParse(c) + if err != nil { + return err + } + + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Headers{Next: next, Rules: rules} + }) + + return nil +} + +func headersParse(c *caddy.Controller) ([]Rule, error) { + var rules []Rule + + for c.NextLine() { + var head Rule + var isNewPattern bool + + if !c.NextArg() { + return rules, c.ArgErr() + } + pattern := c.Val() + + // See if we already have a definition for this Path pattern... + for _, h := range rules { + if h.Path == pattern { + head = h + break + } + } + + // ...otherwise, this is a new pattern + if head.Path == "" { + head.Path = pattern + isNewPattern = true + } + + for c.NextBlock() { + // A block of headers was opened... + + h := Header{Name: c.Val()} + + if c.NextArg() { + h.Value = c.Val() + } + + head.Headers = append(head.Headers, h) + } + if c.NextArg() { + // ... or single header was defined as an argument instead. + + h := Header{Name: c.Val()} + + h.Value = c.Val() + + if c.NextArg() { + h.Value = c.Val() + } + + head.Headers = append(head.Headers, h) + } + + if isNewPattern { + rules = append(rules, head) + } else { + for i := 0; i < len(rules); i++ { + if rules[i].Path == pattern { + rules[i] = head + break + } + } + } + } + + return rules, nil +} diff --git a/caddyhttp/header/setup_test.go b/caddyhttp/header/setup_test.go new file mode 100644 index 000000000..e3b6cbf19 --- /dev/null +++ b/caddyhttp/header/setup_test.go @@ -0,0 +1,85 @@ +package header + +import ( + "fmt" + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`header / Foo Bar`)) + if err != nil { + t.Errorf("Expected no errors, but got: %v", err) + } + + mids := httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, had 0 instead") + } + + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(Headers) + if !ok { + t.Fatalf("Expected handler to be type Headers, got: %#v", handler) + } + + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } +} + +func TestHeadersParse(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expected []Rule + }{ + {`header /foo Foo "Bar Baz"`, + false, []Rule{ + {Path: "/foo", Headers: []Header{ + {Name: "Foo", Value: "Bar Baz"}, + }}, + }}, + {`header /bar { Foo "Bar Baz" Baz Qux }`, + false, []Rule{ + {Path: "/bar", Headers: []Header{ + {Name: "Foo", Value: "Bar Baz"}, + {Name: "Baz", Value: "Qux"}, + }}, + }}, + } + + for i, test := range tests { + actual, err := headersParse(caddy.NewTestController(test.input)) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, expectedRule := range test.expected { + actualRule := actual[j] + + if actualRule.Path != expectedRule.Path { + t.Errorf("Test %d, rule %d: Expected path %s, but got %s", + i, j, expectedRule.Path, actualRule.Path) + } + + expectedHeaders := fmt.Sprintf("%v", expectedRule.Headers) + actualHeaders := fmt.Sprintf("%v", actualRule.Headers) + + if actualHeaders != expectedHeaders { + t.Errorf("Test %d, rule %d: Expected headers %s, but got %s", + i, j, expectedHeaders, actualHeaders) + } + } + } +} diff --git a/caddyhttp/internalsrv/internal.go b/caddyhttp/internalsrv/internal.go new file mode 100644 index 000000000..aa7b69a8d --- /dev/null +++ b/caddyhttp/internalsrv/internal.go @@ -0,0 +1,93 @@ +// Package internalsrv provides a simple middleware that (a) prevents access +// to internal locations and (b) allows to return files from internal location +// by setting a special header, e.g. in a proxy response. +// +// The package is named internalsrv so as not to conflict with Go tooling +// convention which treats folders called "internal" differently. +package internalsrv + +import ( + "net/http" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// Internal middleware protects internal locations from external requests - +// but allows access from the inside by using a special HTTP header. +type Internal struct { + Next httpserver.Handler + Paths []string +} + +const ( + redirectHeader string = "X-Accel-Redirect" + maxRedirectCount int = 10 +) + +func isInternalRedirect(w http.ResponseWriter) bool { + return w.Header().Get(redirectHeader) != "" +} + +// ServeHTTP implements the httpserver.Handler interface. +func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + + // Internal location requested? -> Not found. + for _, prefix := range i.Paths { + if httpserver.Path(r.URL.Path).Matches(prefix) { + return http.StatusNotFound, nil + } + } + + // Use internal response writer to ignore responses that will be + // redirected to internal locations + iw := internalResponseWriter{ResponseWriter: w} + status, err := i.Next.ServeHTTP(iw, r) + + for c := 0; c < maxRedirectCount && isInternalRedirect(iw); c++ { + // Redirect - adapt request URL path and send it again + // "down the chain" + r.URL.Path = iw.Header().Get(redirectHeader) + iw.ClearHeader() + + status, err = i.Next.ServeHTTP(iw, r) + } + + if isInternalRedirect(iw) { + // Too many redirect cycles + iw.ClearHeader() + return http.StatusInternalServerError, nil + } + + return status, err +} + +// internalResponseWriter wraps the underlying http.ResponseWriter and ignores +// calls to Write and WriteHeader if the response should be redirected to an +// internal location. +type internalResponseWriter struct { + http.ResponseWriter +} + +// ClearHeader removes all header fields that are already set. +func (w internalResponseWriter) ClearHeader() { + for k := range w.Header() { + w.Header().Del(k) + } +} + +// WriteHeader ignores the call if the response should be redirected to an +// internal location. +func (w internalResponseWriter) WriteHeader(code int) { + if !isInternalRedirect(w) { + w.ResponseWriter.WriteHeader(code) + } +} + +// Write ignores the call if the response should be redirected to an internal +// location. +func (w internalResponseWriter) Write(b []byte) (int, error) { + if isInternalRedirect(w) { + return 0, nil + } + return w.ResponseWriter.Write(b) +} diff --git a/caddyhttp/internalsrv/internal_test.go b/caddyhttp/internalsrv/internal_test.go new file mode 100644 index 000000000..fa9e05b43 --- /dev/null +++ b/caddyhttp/internalsrv/internal_test.go @@ -0,0 +1,64 @@ +package internalsrv + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestInternal(t *testing.T) { + im := Internal{ + Next: httpserver.HandlerFunc(internalTestHandlerFunc), + Paths: []string{"/internal"}, + } + + tests := []struct { + url string + expectedCode int + expectedBody string + }{ + {"/internal", http.StatusNotFound, ""}, + + {"/public", 0, "/public"}, + {"/public/internal", 0, "/public/internal"}, + + {"/redirect", 0, "/internal"}, + + {"/cycle", http.StatusInternalServerError, ""}, + } + + for i, test := range tests { + req, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + code, _ := im.ServeHTTP(rec, req) + + if code != test.expectedCode { + t.Errorf("Test %d: Expected status code %d for %s, but got %d", + i, test.expectedCode, test.url, code) + } + if rec.Body.String() != test.expectedBody { + t.Errorf("Test %d: Expected body '%s' for %s, but got '%s'", + i, test.expectedBody, test.url, rec.Body.String()) + } + } +} + +func internalTestHandlerFunc(w http.ResponseWriter, r *http.Request) (int, error) { + switch r.URL.Path { + case "/redirect": + w.Header().Set("X-Accel-Redirect", "/internal") + case "/cycle": + w.Header().Set("X-Accel-Redirect", "/cycle") + } + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, r.URL.String()) + + return 0, nil +} diff --git a/caddyhttp/internalsrv/setup.go b/caddyhttp/internalsrv/setup.go new file mode 100644 index 000000000..a77edce90 --- /dev/null +++ b/caddyhttp/internalsrv/setup.go @@ -0,0 +1,41 @@ +package internalsrv + +import ( + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "internal", + ServerType: "http", + Action: setup, + }) +} + +// Internal configures a new Internal middleware instance. +func setup(c *caddy.Controller) error { + paths, err := internalParse(c) + if err != nil { + return err + } + + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Internal{Next: next, Paths: paths} + }) + + return nil +} + +func internalParse(c *caddy.Controller) ([]string, error) { + var paths []string + + for c.Next() { + if !c.NextArg() { + return paths, c.ArgErr() + } + paths = append(paths, c.Val()) + } + + return paths, nil +} diff --git a/caddyhttp/internalsrv/setup_test.go b/caddyhttp/internalsrv/setup_test.go new file mode 100644 index 000000000..e67982ce0 --- /dev/null +++ b/caddyhttp/internalsrv/setup_test.go @@ -0,0 +1,69 @@ +package internalsrv + +import ( + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`internal /internal`)) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + mids := httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, got 0 instead") + } + + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(Internal) + + if !ok { + t.Fatalf("Expected handler to be type Internal, got: %#v", handler) + } + + if myHandler.Paths[0] != "/internal" { + t.Errorf("Expected internal in the list of internal Paths") + } + + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } + +} + +func TestInternalParse(t *testing.T) { + tests := []struct { + inputInternalPaths string + shouldErr bool + expectedInternalPaths []string + }{ + {`internal /internal`, false, []string{"/internal"}}, + + {`internal /internal1 + internal /internal2`, false, []string{"/internal1", "/internal2"}}, + } + for i, test := range tests { + actualInternalPaths, err := internalParse(caddy.NewTestController(test.inputInternalPaths)) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } + + if len(actualInternalPaths) != len(test.expectedInternalPaths) { + t.Fatalf("Test %d expected %d InternalPaths, but got %d", + i, len(test.expectedInternalPaths), len(actualInternalPaths)) + } + for j, actualInternalPath := range actualInternalPaths { + if actualInternalPath != test.expectedInternalPaths[j] { + t.Fatalf("Test %d expected %dth Internal Path to be %s , but got %s", + i, j, test.expectedInternalPaths[j], actualInternalPath) + } + } + } + +} diff --git a/caddyhttp/mime/mime.go b/caddyhttp/mime/mime.go new file mode 100644 index 000000000..b215fc8a0 --- /dev/null +++ b/caddyhttp/mime/mime.go @@ -0,0 +1,31 @@ +package mime + +import ( + "net/http" + "path" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// Config represent a mime config. Map from extension to mime-type. +// Note, this should be safe with concurrent read access, as this is +// not modified concurrently. +type Config map[string]string + +// Mime sets Content-Type header of requests based on configurations. +type Mime struct { + Next httpserver.Handler + Configs Config +} + +// ServeHTTP implements the httpserver.Handler interface. +func (e Mime) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + // Get a clean /-path, grab the extension + ext := path.Ext(path.Clean(r.URL.Path)) + + if contentType, ok := e.Configs[ext]; ok { + w.Header().Set("Content-Type", contentType) + } + + return e.Next.ServeHTTP(w, r) +} diff --git a/caddyhttp/mime/mime_test.go b/caddyhttp/mime/mime_test.go new file mode 100644 index 000000000..f97fffadc --- /dev/null +++ b/caddyhttp/mime/mime_test.go @@ -0,0 +1,69 @@ +package mime + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestMimeHandler(t *testing.T) { + mimes := Config{ + ".html": "text/html", + ".txt": "text/plain", + ".swf": "application/x-shockwave-flash", + } + + m := Mime{Configs: mimes} + + w := httptest.NewRecorder() + exts := []string{ + ".html", ".txt", ".swf", + } + for _, e := range exts { + url := "/file" + e + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + m.Next = nextFunc(true, mimes[e]) + _, err = m.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + } + + w = httptest.NewRecorder() + exts = []string{ + ".htm1", ".abc", ".mdx", + } + for _, e := range exts { + url := "/file" + e + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + m.Next = nextFunc(false, "") + _, err = m.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + } +} + +func nextFunc(shouldMime bool, contentType string) httpserver.Handler { + return httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + if shouldMime { + if w.Header().Get("Content-Type") != contentType { + return 0, fmt.Errorf("expected Content-Type: %v, found %v", contentType, r.Header.Get("Content-Type")) + } + return 0, nil + } + if w.Header().Get("Content-Type") != "" { + return 0, fmt.Errorf("Content-Type header not expected") + } + return 0, nil + }) +} diff --git a/caddyhttp/mime/setup.go b/caddyhttp/mime/setup.go new file mode 100644 index 000000000..bb5c40e0f --- /dev/null +++ b/caddyhttp/mime/setup.go @@ -0,0 +1,75 @@ +package mime + +import ( + "fmt" + "strings" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "mime", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new mime middleware instance. +func setup(c *caddy.Controller) error { + configs, err := mimeParse(c) + if err != nil { + return err + } + + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Mime{Next: next, Configs: configs} + }) + + return nil +} + +func mimeParse(c *caddy.Controller) (Config, error) { + configs := Config{} + + for c.Next() { + // At least one extension is required + + args := c.RemainingArgs() + switch len(args) { + case 2: + if err := validateExt(configs, args[0]); err != nil { + return configs, err + } + configs[args[0]] = args[1] + case 1: + return configs, c.ArgErr() + case 0: + for c.NextBlock() { + ext := c.Val() + if err := validateExt(configs, ext); err != nil { + return configs, err + } + if !c.NextArg() { + return configs, c.ArgErr() + } + configs[ext] = c.Val() + } + } + + } + + return configs, nil +} + +// validateExt checks for valid file name extension. +func validateExt(configs Config, ext string) error { + if !strings.HasPrefix(ext, ".") { + return fmt.Errorf(`mime: invalid extension "%v" (must start with dot)`, ext) + } + if _, ok := configs[ext]; ok { + return fmt.Errorf(`mime: duplicate extension "%v" found`, ext) + } + return nil +} diff --git a/caddyhttp/mime/setup_test.go b/caddyhttp/mime/setup_test.go new file mode 100644 index 000000000..3d1fce605 --- /dev/null +++ b/caddyhttp/mime/setup_test.go @@ -0,0 +1,62 @@ +package mime + +import ( + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`mime .txt text/plain`)) + if err != nil { + t.Errorf("Expected no errors, but got: %v", err) + } + mids := httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, but had 0 instead") + } + + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(Mime) + if !ok { + t.Fatalf("Expected handler to be type Mime, got: %#v", handler) + } + + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } + + tests := []struct { + input string + shouldErr bool + }{ + {`mime {`, true}, + {`mime {}`, true}, + {`mime a b`, true}, + {`mime a {`, true}, + {`mime { txt f } `, true}, + {`mime { html } `, true}, + {`mime { + .html text/html + .txt text/plain + } `, false}, + {`mime { + .foo text/foo + .bar text/bar + .foo text/foobar + } `, true}, + {`mime { .html text/html } `, false}, + {`mime { .html + } `, true}, + {`mime .txt text/plain`, false}, + } + for i, test := range tests { + m, err := mimeParse(caddy.NewTestController(test.input)) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil %v", i, m) + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + } + } +} diff --git a/caddyhttp/pprof/pprof.go b/caddyhttp/pprof/pprof.go new file mode 100644 index 000000000..9ac089423 --- /dev/null +++ b/caddyhttp/pprof/pprof.go @@ -0,0 +1,41 @@ +package pprof + +import ( + "net/http" + pp "net/http/pprof" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// BasePath is the base path to match for all pprof requests. +const BasePath = "/debug/pprof" + +// Handler is a simple struct whose ServeHTTP will delegate pprof +// endpoints to their equivalent net/http/pprof handlers. +type Handler struct { + Next httpserver.Handler + Mux *http.ServeMux +} + +// ServeHTTP handles requests to BasePath with pprof, or passes +// all other requests up the chain. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + if httpserver.Path(r.URL.Path).Matches(BasePath) { + h.Mux.ServeHTTP(w, r) + return 0, nil + } + return h.Next.ServeHTTP(w, r) +} + +// NewMux returns a new http.ServeMux that routes pprof requests. +// It pretty much copies what the std lib pprof does on init: +// https://golang.org/src/net/http/pprof/pprof.go#L67 +func NewMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc(BasePath+"/", pp.Index) + mux.HandleFunc(BasePath+"/cmdline", pp.Cmdline) + mux.HandleFunc(BasePath+"/profile", pp.Profile) + mux.HandleFunc(BasePath+"/symbol", pp.Symbol) + mux.HandleFunc(BasePath+"/trace", pp.Trace) + return mux +} diff --git a/caddyhttp/pprof/pprof_test.go b/caddyhttp/pprof/pprof_test.go new file mode 100644 index 000000000..816658694 --- /dev/null +++ b/caddyhttp/pprof/pprof_test.go @@ -0,0 +1,55 @@ +package pprof + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestServeHTTP(t *testing.T) { + h := Handler{ + Next: httpserver.HandlerFunc(nextHandler), + Mux: NewMux(), + } + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/debug/pprof", nil) + if err != nil { + t.Fatal(err) + } + status, err := h.ServeHTTP(w, r) + + if status != 0 { + t.Errorf("Expected status %d but got %d", 0, status) + } + if err != nil { + t.Errorf("Expected nil error, but got: %v", err) + } + if w.Body.String() == "content" { + t.Errorf("Expected pprof to handle request, but it didn't") + } + + w = httptest.NewRecorder() + r, err = http.NewRequest("GET", "/foo", nil) + if err != nil { + t.Fatal(err) + } + status, err = h.ServeHTTP(w, r) + if status != http.StatusNotFound { + t.Errorf("Test two: Expected status %d but got %d", http.StatusNotFound, status) + } + if err != nil { + t.Errorf("Test two: Expected nil error, but got: %v", err) + } + if w.Body.String() != "content" { + t.Errorf("Expected pprof to pass the request thru, but it didn't; got: %s", w.Body.String()) + } +} + +func nextHandler(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprintf(w, "content") + return http.StatusNotFound, nil +} diff --git a/caddyhttp/pprof/setup.go b/caddyhttp/pprof/setup.go new file mode 100644 index 000000000..1c82b856b --- /dev/null +++ b/caddyhttp/pprof/setup.go @@ -0,0 +1,38 @@ +package pprof + +import ( + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "pprof", + ServerType: "http", + Action: setup, + }) +} + +// setup returns a new instance of a pprof handler. It accepts no arguments or options. +func setup(c *caddy.Controller) error { + found := false + + for c.Next() { + if found { + return c.Err("pprof can only be specified once") + } + if len(c.RemainingArgs()) != 0 { + return c.ArgErr() + } + if c.NextBlock() { + return c.ArgErr() + } + found = true + } + + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return &Handler{Next: next, Mux: NewMux()} + }) + + return nil +} diff --git a/caddyhttp/pprof/setup_test.go b/caddyhttp/pprof/setup_test.go new file mode 100644 index 000000000..d53257d2b --- /dev/null +++ b/caddyhttp/pprof/setup_test.go @@ -0,0 +1,31 @@ +package pprof + +import ( + "testing" + + "github.com/mholt/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + }{ + {`pprof`, false}, + {`pprof {}`, true}, + {`pprof /foo`, true}, + {`pprof { + a b + }`, true}, + {`pprof + pprof`, true}, + } + for i, test := range tests { + err := setup(caddy.NewTestController(test.input)) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + } + } +} diff --git a/caddyhttp/proxy/policy.go b/caddyhttp/proxy/policy.go new file mode 100644 index 000000000..96e382a5c --- /dev/null +++ b/caddyhttp/proxy/policy.go @@ -0,0 +1,101 @@ +package proxy + +import ( + "math/rand" + "sync/atomic" +) + +// HostPool is a collection of UpstreamHosts. +type HostPool []*UpstreamHost + +// Policy decides how a host will be selected from a pool. +type Policy interface { + Select(pool HostPool) *UpstreamHost +} + +func init() { + RegisterPolicy("random", func() Policy { return &Random{} }) + RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) +} + +// Random is a policy that selects up hosts from a pool at random. +type Random struct{} + +// Select selects an up host at random from the specified pool. +func (r *Random) Select(pool HostPool) *UpstreamHost { + // instead of just generating a random index + // this is done to prevent selecting a unavailable host + var randHost *UpstreamHost + count := 0 + for _, host := range pool { + if !host.Available() { + continue + } + count++ + if count == 1 { + randHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + randHost = host + } + } + } + return randHost +} + +// LeastConn is a policy that selects the host with the least connections. +type LeastConn struct{} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (r *LeastConn) Select(pool HostPool) *UpstreamHost { + var bestHost *UpstreamHost + count := 0 + leastConn := int64(1<<63 - 1) + for _, host := range pool { + if !host.Available() { + continue + } + hostConns := host.Conns + if hostConns < leastConn { + bestHost = host + leastConn = hostConns + count = 1 + } else if hostConns == leastConn { + // randomly select host among hosts with least connections + count++ + if count == 1 { + bestHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + bestHost = host + } + } + } + } + return bestHost +} + +// RoundRobin is a policy that selects hosts based on round robin ordering. +type RoundRobin struct { + Robin uint32 +} + +// Select selects an up host from the pool using a round robin ordering scheme. +func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { + poolLen := uint32(len(pool)) + selection := atomic.AddUint32(&r.Robin, 1) % poolLen + host := pool[selection] + // if the currently selected host is not available, just ffwd to up host + for i := uint32(1); !host.Available() && i < poolLen; i++ { + host = pool[(selection+i)%poolLen] + } + if !host.Available() { + return nil + } + return host +} diff --git a/caddyhttp/proxy/policy_test.go b/caddyhttp/proxy/policy_test.go new file mode 100644 index 000000000..4cc05f029 --- /dev/null +++ b/caddyhttp/proxy/policy_test.go @@ -0,0 +1,98 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +var workableServer *httptest.Server + +func TestMain(m *testing.M) { + workableServer = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // do nothing + })) + r := m.Run() + workableServer.Close() + os.Exit(r) +} + +type customPolicy struct{} + +func (r *customPolicy) Select(pool HostPool) *UpstreamHost { + return pool[0] +} + +func testPool() HostPool { + pool := []*UpstreamHost{ + { + Name: workableServer.URL, // this should resolve (healthcheck test) + }, + { + Name: "http://shouldnot.resolve", // this shouldn't + }, + { + Name: "http://C", + }, + } + return HostPool(pool) +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := &RoundRobin{} + h := rrPolicy.Select(pool) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + h = rrPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected third round robin host to be first host in the pool.") + } + // mark host as down + pool[1].Unhealthy = true + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected to skip down host.") + } + // mark host as full + pool[2].Conns = 1 + pool[2].MaxConns = 1 + h = rrPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected to skip full host.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := &LeastConn{} + pool[0].Conns = 10 + pool[1].Conns = 10 + h := lcPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].Conns = 100 + h = lcPolicy.Select(pool) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} + +func TestCustomPolicy(t *testing.T) { + pool := testPool() + customPolicy := &customPolicy{} + h := customPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected custom policy host to be the first host.") + } +} diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go new file mode 100644 index 000000000..4cdce972f --- /dev/null +++ b/caddyhttp/proxy/proxy.go @@ -0,0 +1,242 @@ +// Package proxy is middleware that proxies HTTP requests. +package proxy + +import ( + "errors" + "net" + "net/http" + "net/url" + "strings" + "sync/atomic" + "time" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +var errUnreachable = errors.New("unreachable backend") + +// Proxy represents a middleware instance that can proxy requests. +type Proxy struct { + Next httpserver.Handler + Upstreams []Upstream +} + +// Upstream manages a pool of proxy upstream hosts. Select should return a +// suitable upstream host, or nil if no such hosts are available. +type Upstream interface { + // The path this upstream host should be routed on + From() string + // Selects an upstream host to be routed to. + Select() *UpstreamHost + // Checks if subpath is not an ignored path + AllowedPath(string) bool +} + +// UpstreamHostDownFunc can be used to customize how Down behaves. +type UpstreamHostDownFunc func(*UpstreamHost) bool + +// UpstreamHost represents a single proxy upstream +type UpstreamHost struct { + Conns int64 // must be first field to be 64-bit aligned on 32-bit systems + Name string // hostname of this upstream host + ReverseProxy *ReverseProxy + Fails int32 + FailTimeout time.Duration + Unhealthy bool + UpstreamHeaders http.Header + DownstreamHeaders http.Header + CheckDown UpstreamHostDownFunc + WithoutPathPrefix string + MaxConns int64 +} + +// Down checks whether the upstream host is down or not. +// Down will try to use uh.CheckDown first, and will fall +// back to some default criteria if necessary. +func (uh *UpstreamHost) Down() bool { + if uh.CheckDown == nil { + // Default settings + return uh.Unhealthy || uh.Fails > 0 + } + return uh.CheckDown(uh) +} + +// Full checks whether the upstream host has reached its maximum connections +func (uh *UpstreamHost) Full() bool { + return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns +} + +// Available checks whether the upstream host is available for proxying to +func (uh *UpstreamHost) Available() bool { + return !uh.Down() && !uh.Full() +} + +// tryDuration is how long to try upstream hosts; failures result in +// immediate retries until this duration ends or we get a nil host. +var tryDuration = 60 * time.Second + +// ServeHTTP satisfies the httpserver.Handler interface. +func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + for _, upstream := range p.Upstreams { + if !httpserver.Path(r.URL.Path).Matches(upstream.From()) || + !upstream.AllowedPath(r.URL.Path) { + continue + } + + var replacer httpserver.Replacer + start := time.Now() + + outreq := createUpstreamRequest(r) + + // Since Select() should give us "up" hosts, keep retrying + // hosts until timeout (or until we get a nil host). + for time.Now().Sub(start) < tryDuration { + host := upstream.Select() + if host == nil { + return http.StatusBadGateway, errUnreachable + } + if rr, ok := w.(*httpserver.ResponseRecorder); ok && rr.Replacer != nil { + rr.Replacer.Set("upstream", host.Name) + } + + outreq.Host = host.Name + if host.UpstreamHeaders != nil { + if replacer == nil { + rHost := r.Host + replacer = httpserver.NewReplacer(r, nil, "") + outreq.Host = rHost + } + if v, ok := host.UpstreamHeaders["Host"]; ok { + outreq.Host = replacer.Replace(v[len(v)-1]) + } + // Modify headers for request that will be sent to the upstream host + upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer) + for k, v := range upHeaders { + outreq.Header[k] = v + } + } + + var downHeaderUpdateFn respUpdateFn + if host.DownstreamHeaders != nil { + if replacer == nil { + rHost := r.Host + replacer = httpserver.NewReplacer(r, nil, "") + outreq.Host = rHost + } + //Creates a function that is used to update headers the response received by the reverse proxy + downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) + } + + proxy := host.ReverseProxy + if baseURL, err := url.Parse(host.Name); err == nil { + r.Host = baseURL.Host + if proxy == nil { + proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix) + } + } else if proxy == nil { + return http.StatusInternalServerError, err + } + + atomic.AddInt64(&host.Conns, 1) + backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) + atomic.AddInt64(&host.Conns, -1) + if backendErr == nil { + return 0, nil + } + timeout := host.FailTimeout + if timeout == 0 { + timeout = 10 * time.Second + } + atomic.AddInt32(&host.Fails, 1) + go func(host *UpstreamHost, timeout time.Duration) { + time.Sleep(timeout) + atomic.AddInt32(&host.Fails, -1) + }(host, timeout) + } + return http.StatusBadGateway, errUnreachable + } + + return p.Next.ServeHTTP(w, r) +} + +// createUpstremRequest shallow-copies r into a new request +// that can be sent upstream. +func createUpstreamRequest(r *http.Request) *http.Request { + outreq := new(http.Request) + *outreq = *r // includes shallow copies of maps, but okay + + // Restore URL Path if it has been modified + if outreq.URL.RawPath != "" { + outreq.URL.Opaque = outreq.URL.RawPath + } + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from r (shallow + // copied above) so we only copy it if necessary. + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, r.Header) + outreq.Header.Del(h) + } + } + + if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + // If we aren't the first proxy, retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + return outreq +} + +func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn { + return func(resp *http.Response) { + newHeaders := createHeadersByRules(rules, resp.Header, replacer) + for h, v := range newHeaders { + resp.Header[h] = v + } + } +} + +func createHeadersByRules(rules http.Header, base http.Header, repl httpserver.Replacer) http.Header { + newHeaders := make(http.Header) + for header, values := range rules { + if strings.HasPrefix(header, "+") { + header = strings.TrimLeft(header, "+") + add(newHeaders, header, base[header]) + applyEach(values, repl.Replace) + add(newHeaders, header, values) + } else if strings.HasPrefix(header, "-") { + base.Del(strings.TrimLeft(header, "-")) + } else if _, ok := base[header]; ok { + applyEach(values, repl.Replace) + for _, v := range values { + newHeaders.Set(header, v) + } + } else { + applyEach(values, repl.Replace) + add(newHeaders, header, values) + add(newHeaders, header, base[header]) + } + } + return newHeaders +} + +func applyEach(values []string, mapFn func(string) string) { + for i, v := range values { + values[i] = mapFn(v) + } +} + +func add(base http.Header, header string, values []string) { + for _, v := range values { + base.Add(header, v) + } +} diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go new file mode 100644 index 000000000..4c04fd2c2 --- /dev/null +++ b/caddyhttp/proxy/proxy_test.go @@ -0,0 +1,583 @@ +package proxy + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/mholt/caddy/caddyhttp/httpserver" + + "golang.org/x/net/websocket" +) + +func init() { + tryDuration = 50 * time.Millisecond // prevent tests from hanging +} + +func TestReverseProxy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Expected backend to receive request, but it didn't") + } + + // Make sure {upstream} placeholder is set + rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) + rr.Replacer = httpserver.NewReplacer(r, rr, "-") + + p.ServeHTTP(rr, r) + + if got, want := rr.Replacer.Replace("{upstream}"), backend.URL; got != want { + t.Errorf("Expected custom placeholder {upstream} to be set (%s), but it wasn't; got: %s", want, got) + } +} + +func TestReverseProxyInsecureSkipVerify(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, true)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't") + } +} + +func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { + // No-op websocket backend simply allows the WS connection to be + // accepted then it will be immediately closed. Perfect for testing. + wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) + defer wsNop.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsNop.URL) + + // Create client request + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + r.Header = http.Header{ + "Connection": {"Upgrade"}, + "Upgrade": {"websocket"}, + "Origin": {wsNop.URL}, + "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, + "Sec-WebSocket-Version": {"13"}, + } + + // Capture the request + w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)} + + // Booya! Do the test. + p.ServeHTTP(w, r) + + // Make sure the backend accepted the WS connection. + // Mostly interested in the Upgrade and Connection response headers + // and the 101 status code. + expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n") + actual := w.fakeConn.writeBuf.Bytes() + if !bytes.Equal(actual, expected) { + t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) + } +} + +func TestWebSocketReverseProxyFromWSClient(t *testing.T) { + // Echo server allows us to test that socket bytes are properly + // being proxied. + wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsEcho.URL) + + // This is a full end-end test, so the proxy handler + // has to be part of a server listening on a port. Our + // WS client will connect to this test server, not + // the echo client directly. + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "http://", "ws://", 1) + ws, err := websocket.Dial(url, "", echoProxy.URL) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + websocket.Message.Send(ws, trialMsg) + + // It should be echoed back to us + var actualMsg string + websocket.Message.Receive(ws, &actualMsg) + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func TestUnixSocketProxy(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + + trialMsg := "Is it working?" + + var proxySuccess bool + + // This is our fake "application" we want to proxy to + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Request was proxied when this is called + proxySuccess = true + + fmt.Fprint(w, trialMsg) + })) + + // Get absolute path for unix: socket + socketPath, err := filepath.Abs("./test_socket") + if err != nil { + t.Fatalf("Unable to get absolute path: %v", err) + } + + // Change httptest.Server listener to listen to unix: socket + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Unable to listen: %v", err) + } + ts.Listener = ln + + ts.Start() + defer ts.Close() + + url := strings.Replace(ts.URL, "http://", "unix:", 1) + p := newWebSocketTestProxy(url) + + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + res, err := http.Get(echoProxy.URL) + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + actualMsg := fmt.Sprintf("%s", greeting) + + if !proxySuccess { + t.Errorf("Expected request to be proxied, but it wasn't") + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, messageFormat, r.URL.String()) + })) + + return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts +} + +func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, error) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, messageFormat, r.URL.String()) + })) + + socketPath, err := filepath.Abs("./test_socket") + if err != nil { + return nil, nil, fmt.Errorf("Unable to get absolute path: %v", err) + } + + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, nil, fmt.Errorf("Unable to listen: %v", err) + } + ts.Listener = ln + + ts.Start() + + tsURL := strings.Replace(ts.URL, "http://", "unix:", 1) + + return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, nil +} + +func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) { + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + + // *httptest.Server is passed so it can be `defer`red properly + defer ts.Close() + defer echoProxy.Close() + + res, err := http.Get(echoProxy.URL + path) + if err != nil { + return "", fmt.Errorf("Unable to GET: %v", err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + return "", fmt.Errorf("Unable to read body: %v", err) + } + + return fmt.Sprintf("%s", greeting), nil +} + +func TestUnixSocketProxyPaths(t *testing.T) { + greeting := "Hello route %s" + + tests := []struct { + url string + prefix string + expected string + }{ + {"", "", fmt.Sprintf(greeting, "/")}, + {"/hello", "", fmt.Sprintf(greeting, "/hello")}, + {"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")}, + {"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")}, + {"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")}, + {"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")}, + {"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")}, + {"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")}, + {"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")}, + {"/queues/%2F/fetchtasks", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks")}, + {"/queues/%2F/fetchtasks?foo=bar", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks?foo=bar")}, + } + + for _, test := range tests { + p, ts := GetHTTPProxy(greeting, test.prefix) + + actualMsg, err := GetTestServerMessage(p, ts, test.url) + + if err != nil { + t.Fatalf("Getting server message failed - %v", err) + } + + if actualMsg != test.expected { + t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg) + } + } + + if runtime.GOOS == "windows" { + return + } + + for _, test := range tests { + p, ts, err := GetSocketProxy(greeting, test.prefix) + + if err != nil { + t.Fatalf("Getting socket proxy failed - %v", err) + } + + actualMsg, err := GetTestServerMessage(p, ts, test.url) + + if err != nil { + t.Fatalf("Getting server message failed - %v", err) + } + + if actualMsg != test.expected { + t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg) + } + } +} + +func TestUpstreamHeadersUpdate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var actualHeaders http.Header + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, client")) + actualHeaders = r.Header + })) + defer backend.Close() + + upstream := newFakeUpstream(backend.URL, false) + upstream.host.UpstreamHeaders = http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}, + "+Merge-Me": {"Merge-Value"}, + "+Add-Me": {"Add-Value"}, + "-Remove-Me": {""}, + "Replace-Me": {"{hostname}"}, + } + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{upstream}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + //add initial headers + r.Header.Add("Merge-Me", "Initial") + r.Header.Add("Remove-Me", "Remove-Value") + r.Header.Add("Replace-Me", "Replace-Value") + + p.ServeHTTP(w, r) + + replacer := httpserver.NewReplacer(r, nil, "") + + headerKey := "Merge-Me" + values, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey) + } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { + t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values) + } + + headerKey = "Add-Me" + if _, ok := actualHeaders[headerKey]; !ok { + t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey) + } + + headerKey = "Remove-Me" + if _, ok := actualHeaders[headerKey]; ok { + t.Errorf("Request sent to upstream backend should not contain %v header", headerKey) + } + + headerKey = "Replace-Me" + headerValue := replacer.Replace("{hostname}") + value, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Request sent to upstream backend should not remove %v header", headerKey) + } else if len(value) > 0 && headerValue != value[0] { + t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value) + } + +} + +func TestDownstreamHeadersUpdate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Merge-Me", "Initial") + w.Header().Add("Remove-Me", "Remove-Value") + w.Header().Add("Replace-Me", "Replace-Value") + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + upstream := newFakeUpstream(backend.URL, false) + upstream.host.DownstreamHeaders = http.Header{ + "+Merge-Me": {"Merge-Value"}, + "+Add-Me": {"Add-Value"}, + "-Remove-Me": {""}, + "Replace-Me": {"{hostname}"}, + } + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{upstream}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + replacer := httpserver.NewReplacer(r, nil, "") + actualHeaders := w.Header() + + headerKey := "Merge-Me" + values, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey) + } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { + t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values) + } + + headerKey = "Add-Me" + if _, ok := actualHeaders[headerKey]; !ok { + t.Errorf("Downstream response does not contain expected %v header", headerKey) + } + + headerKey = "Remove-Me" + if _, ok := actualHeaders[headerKey]; ok { + t.Errorf("Downstream response should not contain %v header received from upstream", headerKey) + } + + headerKey = "Replace-Me" + headerValue := replacer.Replace("{hostname}") + value, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Downstream response should contain %v header and not remove it", headerKey) + } else if len(value) > 0 && headerValue != value[0] { + t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value) + } + +} + +func newFakeUpstream(name string, insecure bool) *fakeUpstream { + uri, _ := url.Parse(name) + u := &fakeUpstream{ + name: name, + host: &UpstreamHost{ + Name: name, + ReverseProxy: NewSingleHostReverseProxy(uri, ""), + }, + } + if insecure { + u.host.ReverseProxy.Transport = InsecureTransport + } + return u +} + +type fakeUpstream struct { + name string + host *UpstreamHost +} + +func (u *fakeUpstream) From() string { + return "/" +} + +func (u *fakeUpstream) Select() *UpstreamHost { + return u.host +} + +func (u *fakeUpstream) AllowedPath(requestPath string) bool { + return true +} + +// newWebSocketTestProxy returns a test proxy that will +// redirect to the specified backendAddr. The function +// also sets up the rules/environment for testing WebSocket +// proxy. +func newWebSocketTestProxy(backendAddr string) *Proxy { + return &Proxy{ + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}}, + } +} + +func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { + return &Proxy{ + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}}, + } +} + +type fakeWsUpstream struct { + name string + without string +} + +func (u *fakeWsUpstream) From() string { + return "/" +} + +func (u *fakeWsUpstream) Select() *UpstreamHost { + uri, _ := url.Parse(u.name) + return &UpstreamHost{ + Name: u.name, + ReverseProxy: NewSingleHostReverseProxy(uri, u.without), + UpstreamHeaders: http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}}, + } +} + +func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { + return true +} + +// recorderHijacker is a ResponseRecorder that can +// be hijacked. +type recorderHijacker struct { + *httptest.ResponseRecorder + fakeConn *fakeConn +} + +func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rh.fakeConn, nil, nil +} + +type fakeConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +func (c *fakeConn) LocalAddr() net.Addr { return nil } +func (c *fakeConn) RemoteAddr() net.Addr { return nil } +func (c *fakeConn) SetDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } +func (c *fakeConn) Close() error { return nil } +func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } +func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go new file mode 100644 index 000000000..5a14aea79 --- /dev/null +++ b/caddyhttp/proxy/reverseproxy.go @@ -0,0 +1,269 @@ +// This file is adapted from code in the net/http/httputil +// package of the Go standard library, which is by the +// Go Authors, and bears this copyright and license info: +// +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file has been modified from the standard lib to +// meet the needs of the application. + +package proxy + +import ( + "crypto/tls" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + FlushInterval time.Duration +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// Though the relevant directive prefix is just "unix:", url.Parse +// will - assuming the regular URL scheme - add additional slashes +// as if "unix" was a request protocol. +// What we need is just the path, so if "unix:/var/run/www.socket" +// was the proxy directive, the parsed hostName would be +// "unix:///var/run/www.socket", hence the ambiguous trimming. +func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) { + return func(network, addr string) (conn net.Conn, err error) { + return net.Dial("unix", hostName[len("unix://"):]) + } +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// Without logic: target's path is "/", incoming is "/api/messages", +// without is "/api", then the target request will be for /messages. +func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + if target.Scheme == "unix" { + // to make Dial work with unix URL, + // scheme and host have to be faked + req.URL.Scheme = "http" + req.URL.Host = "socket" + } else { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + } + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + // Trims the path of the socket from the URL path. + // This is done because req.URL passed to your proxied service + // will have the full path of the socket file prefixed to it. + // Calling /test on a server that proxies requests to + // unix:/var/run/www.socket will thus set the requested path + // to /var/run/www.socket/test, rendering paths useless. + if target.Scheme == "unix" { + // See comment on socketDial for the trim + socketPrefix := target.String()[len("unix://"):] + req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix) + } + // We are then safe to remove the `without` prefix. + if without != "" { + req.URL.Path = strings.TrimPrefix(req.URL.Path, without) + } + } + rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events + if target.Scheme == "unix" { + rp.Transport = &http.Transport{ + Dial: socketDial(target.String()), + } + } + return rp +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +// InsecureTransport is used to facilitate HTTPS proxying +// when it is OK for upstream to be using a bad certificate, +// since this transport skips verification. +var InsecureTransport http.RoundTripper = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, +} + +type respUpdateFn func(resp *http.Response) + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + res, err := transport.RoundTrip(outreq) + if err != nil { + return err + } else if respUpdateFn != nil { + respUpdateFn(res) + } + + if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { + res.Body.Close() + hj, ok := rw.(http.Hijacker) + if !ok { + return nil + } + + conn, _, err := hj.Hijack() + if err != nil { + return err + } + defer conn.Close() + + backendConn, err := net.Dial("tcp", outreq.URL.Host) + if err != nil { + return err + } + defer backendConn.Close() + + outreq.Write(backendConn) + + go func() { + io.Copy(backendConn, conn) // write tcp stream to backend. + }() + io.Copy(conn, backendConn) // read tcp stream from backend. + } else { + defer res.Body.Close() + for _, h := range hopHeaders { + res.Header.Del(h) + } + copyHeader(rw.Header(), res.Header) + rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) + } + + return nil +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), + } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw + } + } + io.Copy(dst, src) +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration + + lk sync.Mutex // protects Write + Flush + done chan bool +} + +func (m *maxLatencyWriter) Write(p []byte) (int, error) { + m.lk.Lock() + defer m.lk.Unlock() + return m.dst.Write(p) +} + +func (m *maxLatencyWriter) flushLoop() { + t := time.NewTicker(m.latency) + defer t.Stop() + for { + select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return + case <-t.C: + m.lk.Lock() + m.dst.Flush() + m.lk.Unlock() + } + } +} + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/caddyhttp/proxy/setup.go b/caddyhttp/proxy/setup.go new file mode 100644 index 000000000..07d9ac953 --- /dev/null +++ b/caddyhttp/proxy/setup.go @@ -0,0 +1,26 @@ +package proxy + +import ( + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "proxy", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new Proxy middleware instance. +func setup(c *caddy.Controller) error { + upstreams, err := NewStaticUpstreams(c.Dispenser) + if err != nil { + return err + } + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Proxy{Next: next, Upstreams: upstreams} + }) + return nil +} diff --git a/caddyhttp/proxy/setup_test.go b/caddyhttp/proxy/setup_test.go new file mode 100644 index 000000000..c48d3479a --- /dev/null +++ b/caddyhttp/proxy/setup_test.go @@ -0,0 +1,140 @@ +package proxy + +import ( + "reflect" + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + for i, test := range []struct { + input string + shouldErr bool + expectedHosts map[string]struct{} + }{ + // test #0 test usual to destination still works normally + { + "proxy / localhost:80", + false, + map[string]struct{}{ + "http://localhost:80": {}, + }, + }, + + // test #1 test usual to destination with port range + { + "proxy / localhost:8080-8082", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + "http://localhost:8082": {}, + }, + }, + + // test #2 test upstream directive + { + "proxy / {\n upstream localhost:8080\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + }, + }, + + // test #3 test upstream directive with port range + { + "proxy / {\n upstream localhost:8080-8081\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + }, + }, + + // test #4 test to destination with upstream directive + { + "proxy / localhost:8080 {\n upstream localhost:8081-8082\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + "http://localhost:8082": {}, + }, + }, + + // test #5 test with unix sockets + { + "proxy / localhost:8080 {\n upstream unix:/var/foo\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "unix:/var/foo": {}, + }, + }, + + // test #6 test fail on malformed port range + { + "proxy / localhost:8090-8080", + true, + nil, + }, + + // test #7 test fail on malformed port range 2 + { + "proxy / {\n upstream localhost:80-A\n}", + true, + nil, + }, + + // test #8 test upstreams without ports work correctly + { + "proxy / http://localhost {\n upstream testendpoint\n}", + false, + map[string]struct{}{ + "http://localhost": {}, + "http://testendpoint": {}, + }, + }, + + // test #9 test several upstream directives + { + "proxy / localhost:8080 {\n upstream localhost:8081-8082\n upstream localhost:8083-8085\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + "http://localhost:8082": {}, + "http://localhost:8083": {}, + "http://localhost:8084": {}, + "http://localhost:8085": {}, + }, + }, + } { + err := setup(caddy.NewTestController(test.input)) + if err != nil && !test.shouldErr { + t.Errorf("Test case #%d received an error of %v", i, err) + } else if test.shouldErr { + continue + } + + mids := httpserver.GetConfig("").Middleware() + mid := mids[len(mids)-1] + + upstreams := mid(nil).(Proxy).Upstreams + for _, upstream := range upstreams { + val := reflect.ValueOf(upstream).Elem() + hosts := val.FieldByName("Hosts").Interface().(HostPool) + if len(hosts) != len(test.expectedHosts) { + t.Errorf("Test case #%d expected %d hosts but received %d", i, len(test.expectedHosts), len(hosts)) + } else { + for _, host := range hosts { + if _, found := test.expectedHosts[host.Name]; !found { + t.Errorf("Test case #%d has an unexpected host %s", i, host.Name) + } + } + } + } + } +} diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go new file mode 100644 index 000000000..4dc78e820 --- /dev/null +++ b/caddyhttp/proxy/upstream.go @@ -0,0 +1,345 @@ +package proxy + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/mholt/caddy/caddyfile" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +var ( + supportedPolicies = make(map[string]func() Policy) +) + +type staticUpstream struct { + from string + upstreamHeaders http.Header + downstreamHeaders http.Header + Hosts HostPool + Policy Policy + insecureSkipVerify bool + + FailTimeout time.Duration + MaxFails int32 + MaxConns int64 + HealthCheck struct { + Path string + Interval time.Duration + } + WithoutPathPrefix string + IgnoredSubPaths []string +} + +// NewStaticUpstreams parses the configuration input and sets up +// static upstreams for the proxy middleware. +func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { + var upstreams []Upstream + for c.Next() { + upstream := &staticUpstream{ + from: "", + upstreamHeaders: make(http.Header), + downstreamHeaders: make(http.Header), + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + MaxConns: 0, + } + + if !c.Args(&upstream.from) { + return upstreams, c.ArgErr() + } + + var to []string + for _, t := range c.RemainingArgs() { + parsed, err := parseUpstream(t) + if err != nil { + return upstreams, err + } + to = append(to, parsed...) + } + + for c.NextBlock() { + switch c.Val() { + case "upstream": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + parsed, err := parseUpstream(c.Val()) + if err != nil { + return upstreams, err + } + to = append(to, parsed...) + default: + if err := parseBlock(&c, upstream); err != nil { + return upstreams, err + } + } + } + + if len(to) == 0 { + return upstreams, c.ArgErr() + } + + upstream.Hosts = make([]*UpstreamHost, len(to)) + for i, host := range to { + uh, err := upstream.NewHost(host) + if err != nil { + return upstreams, err + } + upstream.Hosts[i] = uh + } + + if upstream.HealthCheck.Path != "" { + go upstream.HealthCheckWorker(nil) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy +} + +func (u *staticUpstream) From() string { + return u.from +} + +func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { + if !strings.HasPrefix(host, "http") && + !strings.HasPrefix(host, "unix:") { + host = "http://" + host + } + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: u.FailTimeout, + Unhealthy: false, + UpstreamHeaders: u.upstreamHeaders, + DownstreamHeaders: u.downstreamHeaders, + CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= u.MaxFails && + u.MaxFails != 0 { + return true + } + return false + } + }(u), + WithoutPathPrefix: u.WithoutPathPrefix, + MaxConns: u.MaxConns, + } + + baseURL, err := url.Parse(uh.Name) + if err != nil { + return nil, err + } + + uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix) + if u.insecureSkipVerify { + uh.ReverseProxy.Transport = InsecureTransport + } + return uh, nil +} + +func parseUpstream(u string) ([]string, error) { + if !strings.HasPrefix(u, "unix:") { + colonIdx := strings.LastIndex(u, ":") + protoIdx := strings.Index(u, "://") + + if colonIdx != -1 && colonIdx != protoIdx { + us := u[:colonIdx] + ports := u[len(us)+1:] + if separators := strings.Count(ports, "-"); separators > 1 { + return nil, fmt.Errorf("port range [%s] is invalid", ports) + } else if separators == 1 { + portsStr := strings.Split(ports, "-") + pIni, err := strconv.Atoi(portsStr[0]) + if err != nil { + return nil, err + } + + pEnd, err := strconv.Atoi(portsStr[1]) + if err != nil { + return nil, err + } + + if pEnd <= pIni { + return nil, fmt.Errorf("port range [%s] is invalid", ports) + } + + hosts := []string{} + for p := pIni; p <= pEnd; p++ { + hosts = append(hosts, fmt.Sprintf("%s:%d", us, p)) + } + return hosts, nil + } + } + } + + return []string{u}, nil + +} + +func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { + switch c.Val() { + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + policyCreateFunc, ok := supportedPolicies[c.Val()] + if !ok { + return c.ArgErr() + } + u.Policy = policyCreateFunc() + case "fail_timeout": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.FailTimeout = dur + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + u.MaxFails = int32(n) + case "max_conns": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.ParseInt(c.Val(), 10, 64) + if err != nil { + return err + } + u.MaxConns = n + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + u.HealthCheck.Path = c.Val() + u.HealthCheck.Interval = 30 * time.Second + if c.NextArg() { + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.HealthCheck.Interval = dur + } + case "header_upstream": + fallthrough + case "proxy_header": + var header, value string + if !c.Args(&header, &value) { + return c.ArgErr() + } + u.upstreamHeaders.Add(header, value) + case "header_downstream": + var header, value string + if !c.Args(&header, &value) { + return c.ArgErr() + } + u.downstreamHeaders.Add(header, value) + case "websocket": + u.upstreamHeaders.Add("Connection", "{>Connection}") + u.upstreamHeaders.Add("Upgrade", "{>Upgrade}") + case "without": + if !c.NextArg() { + return c.ArgErr() + } + u.WithoutPathPrefix = c.Val() + case "except": + ignoredPaths := c.RemainingArgs() + if len(ignoredPaths) == 0 { + return c.ArgErr() + } + u.IgnoredSubPaths = ignoredPaths + case "insecure_skip_verify": + u.insecureSkipVerify = true + default: + return c.Errf("unknown property '%s'", c.Val()) + } + return nil +} + +func (u *staticUpstream) healthCheck() { + for _, host := range u.Hosts { + hostURL := host.Name + u.HealthCheck.Path + if r, err := http.Get(hostURL); err == nil { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 + } else { + host.Unhealthy = true + } + } +} + +func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { + ticker := time.NewTicker(u.HealthCheck.Interval) + u.healthCheck() + for { + select { + case <-ticker.C: + u.healthCheck() + case <-stop: + // TODO: the library should provide a stop channel and global + // waitgroup to allow goroutines started by plugins a chance + // to clean themselves up. + } + } +} + +func (u *staticUpstream) Select() *UpstreamHost { + pool := u.Hosts + if len(pool) == 1 { + if !pool[0].Available() { + return nil + } + return pool[0] + } + allUnavailable := true + for _, host := range pool { + if host.Available() { + allUnavailable = false + break + } + } + if allUnavailable { + return nil + } + + if u.Policy == nil { + return (&Random{}).Select(pool) + } + return u.Policy.Select(pool) +} + +func (u *staticUpstream) AllowedPath(requestPath string) bool { + for _, ignoredSubPath := range u.IgnoredSubPaths { + if httpserver.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) { + return false + } + } + return true +} diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go new file mode 100644 index 000000000..9d38b785f --- /dev/null +++ b/caddyhttp/proxy/upstream_test.go @@ -0,0 +1,135 @@ +package proxy + +import ( + "testing" + "time" +) + +func TestNewHost(t *testing.T) { + upstream := &staticUpstream{ + FailTimeout: 10 * time.Second, + MaxConns: 1, + MaxFails: 1, + } + + uh, err := upstream.NewHost("example.com") + if err != nil { + t.Error("Expected no error") + } + if uh.Name != "http://example.com" { + t.Error("Expected default schema to be added to Name.") + } + if uh.FailTimeout != upstream.FailTimeout { + t.Error("Expected default FailTimeout to be set.") + } + if uh.MaxConns != upstream.MaxConns { + t.Error("Expected default MaxConns to be set.") + } + if uh.CheckDown == nil { + t.Error("Expected default CheckDown to be set.") + } + if uh.CheckDown(uh) { + t.Error("Expected new host not to be down.") + } + // mark Unhealthy + uh.Unhealthy = true + if !uh.CheckDown(uh) { + t.Error("Expected unhealthy host to be down.") + } + // mark with Fails + uh.Unhealthy = false + uh.Fails = 1 + if !uh.CheckDown(uh) { + t.Error("Expected failed host to be down.") + } +} + +func TestHealthCheck(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool(), + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.healthCheck() + if upstream.Hosts[0].Down() { + t.Error("Expected first host in testpool to not fail healthcheck.") + } + if !upstream.Hosts[1].Down() { + t.Error("Expected second host in testpool to fail healthcheck.") + } +} + +func TestSelect(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool()[:3], + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.Hosts[0].Unhealthy = true + upstream.Hosts[1].Unhealthy = true + upstream.Hosts[2].Unhealthy = true + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all host are down") + } + upstream.Hosts[2].Unhealthy = false + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } + upstream.Hosts[0].Conns = 1 + upstream.Hosts[0].MaxConns = 1 + upstream.Hosts[1].Conns = 1 + upstream.Hosts[1].MaxConns = 1 + upstream.Hosts[2].Conns = 1 + upstream.Hosts[2].MaxConns = 1 + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all hosts are full") + } + upstream.Hosts[2].Conns = 0 + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } +} + +func TestRegisterPolicy(t *testing.T) { + name := "custom" + customPolicy := &customPolicy{} + RegisterPolicy(name, func() Policy { return customPolicy }) + if _, ok := supportedPolicies[name]; !ok { + t.Error("Expected supportedPolicies to have a custom policy.") + } + +} + +func TestAllowedPaths(t *testing.T) { + upstream := &staticUpstream{ + from: "/proxy", + IgnoredSubPaths: []string{"/download", "/static"}, + } + tests := []struct { + url string + expected bool + }{ + {"/proxy", true}, + {"/proxy/dl", true}, + {"/proxy/download", false}, + {"/proxy/download/static", false}, + {"/proxy/static", false}, + {"/proxy/static/download", false}, + {"/proxy/something/download", true}, + {"/proxy/something/static", true}, + {"/proxy//static", false}, + {"/proxy//static//download", false}, + {"/proxy//download", false}, + } + + for i, test := range tests { + allowed := upstream.AllowedPath(test.url) + if test.expected != allowed { + t.Errorf("Test %d: expected %v found %v", i+1, test.expected, allowed) + } + } +} diff --git a/caddyhttp/redirect/redirect.go b/caddyhttp/redirect/redirect.go new file mode 100644 index 000000000..edb7caea5 --- /dev/null +++ b/caddyhttp/redirect/redirect.go @@ -0,0 +1,58 @@ +// Package redirect is middleware for redirecting certain requests +// to other locations. +package redirect + +import ( + "fmt" + "html" + "net/http" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// Redirect is middleware to respond with HTTP redirects +type Redirect struct { + Next httpserver.Handler + Rules []Rule +} + +// ServeHTTP implements the httpserver.Handler interface. +func (rd Redirect) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + for _, rule := range rd.Rules { + if (rule.FromPath == "/" || r.URL.Path == rule.FromPath) && schemeMatches(rule, r) { + to := httpserver.NewReplacer(r, nil, "").Replace(rule.To) + if rule.Meta { + safeTo := html.EscapeString(to) + fmt.Fprintf(w, metaRedir, safeTo, safeTo) + } else { + http.Redirect(w, r, to, rule.Code) + } + return 0, nil + } + } + return rd.Next.ServeHTTP(w, r) +} + +func schemeMatches(rule Rule, req *http.Request) bool { + return (rule.FromScheme == "https" && req.TLS != nil) || + (rule.FromScheme != "https" && req.TLS == nil) +} + +// Rule describes an HTTP redirect rule. +type Rule struct { + FromScheme, FromPath, To string + Code int + Meta bool +} + +// Script tag comes first since that will better imitate a redirect in the browser's +// history, but the meta tag is a fallback for most non-JS clients. +const metaRedir = ` + +
+ + + + Redirecting... + +` diff --git a/caddyhttp/redirect/redirect_test.go b/caddyhttp/redirect/redirect_test.go new file mode 100644 index 000000000..b6f8f74d0 --- /dev/null +++ b/caddyhttp/redirect/redirect_test.go @@ -0,0 +1,154 @@ +package redirect + +import ( + "bytes" + "crypto/tls" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestRedirect(t *testing.T) { + for i, test := range []struct { + from string + expectedLocation string + expectedCode int + }{ + {"http://localhost/from", "/to", http.StatusMovedPermanently}, + {"http://localhost/a", "/b", http.StatusTemporaryRedirect}, + {"http://localhost/aa", "", http.StatusOK}, + {"http://localhost/", "", http.StatusOK}, + {"http://localhost/a?foo=bar", "/b", http.StatusTemporaryRedirect}, + {"http://localhost/asdf?foo=bar", "", http.StatusOK}, + {"http://localhost/foo#bar", "", http.StatusOK}, + {"http://localhost/a#foo", "/b", http.StatusTemporaryRedirect}, + + // The scheme checks that were added to this package don't actually + // help with redirects because of Caddy's design: a redirect middleware + // for http will always be different than the redirect middleware for + // https because they have to be on different listeners. These tests + // just go to show extra bulletproofing, I guess. + {"http://localhost/scheme", "https://localhost/scheme", http.StatusMovedPermanently}, + {"https://localhost/scheme", "", http.StatusOK}, + {"https://localhost/scheme2", "http://localhost/scheme2", http.StatusMovedPermanently}, + {"http://localhost/scheme2", "", http.StatusOK}, + {"http://localhost/scheme3", "https://localhost/scheme3", http.StatusMovedPermanently}, + {"https://localhost/scheme3", "", http.StatusOK}, + } { + var nextCalled bool + + re := Redirect{ + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + nextCalled = true + return 0, nil + }), + Rules: []Rule{ + {FromPath: "/from", To: "/to", Code: http.StatusMovedPermanently}, + {FromPath: "/a", To: "/b", Code: http.StatusTemporaryRedirect}, + + // These http and https schemes would never actually be mixed in the same + // redirect rule with Caddy because http and https schemes have different listeners, + // so they don't share a redirect rule. So although these tests prove something + // impossible with Caddy, it's extra bulletproofing at very little cost. + {FromScheme: "http", FromPath: "/scheme", To: "https://localhost/scheme", Code: http.StatusMovedPermanently}, + {FromScheme: "https", FromPath: "/scheme2", To: "http://localhost/scheme2", Code: http.StatusMovedPermanently}, + {FromScheme: "", FromPath: "/scheme3", To: "https://localhost/scheme3", Code: http.StatusMovedPermanently}, + }, + } + + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + if strings.HasPrefix(test.from, "https://") { + req.TLS = new(tls.ConnectionState) // faux HTTPS + } + + rec := httptest.NewRecorder() + re.ServeHTTP(rec, req) + + if rec.Header().Get("Location") != test.expectedLocation { + t.Errorf("Test %d: Expected Location header to be %q but was %q", + i, test.expectedLocation, rec.Header().Get("Location")) + } + + if rec.Code != test.expectedCode { + t.Errorf("Test %d: Expected status code to be %d but was %d", + i, test.expectedCode, rec.Code) + } + + if nextCalled && test.expectedLocation != "" { + t.Errorf("Test %d: Next handler was unexpectedly called", i) + } + } +} + +func TestParametersRedirect(t *testing.T) { + re := Redirect{ + Rules: []Rule{ + {FromPath: "/", Meta: false, To: "http://example.com{uri}"}, + }, + } + + req, err := http.NewRequest("GET", "/a?b=c", nil) + if err != nil { + t.Fatalf("Test: Could not create HTTP request: %v", err) + } + + rec := httptest.NewRecorder() + re.ServeHTTP(rec, req) + + if rec.Header().Get("Location") != "http://example.com/a?b=c" { + t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a?b=c", rec.Header().Get("Location")) + } + + re = Redirect{ + Rules: []Rule{ + {FromPath: "/", Meta: false, To: "http://example.com/a{path}?b=c&{query}"}, + }, + } + + req, err = http.NewRequest("GET", "/d?e=f", nil) + if err != nil { + t.Fatalf("Test: Could not create HTTP request: %v", err) + } + + re.ServeHTTP(rec, req) + + if "http://example.com/a/d?b=c&e=f" != rec.Header().Get("Location") { + t.Fatalf("Test: expected location header %q but was %q", "http://example.com/a/d?b=c&e=f", rec.Header().Get("Location")) + } +} + +func TestMetaRedirect(t *testing.T) { + re := Redirect{ + Rules: []Rule{ + {FromPath: "/whatever", Meta: true, To: "/something"}, + {FromPath: "/", Meta: true, To: "https://example.com/"}, + }, + } + + for i, test := range re.Rules { + req, err := http.NewRequest("GET", test.FromPath, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + re.ServeHTTP(rec, req) + + body, err := ioutil.ReadAll(rec.Body) + if err != nil { + t.Fatalf("Test %d: Could not read HTTP response body: %v", i, err) + } + expectedSnippet := `` + if !bytes.Contains(body, []byte(expectedSnippet)) { + t.Errorf("Test %d: Expected Response Body to contain %q but was %q", + i, expectedSnippet, body) + } + } +} diff --git a/caddyhttp/redirect/setup.go b/caddyhttp/redirect/setup.go new file mode 100644 index 000000000..31fbd7afd --- /dev/null +++ b/caddyhttp/redirect/setup.go @@ -0,0 +1,185 @@ +package redirect + +import ( + "net/http" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "redir", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new Redirect middleware instance. +func setup(c *caddy.Controller) error { + rules, err := redirParse(c) + if err != nil { + return err + } + + httpserver.GetConfig(c.Key).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Redirect{Next: next, Rules: rules} + }) + + return nil +} + +func redirParse(c *caddy.Controller) ([]Rule, error) { + var redirects []Rule + + cfg := httpserver.GetConfig(c.Key) + + // setRedirCode sets the redirect code for rule if it can, or returns an error + setRedirCode := func(code string, rule *Rule) error { + if code == "meta" { + rule.Meta = true + } else if codeNumber, ok := httpRedirs[code]; ok { + rule.Code = codeNumber + } else { + return c.Errf("Invalid redirect code '%v'", code) + } + return nil + } + + // checkAndSaveRule checks the rule for validity (except the redir code) + // and saves it if it's valid, or returns an error. + checkAndSaveRule := func(rule Rule) error { + if rule.FromPath == rule.To { + return c.Err("'from' and 'to' values of redirect rule cannot be the same") + } + + for _, otherRule := range redirects { + if otherRule.FromPath == rule.FromPath { + return c.Errf("rule with duplicate 'from' value: %s -> %s", otherRule.FromPath, otherRule.To) + } + } + + redirects = append(redirects, rule) + return nil + } + + for c.Next() { + args := c.RemainingArgs() + + var hadOptionalBlock bool + for c.NextBlock() { + hadOptionalBlock = true + + var rule Rule + + if cfg.TLS.Enabled { + rule.FromScheme = "https" + } else { + rule.FromScheme = "http" + } + + // Set initial redirect code + // BUG: If the code is specified for a whole block and that code is invalid, + // the line number will appear on the first line inside the block, even if that + // line overwrites the block-level code with a valid redirect code. The program + // still functions correctly, but the line number in the error reporting is + // misleading to the user. + if len(args) == 1 { + err := setRedirCode(args[0], &rule) + if err != nil { + return redirects, err + } + } else { + rule.Code = http.StatusMovedPermanently // default code + } + + // RemainingArgs only gets the values after the current token, but in our + // case we want to include the current token to get an accurate count. + insideArgs := append([]string{c.Val()}, c.RemainingArgs()...) + + switch len(insideArgs) { + case 1: + // To specified (catch-all redirect) + // Not sure why user is doing this in a table, as it causes all other redirects to be ignored. + // As such, this feature remains undocumented. + rule.FromPath = "/" + rule.To = insideArgs[0] + case 2: + // From and To specified + rule.FromPath = insideArgs[0] + rule.To = insideArgs[1] + case 3: + // From, To, and Code specified + rule.FromPath = insideArgs[0] + rule.To = insideArgs[1] + err := setRedirCode(insideArgs[2], &rule) + if err != nil { + return redirects, err + } + default: + return redirects, c.ArgErr() + } + + err := checkAndSaveRule(rule) + if err != nil { + return redirects, err + } + } + + if !hadOptionalBlock { + var rule Rule + + if cfg.TLS.Enabled { + rule.FromScheme = "https" + } else { + rule.FromScheme = "http" + } + + rule.Code = http.StatusMovedPermanently // default + + switch len(args) { + case 1: + // To specified (catch-all redirect) + rule.FromPath = "/" + rule.To = args[0] + case 2: + // To and Code specified (catch-all redirect) + rule.FromPath = "/" + rule.To = args[0] + err := setRedirCode(args[1], &rule) + if err != nil { + return redirects, err + } + case 3: + // From, To, and Code specified + rule.FromPath = args[0] + rule.To = args[1] + err := setRedirCode(args[2], &rule) + if err != nil { + return redirects, err + } + default: + return redirects, c.ArgErr() + } + + err := checkAndSaveRule(rule) + if err != nil { + return redirects, err + } + } + } + + return redirects, nil +} + +// httpRedirs is a list of supported HTTP redirect codes. +var httpRedirs = map[string]int{ + "300": http.StatusMultipleChoices, + "301": http.StatusMovedPermanently, + "302": http.StatusFound, // (NOT CORRECT for "Temporary Redirect", see 307) + "303": http.StatusSeeOther, + "304": http.StatusNotModified, + "305": http.StatusUseProxy, + "307": http.StatusTemporaryRedirect, + "308": 308, // Permanent Redirect +} diff --git a/caddyhttp/redirect/setup_test.go b/caddyhttp/redirect/setup_test.go new file mode 100644 index 000000000..c4774cfaf --- /dev/null +++ b/caddyhttp/redirect/setup_test.go @@ -0,0 +1,69 @@ +package redirect + +import ( + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + + for j, test := range []struct { + input string + shouldErr bool + expectedRules []Rule + }{ + // test case #0 tests the recognition of a valid HTTP status code defined outside of block statement + {"redir 300 {\n/ /foo\n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 300}}}, + + // test case #1 tests the recognition of an invalid HTTP status code defined outside of block statement + {"redir 9000 {\n/ /foo\n}", true, []Rule{{}}}, + + // test case #2 tests the detection of a valid HTTP status code outside of a block statement being overriden by an invalid HTTP status code inside statement of a block statement + {"redir 300 {\n/ /foo 9000\n}", true, []Rule{{}}}, + + // test case #3 tests the detection of an invalid HTTP status code outside of a block statement being overriden by a valid HTTP status code inside statement of a block statement + {"redir 9000 {\n/ /foo 300\n}", true, []Rule{{}}}, + + // test case #4 tests the recognition of a TO redirection in a block statement.The HTTP status code is set to the default of 301 - MovedPermanently + {"redir 302 {\n/foo\n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 302}}}, + + // test case #5 tests the recognition of a TO and From redirection in a block statement + {"redir {\n/bar /foo 303\n}", false, []Rule{{FromPath: "/bar", To: "/foo", Code: 303}}}, + + // test case #6 tests the recognition of a TO redirection in a non-block statement. The HTTP status code is set to the default of 301 - MovedPermanently + {"redir /foo", false, []Rule{{FromPath: "/", To: "/foo", Code: 301}}}, + + // test case #7 tests the recognition of a TO and From redirection in a non-block statement + {"redir /bar /foo 303", false, []Rule{{FromPath: "/bar", To: "/foo", Code: 303}}}, + + // test case #8 tests the recognition of multiple redirections + {"redir {\n / /foo 304 \n} \n redir {\n /bar /foobar 305 \n}", false, []Rule{{FromPath: "/", To: "/foo", Code: 304}, {FromPath: "/bar", To: "/foobar", Code: 305}}}, + + // test case #9 tests the detection of duplicate redirections + {"redir {\n /bar /foo 304 \n} redir {\n /bar /foo 304 \n}", true, []Rule{{}}}, + } { + err := setup(caddy.NewTestController(test.input)) + if err != nil && !test.shouldErr { + t.Errorf("Test case #%d recieved an error of %v", j, err) + } else if test.shouldErr { + continue + } + mids := httpserver.GetConfig("").Middleware() + recievedRules := mids[len(mids)-1](nil).(Redirect).Rules + + for i, recievedRule := range recievedRules { + if recievedRule.FromPath != test.expectedRules[i].FromPath { + t.Errorf("Test case #%d.%d expected a from path of %s, but recieved a from path of %s", j, i, test.expectedRules[i].FromPath, recievedRule.FromPath) + } + if recievedRule.To != test.expectedRules[i].To { + t.Errorf("Test case #%d.%d expected a TO path of %s, but recieved a TO path of %s", j, i, test.expectedRules[i].To, recievedRule.To) + } + if recievedRule.Code != test.expectedRules[i].Code { + t.Errorf("Test case #%d.%d expected a HTTP status code of %d, but recieved a code of %d", j, i, test.expectedRules[i].Code, recievedRule.Code) + } + } + } + +} diff --git a/caddyhttp/rewrite/condition.go b/caddyhttp/rewrite/condition.go new file mode 100644 index 000000000..97b0e96aa --- /dev/null +++ b/caddyhttp/rewrite/condition.go @@ -0,0 +1,130 @@ +package rewrite + +import ( + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// Operators +const ( + Is = "is" + Not = "not" + Has = "has" + NotHas = "not_has" + StartsWith = "starts_with" + EndsWith = "ends_with" + Match = "match" + NotMatch = "not_match" +) + +func operatorError(operator string) error { + return fmt.Errorf("Invalid operator %v", operator) +} + +func newReplacer(r *http.Request) httpserver.Replacer { + return httpserver.NewReplacer(r, nil, "") +} + +// condition is a rewrite condition. +type condition func(string, string) bool + +var conditions = map[string]condition{ + Is: isFunc, + Not: notFunc, + Has: hasFunc, + NotHas: notHasFunc, + StartsWith: startsWithFunc, + EndsWith: endsWithFunc, + Match: matchFunc, + NotMatch: notMatchFunc, +} + +// isFunc is condition for Is operator. +// It checks for equality. +func isFunc(a, b string) bool { + return a == b +} + +// notFunc is condition for Not operator. +// It checks for inequality. +func notFunc(a, b string) bool { + return a != b +} + +// hasFunc is condition for Has operator. +// It checks if b is a substring of a. +func hasFunc(a, b string) bool { + return strings.Contains(a, b) +} + +// notHasFunc is condition for NotHas operator. +// It checks if b is not a substring of a. +func notHasFunc(a, b string) bool { + return !strings.Contains(a, b) +} + +// startsWithFunc is condition for StartsWith operator. +// It checks if b is a prefix of a. +func startsWithFunc(a, b string) bool { + return strings.HasPrefix(a, b) +} + +// endsWithFunc is condition for EndsWith operator. +// It checks if b is a suffix of a. +func endsWithFunc(a, b string) bool { + return strings.HasSuffix(a, b) +} + +// matchFunc is condition for Match operator. +// It does regexp matching of a against pattern in b +// and returns if they match. +func matchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return matched +} + +// notMatchFunc is condition for NotMatch operator. +// It does regexp matching of a against pattern in b +// and returns if they do not match. +func notMatchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return !matched +} + +// If is statement for a rewrite condition. +type If struct { + A string + Operator string + B string +} + +// True returns true if the condition is true and false otherwise. +// If r is not nil, it replaces placeholders before comparison. +func (i If) True(r *http.Request) bool { + if c, ok := conditions[i.Operator]; ok { + a, b := i.A, i.B + if r != nil { + replacer := newReplacer(r) + a = replacer.Replace(i.A) + b = replacer.Replace(i.B) + } + return c(a, b) + } + return false +} + +// NewIf creates a new If condition. +func NewIf(a, operator, b string) (If, error) { + if _, ok := conditions[operator]; !ok { + return If{}, operatorError(operator) + } + return If{ + A: a, + Operator: operator, + B: b, + }, nil +} diff --git a/caddyhttp/rewrite/condition_test.go b/caddyhttp/rewrite/condition_test.go new file mode 100644 index 000000000..3c3b6053a --- /dev/null +++ b/caddyhttp/rewrite/condition_test.go @@ -0,0 +1,106 @@ +package rewrite + +import ( + "net/http" + "strings" + "testing" +) + +func TestConditions(t *testing.T) { + tests := []struct { + condition string + isTrue bool + }{ + {"a is b", false}, + {"a is a", true}, + {"a not b", true}, + {"a not a", false}, + {"a has a", true}, + {"a has b", false}, + {"ba has b", true}, + {"bab has b", true}, + {"bab has bb", false}, + {"a not_has a", false}, + {"a not_has b", true}, + {"ba not_has b", false}, + {"bab not_has b", false}, + {"bab not_has bb", true}, + {"bab starts_with bb", false}, + {"bab starts_with ba", true}, + {"bab starts_with bab", true}, + {"bab ends_with bb", false}, + {"bab ends_with bab", true}, + {"bab ends_with ab", true}, + {"a match *", false}, + {"a match a", true}, + {"a match .*", true}, + {"a match a.*", true}, + {"a match b.*", false}, + {"ba match b.*", true}, + {"ba match b[a-z]", true}, + {"b0 match b[a-z]", false}, + {"b0a match b[a-z]", false}, + {"b0a match b[a-z]+", false}, + {"b0a match b[a-z0-9]+", true}, + {"a not_match *", true}, + {"a not_match a", false}, + {"a not_match .*", false}, + {"a not_match a.*", false}, + {"a not_match b.*", true}, + {"ba not_match b.*", false}, + {"ba not_match b[a-z]", false}, + {"b0 not_match b[a-z]", true}, + {"b0a not_match b[a-z]", true}, + {"b0a not_match b[a-z]+", true}, + {"b0a not_match b[a-z0-9]+", false}, + } + + for i, test := range tests { + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(nil) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } + + invalidOperators := []string{"ss", "and", "if"} + for _, op := range invalidOperators { + _, err := NewIf("a", op, "b") + if err == nil { + t.Errorf("Invalid operator %v used, expected error.", op) + } + } + + replaceTests := []struct { + url string + condition string + isTrue bool + }{ + {"/home", "{uri} match /home", true}, + {"/hom", "{uri} match /home", false}, + {"/hom", "{uri} starts_with /home", false}, + {"/hom", "{uri} starts_with /h", true}, + {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true}, + {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true}, + } + + for i, test := range replaceTests { + r, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Error(err) + } + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(r) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } +} diff --git a/caddyhttp/rewrite/rewrite.go b/caddyhttp/rewrite/rewrite.go new file mode 100644 index 000000000..7567f5d85 --- /dev/null +++ b/caddyhttp/rewrite/rewrite.go @@ -0,0 +1,236 @@ +// Package rewrite is middleware for rewriting requests internally to +// a different path. +package rewrite + +import ( + "fmt" + "net/http" + "net/url" + "path" + "path/filepath" + "regexp" + "strings" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// Result is the result of a rewrite +type Result int + +const ( + // RewriteIgnored is returned when rewrite is not done on request. + RewriteIgnored Result = iota + // RewriteDone is returned when rewrite is done on request. + RewriteDone + // RewriteStatus is returned when rewrite is not needed and status code should be set + // for the request. + RewriteStatus +) + +// Rewrite is middleware to rewrite request locations internally before being handled. +type Rewrite struct { + Next httpserver.Handler + FileSys http.FileSystem + Rules []Rule +} + +// ServeHTTP implements the httpserver.Handler interface. +func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { +outer: + for _, rule := range rw.Rules { + switch result := rule.Rewrite(rw.FileSys, r); result { + case RewriteDone: + break outer + case RewriteIgnored: + break + case RewriteStatus: + // only valid for complex rules. + if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 { + return cRule.Status, nil + } + } + } + return rw.Next.ServeHTTP(w, r) +} + +// Rule describes an internal location rewrite rule. +type Rule interface { + // Rewrite rewrites the internal location of the current request. + Rewrite(http.FileSystem, *http.Request) Result +} + +// SimpleRule is a simple rewrite rule. +type SimpleRule struct { + From, To string +} + +// NewSimpleRule creates a new Simple Rule +func NewSimpleRule(from, to string) SimpleRule { + return SimpleRule{from, to} +} + +// Rewrite rewrites the internal location of the current request. +func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result { + if s.From == r.URL.Path { + // take note of this rewrite for internal use by fastcgi + // all we need is the URI, not full URL + r.Header.Set(headerFieldName, r.URL.RequestURI()) + + // attempt rewrite + return To(fs, r, s.To, newReplacer(r)) + } + return RewriteIgnored +} + +// ComplexRule is a rewrite rule based on a regular expression +type ComplexRule struct { + // Path base. Request to this path and subpaths will be rewritten + Base string + + // Path to rewrite to + To string + + // If set, neither performs rewrite nor proceeds + // with request. Only returns code. + Status int + + // Extensions to filter by + Exts []string + + // Rewrite conditions + Ifs []If + + *regexp.Regexp +} + +// NewComplexRule creates a new RegexpRule. It returns an error if regexp +// pattern (pattern) or extensions (ext) are invalid. +func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { + // validate regexp if present + var r *regexp.Regexp + if pattern != "" { + var err error + r, err = regexp.Compile(pattern) + if err != nil { + return nil, err + } + } + + // validate extensions if present + for _, v := range ext { + if len(v) < 2 || (len(v) < 3 && v[0] == '!') { + // check if no extension is specified + if v != "/" && v != "!/" { + return nil, fmt.Errorf("invalid extension %v", v) + } + } + } + + return &ComplexRule{ + Base: base, + To: to, + Status: status, + Exts: ext, + Ifs: ifs, + Regexp: r, + }, nil +} + +// Rewrite rewrites the internal location of the current request. +func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result) { + rPath := req.URL.Path + replacer := newReplacer(req) + + // validate base + if !httpserver.Path(rPath).Matches(r.Base) { + return + } + + // validate extensions + if !r.matchExt(rPath) { + return + } + + // validate regexp if present + if r.Regexp != nil { + // include trailing slash in regexp if present + start := len(r.Base) + if strings.HasSuffix(r.Base, "/") { + start-- + } + + matches := r.FindStringSubmatch(rPath[start:]) + switch len(matches) { + case 0: + // no match + return + default: + // set regexp match variables {1}, {2} ... + + // url escaped values of ? and #. + q, f := url.QueryEscape("?"), url.QueryEscape("#") + + for i := 1; i < len(matches); i++ { + // Special case of unescaped # and ? by stdlib regexp. + // Reverse the unescape. + if strings.ContainsAny(matches[i], "?#") { + matches[i] = strings.NewReplacer("?", q, "#", f).Replace(matches[i]) + } + + replacer.Set(fmt.Sprint(i), matches[i]) + } + } + } + + // validate rewrite conditions + for _, i := range r.Ifs { + if !i.True(req) { + return + } + } + + // if status is present, stop rewrite and return it. + if r.Status != 0 { + return RewriteStatus + } + + // attempt rewrite + return To(fs, req, r.To, replacer) +} + +// matchExt matches rPath against registered file extensions. +// Returns true if a match is found and false otherwise. +func (r *ComplexRule) matchExt(rPath string) bool { + f := filepath.Base(rPath) + ext := path.Ext(f) + if ext == "" { + ext = "/" + } + + mustUse := false + for _, v := range r.Exts { + use := true + if v[0] == '!' { + use = false + v = v[1:] + } + + if use { + mustUse = true + } + + if ext == v { + return use + } + } + + if mustUse { + return false + } + return true +} + +// When a rewrite is performed, this header is added to the request +// and is for internal use only, specifically the fastcgi middleware. +// It contains the original request URI before the rewrite. +const headerFieldName = "Caddy-Rewrite-Original-URI" diff --git a/caddyhttp/rewrite/rewrite_test.go b/caddyhttp/rewrite/rewrite_test.go new file mode 100644 index 000000000..c2c59afa1 --- /dev/null +++ b/caddyhttp/rewrite/rewrite_test.go @@ -0,0 +1,163 @@ +package rewrite + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestRewrite(t *testing.T) { + rw := Rewrite{ + Next: httpserver.HandlerFunc(urlPrinter), + Rules: []Rule{ + NewSimpleRule("/from", "/to"), + NewSimpleRule("/a", "/b"), + NewSimpleRule("/b", "/b{uri}"), + }, + FileSys: http.Dir("."), + } + + regexps := [][]string{ + {"/reg/", ".*", "/to", ""}, + {"/r/", "[a-z]+", "/toaz", "!.html|"}, + {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, + {"/ab/", "ab", "/ab?{query}", ".txt|"}, + {"/ab/", "ab", "/ab?type=html&{query}", ".html|"}, + {"/abc/", "ab", "/abc/{file}", ".html|"}, + {"/abcd/", "ab", "/a/{dir}/{file}", ".html|"}, + {"/abcde/", "ab", "/a#{fragment}", ".html|"}, + {"/ab/", `.*\.jpg`, "/ajpg", ""}, + {"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""}, + {"/reg2grp", `(.*)`, "/{1}", ""}, + {"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""}, + {"/hashtest", "(.*)", "/{1}", ""}, + } + + for _, regexpRule := range regexps { + var ext []string + if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { + ext = s[:len(s)-1] + } + rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil) + if err != nil { + t.Fatal(err) + } + rw.Rules = append(rw.Rules, rule) + } + + tests := []struct { + from string + expectedTo string + }{ + {"/from", "/to"}, + {"/a", "/b"}, + {"/b", "/b/b"}, + {"/aa", "/aa"}, + {"/", "/"}, + {"/a?foo=bar", "/b?foo=bar"}, + {"/asdf?foo=bar", "/asdf?foo=bar"}, + {"/foo#bar", "/foo#bar"}, + {"/a#foo", "/b#foo"}, + {"/reg/foo", "/to"}, + {"/re", "/re"}, + {"/r/", "/r/"}, + {"/r/123", "/r/123"}, + {"/r/a123", "/toaz"}, + {"/r/abcz", "/toaz"}, + {"/r/z", "/toaz"}, + {"/r/z.html", "/r/z.html"}, + {"/r/z.js", "/toaz"}, + {"/url/asAB", "/to/url/asAB"}, + {"/url/aBsAB", "/url/aBsAB"}, + {"/url/a00sAB", "/to/url/a00sAB"}, + {"/url/a0z0sAB", "/to/url/a0z0sAB"}, + {"/ab/aa", "/ab/aa"}, + {"/ab/ab", "/ab/ab"}, + {"/ab/ab.txt", "/ab"}, + {"/ab/ab.txt?name=name", "/ab?name=name"}, + {"/ab/ab.html?name=name", "/ab?type=html&name=name"}, + {"/abc/ab.html", "/abc/ab.html"}, + {"/abcd/abcd.html", "/a/abcd/abcd.html"}, + {"/abcde/abcde.html", "/a"}, + {"/abcde/abcde.html#1234", "/a#1234"}, + {"/ab/ab.jpg", "/ajpg"}, + {"/reggrp/ad/12", "/a12"}, + {"/reggrp/ad/124a", "/a124/a"}, + {"/reggrp/ad/124abc", "/a124/abc"}, + {"/reg2grp/ad/124abc", "/ad/124abc"}, + {"/reg3grp/ad/aa/66", "/adaa66"}, + {"/reg3grp/ad612/n1n/ab", "/ad612n1nab"}, + {"/hashtest/a%20%23%20test", "/a%20%23%20test"}, + {"/hashtest/a%20%3F%20test", "/a%20%3F%20test"}, + {"/hashtest/a%20%3F%23test", "/a%20%3F%23test"}, + } + + for i, test := range tests { + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + rw.ServeHTTP(rec, req) + + if rec.Body.String() != test.expectedTo { + t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", + i, test.expectedTo, rec.Body.String()) + } + } + + statusTests := []struct { + status int + base string + to string + regexp string + statusExpected bool + }{ + {400, "/status", "", "", true}, + {400, "/ignore", "", "", false}, + {400, "/", "", "^/ignore", false}, + {400, "/", "", "(.*)", true}, + {400, "/status", "", "", true}, + } + + for i, s := range statusTests { + urlPath := fmt.Sprintf("/status%d", i) + rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil) + if err != nil { + t.Fatalf("Test %d: No error expected for rule but found %v", i, err) + } + rw.Rules = []Rule{rule} + req, err := http.NewRequest("GET", urlPath, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + code, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: No error expected for handler but found %v", i, err) + } + if s.statusExpected { + if rec.Body.String() != "" { + t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String()) + } + if code != s.status { + t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code) + } + } else { + if code != 0 { + t.Errorf("Test %d: Expected no status code found %d", i, code) + } + } + } +} + +func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprint(w, r.URL.String()) + return 0, nil +} diff --git a/caddyhttp/rewrite/setup.go b/caddyhttp/rewrite/setup.go new file mode 100644 index 000000000..317b21d4d --- /dev/null +++ b/caddyhttp/rewrite/setup.go @@ -0,0 +1,121 @@ +package rewrite + +import ( + "net/http" + "strconv" + "strings" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func init() { + caddy.RegisterPlugin(caddy.Plugin{ + Name: "rewrite", + ServerType: "http", + Action: setup, + }) +} + +// setup configures a new Rewrite middleware instance. +func setup(c *caddy.Controller) error { + rewrites, err := rewriteParse(c) + if err != nil { + return err + } + + cfg := httpserver.GetConfig(c.Key) + + cfg.AddMiddleware(func(next httpserver.Handler) httpserver.Handler { + return Rewrite{ + Next: next, + FileSys: http.Dir(cfg.Root), + Rules: rewrites, + } + }) + + return nil +} + +func rewriteParse(c *caddy.Controller) ([]Rule, error) { + var simpleRules []Rule + var regexpRules []Rule + + for c.Next() { + var rule Rule + var err error + var base = "/" + var pattern, to string + var status int + var ext []string + + args := c.RemainingArgs() + + var ifs []If + + switch len(args) { + case 1: + base = args[0] + fallthrough + case 0: + for c.NextBlock() { + switch c.Val() { + case "r", "regexp": + if !c.NextArg() { + return nil, c.ArgErr() + } + pattern = c.Val() + case "to": + args1 := c.RemainingArgs() + if len(args1) == 0 { + return nil, c.ArgErr() + } + to = strings.Join(args1, " ") + case "ext": + args1 := c.RemainingArgs() + if len(args1) == 0 { + return nil, c.ArgErr() + } + ext = args1 + case "if": + args1 := c.RemainingArgs() + if len(args1) != 3 { + return nil, c.ArgErr() + } + ifCond, err := NewIf(args1[0], args1[1], args1[2]) + if err != nil { + return nil, err + } + ifs = append(ifs, ifCond) + case "status": + if !c.NextArg() { + return nil, c.ArgErr() + } + status, _ = strconv.Atoi(c.Val()) + if status < 200 || (status > 299 && status < 400) || status > 499 { + return nil, c.Err("status must be 2xx or 4xx") + } + default: + return nil, c.ArgErr() + } + } + // ensure to or status is specified + if to == "" && status == 0 { + return nil, c.ArgErr() + } + if rule, err = NewComplexRule(base, pattern, to, status, ext, ifs); err != nil { + return nil, err + } + regexpRules = append(regexpRules, rule) + + // the only unhandled case is 2 and above + default: + rule = NewSimpleRule(args[0], strings.Join(args[1:], " ")) + simpleRules = append(simpleRules, rule) + } + + } + + // put simple rules in front to avoid regexp computation for them + return append(simpleRules, regexpRules...), nil +} diff --git a/caddyhttp/rewrite/setup_test.go b/caddyhttp/rewrite/setup_test.go new file mode 100644 index 000000000..ec22aa3c3 --- /dev/null +++ b/caddyhttp/rewrite/setup_test.go @@ -0,0 +1,239 @@ +package rewrite + +import ( + "fmt" + "regexp" + "testing" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +func TestSetup(t *testing.T) { + err := setup(caddy.NewTestController(`rewrite /from /to`)) + if err != nil { + t.Errorf("Expected no errors, but got: %v", err) + } + mids := httpserver.GetConfig("").Middleware() + if len(mids) == 0 { + t.Fatal("Expected middleware, had 0 instead") + } + + handler := mids[0](httpserver.EmptyNext) + myHandler, ok := handler.(Rewrite) + if !ok { + t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler) + } + + if !httpserver.SameNext(myHandler.Next, httpserver.EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } + + if len(myHandler.Rules) != 1 { + t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules)) + } +} + +func TestRewriteParse(t *testing.T) { + simpleTests := []struct { + input string + shouldErr bool + expected []Rule + }{ + {`rewrite /from /to`, false, []Rule{ + SimpleRule{From: "/from", To: "/to"}, + }}, + {`rewrite /from /to + rewrite a b`, false, []Rule{ + SimpleRule{From: "/from", To: "/to"}, + SimpleRule{From: "a", To: "b"}, + }}, + {`rewrite a`, true, []Rule{}}, + {`rewrite`, true, []Rule{}}, + {`rewrite a b c`, false, []Rule{ + SimpleRule{From: "a", To: "b c"}, + }}, + } + + for i, test := range simpleTests { + actual, err := rewriteParse(caddy.NewTestController(test.input)) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, e := range test.expected { + actualRule := actual[j].(SimpleRule) + expectedRule := e.(SimpleRule) + + if actualRule.From != expectedRule.From { + t.Errorf("Test %d, rule %d: Expected From=%s, got %s", + i, j, expectedRule.From, actualRule.From) + } + + if actualRule.To != expectedRule.To { + t.Errorf("Test %d, rule %d: Expected To=%s, got %s", + i, j, expectedRule.To, actualRule.To) + } + } + } + + regexpTests := []struct { + input string + shouldErr bool + expected []Rule + }{ + {`rewrite { + r .* + to /to /index.php? + }`, false, []Rule{ + &ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")}, + }}, + {`rewrite { + regexp .* + to /to + ext / html txt + }`, false, []Rule{ + &ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, + }}, + {`rewrite /path { + r rr + to /dest + } + rewrite / { + regexp [a-z]+ + to /to /to2 + } + `, false, []Rule{ + &ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, + &ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")}, + }}, + {`rewrite { + r .* + }`, true, []Rule{ + &ComplexRule{}, + }}, + {`rewrite { + + }`, true, []Rule{ + &ComplexRule{}, + }}, + {`rewrite /`, true, []Rule{ + &ComplexRule{}, + }}, + {`rewrite { + to /to + if {path} is a + }`, false, []Rule{ + &ComplexRule{Base: "/", To: "/to", Ifs: []If{{A: "{path}", Operator: "is", B: "a"}}}, + }}, + {`rewrite { + status 500 + }`, true, []Rule{ + &ComplexRule{}, + }}, + {`rewrite { + status 400 + }`, false, []Rule{ + &ComplexRule{Base: "/", Status: 400}, + }}, + {`rewrite { + to /to + status 400 + }`, false, []Rule{ + &ComplexRule{Base: "/", To: "/to", Status: 400}, + }}, + {`rewrite { + status 399 + }`, true, []Rule{ + &ComplexRule{}, + }}, + {`rewrite { + status 200 + }`, false, []Rule{ + &ComplexRule{Base: "/", Status: 200}, + }}, + {`rewrite { + to /to + status 200 + }`, false, []Rule{ + &ComplexRule{Base: "/", To: "/to", Status: 200}, + }}, + {`rewrite { + status 199 + }`, true, []Rule{ + &ComplexRule{}, + }}, + {`rewrite { + status 0 + }`, true, []Rule{ + &ComplexRule{}, + }}, + {`rewrite { + to /to + status 0 + }`, true, []Rule{ + &ComplexRule{}, + }}, + } + + for i, test := range regexpTests { + actual, err := rewriteParse(caddy.NewTestController(test.input)) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, e := range test.expected { + actualRule := actual[j].(*ComplexRule) + expectedRule := e.(*ComplexRule) + + if actualRule.Base != expectedRule.Base { + t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", + i, j, expectedRule.Base, actualRule.Base) + } + + if actualRule.To != expectedRule.To { + t.Errorf("Test %d, rule %d: Expected To=%s, got %s", + i, j, expectedRule.To, actualRule.To) + } + + if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) { + t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v", + i, j, expectedRule.To, actualRule.To) + } + + if actualRule.Regexp != nil { + if actualRule.String() != expectedRule.String() { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, expectedRule.String(), actualRule.String()) + } + } + + if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs)) + } + + } + } + +} diff --git a/caddyhttp/rewrite/testdata/testdir/empty b/caddyhttp/rewrite/testdata/testdir/empty new file mode 100644 index 000000000..e69de29bb diff --git a/caddyhttp/rewrite/testdata/testfile b/caddyhttp/rewrite/testdata/testfile new file mode 100644 index 000000000..7b4d68d70 --- /dev/null +++ b/caddyhttp/rewrite/testdata/testfile @@ -0,0 +1 @@ +empty \ No newline at end of file diff --git a/caddyhttp/rewrite/to.go b/caddyhttp/rewrite/to.go new file mode 100644 index 000000000..2cfe1de46 --- /dev/null +++ b/caddyhttp/rewrite/to.go @@ -0,0 +1,87 @@ +package rewrite + +import ( + "log" + "net/http" + "net/url" + "path" + "strings" + + "github.com/mholt/caddy/caddyhttp/httpserver" +) + +// To attempts rewrite. It attempts to rewrite to first valid path +// or the last path if none of the paths are valid. +// Returns true if rewrite is successful and false otherwise. +func To(fs http.FileSystem, r *http.Request, to string, replacer httpserver.Replacer) Result { + tos := strings.Fields(to) + + // try each rewrite paths + t := "" + for _, v := range tos { + t = path.Clean(replacer.Replace(v)) + + // add trailing slash for directories, if present + if strings.HasSuffix(v, "/") && !strings.HasSuffix(t, "/") { + t += "/" + } + + // validate file + if isValidFile(fs, t) { + break + } + } + + // validate resulting path + u, err := url.Parse(t) + if err != nil { + // Let the user know we got here. Rewrite is expected but + // the resulting url is invalid. + log.Printf("[ERROR] rewrite: resulting path '%v' is invalid. error: %v", t, err) + return RewriteIgnored + } + + // take note of this rewrite for internal use by fastcgi + // all we need is the URI, not full URL + r.Header.Set(headerFieldName, r.URL.RequestURI()) + + // perform rewrite + r.URL.Path = u.Path + if u.RawQuery != "" { + // overwrite query string if present + r.URL.RawQuery = u.RawQuery + } + if u.Fragment != "" { + // overwrite fragment if present + r.URL.Fragment = u.Fragment + } + + return RewriteDone +} + +// isValidFile checks if file exists on the filesystem. +// if file ends with `/`, it is validated as a directory. +func isValidFile(fs http.FileSystem, file string) bool { + if fs == nil { + return false + } + + f, err := fs.Open(file) + if err != nil { + return false + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return false + } + + // directory + if strings.HasSuffix(file, "/") { + return stat.IsDir() + } + + // file + return !stat.IsDir() +} diff --git a/caddyhttp/rewrite/to_test.go b/caddyhttp/rewrite/to_test.go new file mode 100644 index 000000000..6133c0b63 --- /dev/null +++ b/caddyhttp/rewrite/to_test.go @@ -0,0 +1,44 @@ +package rewrite + +import ( + "net/http" + "net/url" + "testing" +) + +func TestTo(t *testing.T) { + fs := http.Dir("testdata") + tests := []struct { + url string + to string + expected string + }{ + {"/", "/somefiles", "/somefiles"}, + {"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"}, + {"/somefiles", "/testfile /index.php{uri}", "/testfile"}, + {"/somefiles", "/testfile/ /index.php{uri}", "/index.php/somefiles"}, + {"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"}, + {"/?a=b", "/somefiles /index.php?{query}", "/index.php?a=b"}, + {"/?a=b", "/testfile /index.php?{query}", "/testfile?a=b"}, + {"/?a=b", "/testdir /index.php?{query}", "/index.php?a=b"}, + {"/?a=b", "/testdir/ /index.php?{query}", "/testdir/?a=b"}, + } + + uri := func(r *url.URL) string { + uri := r.Path + if r.RawQuery != "" { + uri += "?" + r.RawQuery + } + return uri + } + for i, test := range tests { + r, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Error(err) + } + To(fs, r, test.to, newReplacer(r)) + if uri(r.URL) != test.expected { + t.Errorf("Test %v: expected %v found %v", i, test.expected, uri(r.URL)) + } + } +}