diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index de29b8d9a..25d1d5046 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -19,6 +19,10 @@ type BasicAuth struct { // ServeHTTP implements the middleware.Handler interface. func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + + var hasAuth bool + var isAuthenticated bool + for _, rule := range a.Rules { for _, res := range rule.Resources { if !middleware.Path(r.URL.Path).Matches(res) { @@ -27,16 +31,26 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error // Path matches; parse auth header username, password, ok := r.BasicAuth() + hasAuth = true // Check credentials if !ok || username != rule.Username || password != rule.Password { - w.Header().Set("WWW-Authenticate", "Basic") - return http.StatusUnauthorized, nil + continue } - + // flag set only on success authentication + isAuthenticated = true + } + } + + if hasAuth { + if !isAuthenticated { + w.Header().Set("WWW-Authenticate", "Basic") + return http.StatusUnauthorized, nil + } else { // "It's an older code, sir, but it checks out. I was about to clear them." return a.Next.ServeHTTP(w, r) } + } // Pass-thru when no paths match diff --git a/middleware/basicauth/basicauth_test.go b/middleware/basicauth/basicauth_test.go new file mode 100644 index 000000000..b590bc35b --- /dev/null +++ b/middleware/basicauth/basicauth_test.go @@ -0,0 +1,121 @@ +package basicauth + +import ( + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mholt/caddy/middleware" +) + +func TestBasicAuth(t *testing.T) { + + rw := BasicAuth{ + Next: middleware.HandlerFunc(contentHandler), + Rules: []Rule{ + {Username: "test", Password: "ttest", Resources: []string{"/testing"}}, + }, + } + + tests := []struct { + from string + result int + cred string + }{ + {"/testing", http.StatusUnauthorized, "ttest:test"}, + {"/testing", http.StatusOK, "test:ttest"}, + {"/testing", http.StatusUnauthorized, ""}, + + } + + + for i, test := range tests { + + + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request %v", i, err) + } + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred)) + req.Header.Set("Authorization", auth) + + rec := httptest.NewRecorder() + result, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) + } + if result != test.result { + t.Errorf("Test %d: Expected Header '%d' but was '%d'", + i, test.result, result) + } + if result == http.StatusUnauthorized { + headers := rec.Header() + if val, ok := headers["Www-Authenticate"]; ok { + if val[0] != "Basic" { + t.Errorf("Test %d, Www-Authenticate should be %s provided %s", i, "Basic", val[0]) + } + } else { + t.Errorf("Test %d, should provide a header Www-Authenticate", i) + } + } + + + } + +} + + +func TestMultipleOverlappingRules(t *testing.T) { + rw := BasicAuth{ + Next: middleware.HandlerFunc(contentHandler), + Rules: []Rule{ + {Username: "t", Password: "p1", Resources: []string{"/t"}}, + {Username: "t1", Password: "p2", Resources: []string{"/t/t"}}, + }, + } + + tests := []struct { + from string + result int + cred string + }{ + {"/t", http.StatusOK, "t:p1"}, + {"/t/t", http.StatusOK, "t:p1"}, + {"/t/t", http.StatusOK, "t1:p2"}, + {"/a", http.StatusOK, "t1:p2"}, + {"/t/t", http.StatusUnauthorized, "t1:p3"}, + {"/t", http.StatusUnauthorized, "t1:p2"}, + } + + + for i, test := range tests { + + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request %v", i, err) + } + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred)) + req.Header.Set("Authorization", auth) + + rec := httptest.NewRecorder() + result, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: Could not ServeHTTP %v", i, err) + } + if result != test.result { + t.Errorf("Test %d: Expected Header '%d' but was '%d'", + i, test.result, result) + } + + } + +} + + + +func contentHandler(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprintf(w, r.URL.String()) + return http.StatusOK, nil +} \ No newline at end of file