mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	basicauth: Remove Authorization header on successful authz (issue #1324)
If a site owner protects a path with basicauth, no need to use the Authorization header elsewhere upstream, especially since it contains credentials. If this breaks anyone, it means they're double-dipping. It's usually good practice to clear out credentials as soon as they're not needed anymore. (Note that we only clear credentials after they're used, they stay for any other reason.)
This commit is contained in:
		
							parent
							
								
									a1a8d0f655
								
							
						
					
					
						commit
						54acb9b2de
					
				@ -34,8 +34,7 @@ type BasicAuth struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// ServeHTTP implements the httpserver.Handler interface.
 | 
					// ServeHTTP implements the httpserver.Handler interface.
 | 
				
			||||||
func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
					func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 | 
				
			||||||
	var hasAuth bool
 | 
						var protected, isAuthenticated bool
 | 
				
			||||||
	var isAuthenticated bool
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, rule := range a.Rules {
 | 
						for _, rule := range a.Rules {
 | 
				
			||||||
		for _, res := range rule.Resources {
 | 
							for _, res := range rule.Resources {
 | 
				
			||||||
@ -43,30 +42,34 @@ func (a BasicAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
 | 
				
			|||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Path matches; parse auth header
 | 
								// path matches; this endpoint is protected
 | 
				
			||||||
			username, password, ok := r.BasicAuth()
 | 
								protected = true
 | 
				
			||||||
			hasAuth = true
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Check credentials
 | 
								// parse auth header
 | 
				
			||||||
 | 
								username, password, ok := r.BasicAuth()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// check credentials
 | 
				
			||||||
			if !ok ||
 | 
								if !ok ||
 | 
				
			||||||
				username != rule.Username ||
 | 
									username != rule.Username ||
 | 
				
			||||||
				!rule.Password(password) {
 | 
									!rule.Password(password) {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Flag set only on successful authentication
 | 
								// by this point, authentication was successful
 | 
				
			||||||
			isAuthenticated = true
 | 
								isAuthenticated = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// remove credentials from request to avoid leaking upstream
 | 
				
			||||||
 | 
								r.Header.Del("Authorization")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if hasAuth {
 | 
						if protected && !isAuthenticated {
 | 
				
			||||||
		if !isAuthenticated {
 | 
							// browsers show a message that says something like:
 | 
				
			||||||
 | 
							// "The website says: <realm>"
 | 
				
			||||||
 | 
							// which is kinda dumb, but whatever.
 | 
				
			||||||
		w.Header().Set("WWW-Authenticate", "Basic realm=\"Restricted\"")
 | 
							w.Header().Set("WWW-Authenticate", "Basic realm=\"Restricted\"")
 | 
				
			||||||
		return http.StatusUnauthorized, nil
 | 
							return http.StatusUnauthorized, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
		// "It's an older code, sir, but it checks out. I was about to clear them."
 | 
					 | 
				
			||||||
		return a.Next.ServeHTTP(w, r)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Pass-through when no paths match
 | 
						// Pass-through when no paths match
 | 
				
			||||||
	return a.Next.ServeHTTP(w, r)
 | 
						return a.Next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
				
			|||||||
@ -17,51 +17,57 @@ func TestBasicAuth(t *testing.T) {
 | 
				
			|||||||
	rw := BasicAuth{
 | 
						rw := BasicAuth{
 | 
				
			||||||
		Next: httpserver.HandlerFunc(contentHandler),
 | 
							Next: httpserver.HandlerFunc(contentHandler),
 | 
				
			||||||
		Rules: []Rule{
 | 
							Rules: []Rule{
 | 
				
			||||||
			{Username: "test", Password: PlainMatcher("ttest"), Resources: []string{"/testing"}},
 | 
								{Username: "okuser", Password: PlainMatcher("okpass"), Resources: []string{"/testing"}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tests := []struct {
 | 
						tests := []struct {
 | 
				
			||||||
		from     string
 | 
							from     string
 | 
				
			||||||
		result   int
 | 
							result   int
 | 
				
			||||||
		cred   string
 | 
							user     string
 | 
				
			||||||
 | 
							password string
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{"/testing", http.StatusUnauthorized, "ttest:test"},
 | 
							{"/testing", http.StatusOK, "okuser", "okpass"},
 | 
				
			||||||
		{"/testing", http.StatusOK, "test:ttest"},
 | 
							{"/testing", http.StatusUnauthorized, "baduser", "okpass"},
 | 
				
			||||||
		{"/testing", http.StatusUnauthorized, ""},
 | 
							{"/testing", http.StatusUnauthorized, "okuser", "badpass"},
 | 
				
			||||||
 | 
							{"/testing", http.StatusUnauthorized, "OKuser", "okpass"},
 | 
				
			||||||
 | 
							{"/testing", http.StatusUnauthorized, "OKuser", "badPASS"},
 | 
				
			||||||
 | 
							{"/testing", http.StatusUnauthorized, "", "okpass"},
 | 
				
			||||||
 | 
							{"/testing", http.StatusUnauthorized, "okuser", ""},
 | 
				
			||||||
 | 
							{"/testing", http.StatusUnauthorized, "", ""},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i, test := range tests {
 | 
						for i, test := range tests {
 | 
				
			||||||
 | 
					 | 
				
			||||||
		req, err := http.NewRequest("GET", test.from, nil)
 | 
							req, err := http.NewRequest("GET", test.from, nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatalf("Test %d: Could not create HTTP request %v", i, err)
 | 
								t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.cred))
 | 
							req.SetBasicAuth(test.user, test.password)
 | 
				
			||||||
		req.Header.Set("Authorization", auth)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		rec := httptest.NewRecorder()
 | 
							rec := httptest.NewRecorder()
 | 
				
			||||||
		result, err := rw.ServeHTTP(rec, req)
 | 
							result, err := rw.ServeHTTP(rec, req)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatalf("Test %d: Could not ServeHTTP %v", i, err)
 | 
								t.Fatalf("Test %d: Could not ServeHTTP: %v", i, err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if result != test.result {
 | 
							if result != test.result {
 | 
				
			||||||
			t.Errorf("Test %d: Expected Header '%d' but was '%d'",
 | 
								t.Errorf("Test %d: Expected status code %d but was %d",
 | 
				
			||||||
				i, test.result, result)
 | 
									i, test.result, result)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if result == http.StatusUnauthorized {
 | 
							if test.result == http.StatusUnauthorized {
 | 
				
			||||||
			headers := rec.Header()
 | 
								headers := rec.Header()
 | 
				
			||||||
			if val, ok := headers["Www-Authenticate"]; ok {
 | 
								if val, ok := headers["Www-Authenticate"]; ok {
 | 
				
			||||||
				if val[0] != "Basic realm=\"Restricted\"" {
 | 
									if got, want := val[0], "Basic realm=\"Restricted\""; got != want {
 | 
				
			||||||
					t.Errorf("Test %d, Www-Authenticate should be %s provided %s", i, "Basic", val[0])
 | 
										t.Errorf("Test %d: Www-Authenticate header should be '%s', got: '%s'", i, want, got)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				t.Errorf("Test %d, should provide a header Www-Authenticate", i)
 | 
									t.Errorf("Test %d: response should have a 'Www-Authenticate' header", i)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								if got, want := req.Header.Get("Authorization"), ""; got != want {
 | 
				
			||||||
 | 
									t.Errorf("Test %d: Expected Authorization header to be stripped from request after successful authentication, but is: %s", i, got)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMultipleOverlappingRules(t *testing.T) {
 | 
					func TestMultipleOverlappingRules(t *testing.T) {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user