mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-10-31 10:37:24 -04:00 
			
		
		
		
	Merge branch 'master' into md_changes
This commit is contained in:
		
						commit
						19a85d08c6
					
				| @ -83,10 +83,30 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { | |||||||
| 					c.TLS.Ciphers = append(c.TLS.Ciphers, value) | 					c.TLS.Ciphers = append(c.TLS.Ciphers, value) | ||||||
| 				} | 				} | ||||||
| 			case "clients": | 			case "clients": | ||||||
| 				c.TLS.ClientCerts = c.RemainingArgs() | 				clientCertList := c.RemainingArgs() | ||||||
| 				if len(c.TLS.ClientCerts) == 0 { | 				if len(clientCertList) == 0 { | ||||||
| 					return nil, c.ArgErr() | 					return nil, c.ArgErr() | ||||||
| 				} | 				} | ||||||
|  | 
 | ||||||
|  | 				listStart, mustProvideCA := 1, true | ||||||
|  | 				switch clientCertList[0] { | ||||||
|  | 				case "request": | ||||||
|  | 					c.TLS.ClientAuth = tls.RequestClientCert | ||||||
|  | 					mustProvideCA = false | ||||||
|  | 				case "require": | ||||||
|  | 					c.TLS.ClientAuth = tls.RequireAnyClientCert | ||||||
|  | 					mustProvideCA = false | ||||||
|  | 				case "verify_if_given": | ||||||
|  | 					c.TLS.ClientAuth = tls.VerifyClientCertIfGiven | ||||||
|  | 				default: | ||||||
|  | 					c.TLS.ClientAuth = tls.RequireAndVerifyClientCert | ||||||
|  | 					listStart = 0 | ||||||
|  | 				} | ||||||
|  | 				if mustProvideCA && len(clientCertList) <= listStart { | ||||||
|  | 					return nil, c.ArgErr() | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				c.TLS.ClientCerts = clientCertList[listStart:] | ||||||
| 			case "load": | 			case "load": | ||||||
| 				c.Args(&loadDir) | 				c.Args(&loadDir) | ||||||
| 				c.TLS.Manual = true | 				c.TLS.Manual = true | ||||||
|  | |||||||
| @ -189,34 +189,69 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSetupParseWithClientAuth(t *testing.T) { | func TestSetupParseWithClientAuth(t *testing.T) { | ||||||
|  | 	// Test missing client cert file | ||||||
| 	params := `tls ` + certFile + ` ` + keyFile + ` { | 	params := `tls ` + certFile + ` ` + keyFile + ` { | ||||||
| 			clients client_ca.crt client2_ca.crt | 			clients | ||||||
| 		}` | 		}` | ||||||
| 	c := setup.NewTestController(params) | 	c := setup.NewTestController(params) | ||||||
| 	_, err := Setup(c) | 	_, err := Setup(c) | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("Expected no errors, got: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if count := len(c.TLS.ClientCerts); count != 2 { |  | ||||||
| 		t.Fatalf("Expected two client certs, had %d", count) |  | ||||||
| 	} |  | ||||||
| 	if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" { |  | ||||||
| 		t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual) |  | ||||||
| 	} |  | ||||||
| 	if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" { |  | ||||||
| 		t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Test missing client cert file |  | ||||||
| 	params = `tls ` + certFile + ` ` + keyFile + ` { |  | ||||||
| 			clients |  | ||||||
| 		}` |  | ||||||
| 	c = setup.NewTestController(params) |  | ||||||
| 	_, err = Setup(c) |  | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		t.Errorf("Expected an error, but no error returned") | 		t.Errorf("Expected an error, but no error returned") | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"} | ||||||
|  | 	for caseNumber, caseData := range []struct { | ||||||
|  | 		params         string | ||||||
|  | 		clientAuthType tls.ClientAuthType | ||||||
|  | 		expectedErr    bool | ||||||
|  | 		expectedCAs    []string | ||||||
|  | 	}{ | ||||||
|  | 		{"", tls.NoClientCert, false, noCAs}, | ||||||
|  | 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||||
|  | 			clients client_ca.crt client2_ca.crt | ||||||
|  | 		}`, tls.RequireAndVerifyClientCert, false, twoCAs}, | ||||||
|  | 		// now come modifier | ||||||
|  | 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||||
|  | 			clients request | ||||||
|  | 		}`, tls.RequestClientCert, false, noCAs}, | ||||||
|  | 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||||
|  | 			clients require | ||||||
|  | 		}`, tls.RequireAnyClientCert, false, noCAs}, | ||||||
|  | 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||||
|  | 			clients verify_if_given client_ca.crt client2_ca.crt | ||||||
|  | 		}`, tls.VerifyClientCertIfGiven, false, twoCAs}, | ||||||
|  | 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||||
|  | 			clients verify_if_given | ||||||
|  | 		}`, tls.VerifyClientCertIfGiven, true, noCAs}, | ||||||
|  | 	} { | ||||||
|  | 		c := setup.NewTestController(caseData.params) | ||||||
|  | 		_, err := Setup(c) | ||||||
|  | 		if caseData.expectedErr { | ||||||
|  | 			if err == nil { | ||||||
|  | 				t.Errorf("In case %d: Expected an error, got: %v", caseNumber, err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if caseData.clientAuthType != c.TLS.ClientAuth { | ||||||
|  | 			t.Errorf("In case %d: Expected TLS client auth type %v, got: %v", | ||||||
|  | 				caseNumber, caseData.clientAuthType, c.TLS.ClientAuth) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if count := len(c.TLS.ClientCerts); count < len(caseData.expectedCAs) { | ||||||
|  | 			t.Fatalf("In case %d: Expected %d client certs, had %d", caseNumber, len(caseData.expectedCAs), count) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for idx, expected := range caseData.expectedCAs { | ||||||
|  | 			if actual := c.TLS.ClientCerts[idx]; actual != expected { | ||||||
|  | 				t.Errorf("In case %d: Expected %dth client cert file to be '%s', but was '%s'", | ||||||
|  | 					caseNumber, idx, expected, actual) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSetupParseWithKeyType(t *testing.T) { | func TestSetupParseWithKeyType(t *testing.T) { | ||||||
|  | |||||||
| @ -210,7 +210,8 @@ td:first-child svg { | |||||||
| 	position: absolute; | 	position: absolute; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| td .name { | td .name, | ||||||
|  | td .goup { | ||||||
| 	margin-left: 1.75em; | 	margin-left: 1.75em; | ||||||
| 	word-break: break-all; | 	word-break: break-all; | ||||||
| 	overflow-wrap: break-word; | 	overflow-wrap: break-word; | ||||||
| @ -263,7 +264,6 @@ footer { | |||||||
| 					</g> | 					</g> | ||||||
| 				</g> | 				</g> | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 				<!-- File --> | 				<!-- File --> | ||||||
| 				<linearGradient id="a"> | 				<linearGradient id="a"> | ||||||
| 					<stop stop-color="#cbcbcb" offset="0"/> | 					<stop stop-color="#cbcbcb" offset="0"/> | ||||||
| @ -299,14 +299,14 @@ footer { | |||||||
| 			</defs> | 			</defs> | ||||||
| 		</svg> | 		</svg> | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 		<header> | 		<header> | ||||||
| 			<h1>{{.LinkedPath}}</h1> | 			<h1> | ||||||
|  | 				{{range $url, $name := .BreadcrumbMap}}<a href="{{$url}}">{{$name}}</a>{{if ne $url "/"}}/{{end}}{{end}} | ||||||
|  | 			</h1> | ||||||
| 		</header> | 		</header> | ||||||
| 		<main> | 		<main> | ||||||
| 			<div class="meta"> | 			<div class="meta"> | ||||||
| 				<div class="content">	 | 				<div class="content"> | ||||||
| 					<span class="meta-item"><b>{{.NumDirs}}</b> director{{if eq 1 .NumDirs}}y{{else}}ies{{end}}</span> | 					<span class="meta-item"><b>{{.NumDirs}}</b> director{{if eq 1 .NumDirs}}y{{else}}ies{{end}}</span> | ||||||
| 					<span class="meta-item"><b>{{.NumFiles}}</b> file{{if ne 1 .NumFiles}}s{{end}}</span> | 					<span class="meta-item"><b>{{.NumFiles}}</b> file{{if ne 1 .NumFiles}}s{{end}}</span> | ||||||
| 				</div> | 				</div> | ||||||
| @ -342,6 +342,17 @@ footer { | |||||||
| 							{{end}} | 							{{end}} | ||||||
| 						</th> | 						</th> | ||||||
| 					</tr> | 					</tr> | ||||||
|  | 					{{if .CanGoUp}} | ||||||
|  | 					<tr> | ||||||
|  | 						<td> | ||||||
|  | 							<a href=".."> | ||||||
|  | 								<span class="goup">Go up</span> | ||||||
|  | 							</a> | ||||||
|  | 						</td> | ||||||
|  | 						<td>—</td> | ||||||
|  | 						<td>—</td> | ||||||
|  | 					</tr> | ||||||
|  | 					{{end}} | ||||||
| 					{{range .Items}} | 					{{range .Items}} | ||||||
| 					<tr> | 					<tr> | ||||||
| 						<td> | 						<td> | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								dist/automate.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								dist/automate.go
									
									
									
									
										vendored
									
									
								
							| @ -66,10 +66,7 @@ func main() { | |||||||
| 			if p.arch == "arm" { | 			if p.arch == "arm" { | ||||||
| 				baseFilename += p.arm | 				baseFilename += p.arm | ||||||
| 			} | 			} | ||||||
| 			binFilename = baseFilename | 			binFilename = baseFilename + p.binExt | ||||||
| 			if p.os == "windows" { |  | ||||||
| 				binFilename += ".exe" |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			binPath := filepath.Join(buildDir, binFilename) | 			binPath := filepath.Join(buildDir, binFilename) | ||||||
| 			archive := filepath.Join(releaseDir, fmt.Sprintf("%s.%s", baseFilename, p.archive)) | 			archive := filepath.Join(releaseDir, fmt.Sprintf("%s.%s", baseFilename, p.archive)) | ||||||
| @ -126,7 +123,7 @@ func (p platform) String() string { | |||||||
| func numProcs() int { | func numProcs() int { | ||||||
| 	n := runtime.GOMAXPROCS(0) | 	n := runtime.GOMAXPROCS(0) | ||||||
| 	if n == runtime.NumCPU() && n > 1 { | 	if n == runtime.NumCPU() && n > 1 { | ||||||
| 		n -= 1 | 		n-- | ||||||
| 	} | 	} | ||||||
| 	return n | 	return n | ||||||
| } | } | ||||||
| @ -151,8 +148,8 @@ var platforms = []platform{ | |||||||
| 	{os: "openbsd", arch: "386", archive: "tar.gz"}, | 	{os: "openbsd", arch: "386", archive: "tar.gz"}, | ||||||
| 	{os: "openbsd", arch: "amd64", archive: "tar.gz"}, | 	{os: "openbsd", arch: "amd64", archive: "tar.gz"}, | ||||||
| 	{os: "solaris", arch: "amd64", archive: "tar.gz"}, | 	{os: "solaris", arch: "amd64", archive: "tar.gz"}, | ||||||
| 	{os: "windows", arch: "386", archive: "zip"}, | 	{os: "windows", arch: "386", binExt: ".exe", archive: "zip"}, | ||||||
| 	{os: "windows", arch: "amd64", archive: "zip"}, | 	{os: "windows", arch: "amd64", binExt: ".exe", archive: "zip"}, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| var distContents = []string{ | var distContents = []string{ | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ import ( | |||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
| @ -69,11 +68,13 @@ type Listing struct { | |||||||
| 	middleware.Context | 	middleware.Context | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // LinkedPath returns l.Path where every element is a clickable | // BreadcrumbMap returns l.Path where every element is a map | ||||||
| // link to the path up to that point so far. | // of URLs and path segment names. | ||||||
| func (l Listing) LinkedPath() string { | func (l Listing) BreadcrumbMap() map[string]string { | ||||||
|  | 	result := map[string]string{} | ||||||
|  | 
 | ||||||
| 	if len(l.Path) == 0 { | 	if len(l.Path) == 0 { | ||||||
| 		return "" | 		return result | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// skip trailing slash | 	// skip trailing slash | ||||||
| @ -83,14 +84,13 @@ func (l Listing) LinkedPath() string { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	parts := strings.Split(lpath, "/") | 	parts := strings.Split(lpath, "/") | ||||||
| 	var result string |  | ||||||
| 	for i, part := range parts { | 	for i, part := range parts { | ||||||
| 		if i == 0 && part == "" { | 		if i == 0 && part == "" { | ||||||
| 			// Leading slash (root) | 			// Leading slash (root) | ||||||
| 			result += `<a href="/">/</a>` | 			result["/"] = "/" | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		result += fmt.Sprintf(`<a href="%s/">%s</a>/`, strings.Join(parts[:i+1], "/"), part) | 		result[strings.Join(parts[:i+1], "/")] = part | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return result | 	return result | ||||||
| @ -241,6 +241,11 @@ func (b Browse) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { | |||||||
| 		if !middleware.Path(r.URL.Path).Matches(bc.PathScope) { | 		if !middleware.Path(r.URL.Path).Matches(bc.PathScope) { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  | 		switch r.Method { | ||||||
|  | 		case http.MethodGet, http.MethodHead: | ||||||
|  | 		default: | ||||||
|  | 			return http.StatusMethodNotAllowed, nil | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		// Browsing navigation gets messed up if browsing a directory | 		// Browsing navigation gets messed up if browsing a directory | ||||||
| 		// that doesn't end in "/" (which it should, anyway) | 		// that doesn't end in "/" (which it should, anyway) | ||||||
|  | |||||||
| @ -104,6 +104,51 @@ func TestSort(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestBrowseHTTPMethods(t *testing.T) { | ||||||
|  | 	tmpl, err := template.ParseFiles("testdata/photos.tpl") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("An error occured while parsing the template: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	b := Browse{ | ||||||
|  | 		Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { | ||||||
|  | 			t.Fatalf("Next shouldn't be called") | ||||||
|  | 			return 0, nil | ||||||
|  | 		}), | ||||||
|  | 		Root: "./testdata", | ||||||
|  | 		Configs: []Config{ | ||||||
|  | 			{ | ||||||
|  | 				PathScope: "/photos", | ||||||
|  | 				Template:  tmpl, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rec := httptest.NewRecorder() | ||||||
|  | 	for method, expected := range map[string]int{ | ||||||
|  | 		http.MethodGet:     http.StatusOK, | ||||||
|  | 		http.MethodHead:    http.StatusOK, | ||||||
|  | 		http.MethodOptions: http.StatusMethodNotAllowed, | ||||||
|  | 		http.MethodPost:    http.StatusMethodNotAllowed, | ||||||
|  | 		http.MethodPut:     http.StatusMethodNotAllowed, | ||||||
|  | 		http.MethodPatch:   http.StatusMethodNotAllowed, | ||||||
|  | 		http.MethodDelete:  http.StatusMethodNotAllowed, | ||||||
|  | 		"COPY":             http.StatusMethodNotAllowed, | ||||||
|  | 		"MOVE":             http.StatusMethodNotAllowed, | ||||||
|  | 		"MKCOL":            http.StatusMethodNotAllowed, | ||||||
|  | 	} { | ||||||
|  | 		req, err := http.NewRequest(method, "/photos/", nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Test: Could not create HTTP request: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		code, _ := b.ServeHTTP(rec, req) | ||||||
|  | 		if code != expected { | ||||||
|  | 			t.Errorf("Wrong status with HTTP Method %s: expected %d, got %d", method, expected, code) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestBrowseTemplate(t *testing.T) { | func TestBrowseTemplate(t *testing.T) { | ||||||
| 	tmpl, err := template.ParseFiles("testdata/photos.tpl") | 	tmpl, err := template.ParseFiles("testdata/photos.tpl") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | |||||||
| @ -52,7 +52,13 @@ func TestInclude(t *testing.T) { | |||||||
| 			fileContent:          `str1 {{ .InvalidField }} str2`, | 			fileContent:          `str1 {{ .InvalidField }} str2`, | ||||||
| 			expectedContent:      "", | 			expectedContent:      "", | ||||||
| 			shouldErr:            true, | 			shouldErr:            true, | ||||||
| 			expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`, | 			expectedErrorContent: `InvalidField`, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			fileContent:          `str1 {{ .InvalidField }} str2`, | ||||||
|  | 			expectedContent:      "", | ||||||
|  | 			shouldErr:            true, | ||||||
|  | 			expectedErrorContent: `type middleware.Context`, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -141,3 +141,22 @@ func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { | |||||||
| 	} | 	} | ||||||
| 	return nil, nil, fmt.Errorf("not a Hijacker") | 	return nil, nil, fmt.Errorf("not a Hijacker") | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // Flush implements http.Flusher. It simply wraps the underlying | ||||||
|  | // ResponseWriter's Flush method if there is one, or panics. | ||||||
|  | func (w *gzipResponseWriter) Flush() { | ||||||
|  | 	if f, ok := w.ResponseWriter.(http.Flusher); ok { | ||||||
|  | 		f.Flush() | ||||||
|  | 	} else { | ||||||
|  | 		panic("not a Flusher") // should be recovered at the beginning of middleware stack | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CloseNotify implements http.CloseNotifier. | ||||||
|  | // It just inherits the underlying ResponseWriter's CloseNotify method. | ||||||
|  | func (w *gzipResponseWriter) CloseNotify() <-chan bool { | ||||||
|  | 	if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok { | ||||||
|  | 		return cn.CloseNotify() | ||||||
|  | 	} | ||||||
|  | 	panic("not a CloseNotifier") | ||||||
|  | } | ||||||
|  | |||||||
| @ -87,3 +87,12 @@ func (r *ResponseRecorder) Flush() { | |||||||
| 		panic("not a Flusher") // should be recovered at the beginning of middleware stack | 		panic("not a Flusher") // should be recovered at the beginning of middleware stack | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // CloseNotify implements http.CloseNotifier. | ||||||
|  | // It just inherits the underlying ResponseWriter's CloseNotify method. | ||||||
|  | func (r *ResponseRecorder) CloseNotify() <-chan bool { | ||||||
|  | 	if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { | ||||||
|  | 		return cn.CloseNotify() | ||||||
|  | 	} | ||||||
|  | 	panic("not a CloseNotifier") | ||||||
|  | } | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package server | package server | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"crypto/tls" | ||||||
| 	"net" | 	"net" | ||||||
| 
 | 
 | ||||||
| 	"github.com/mholt/caddy/middleware" | 	"github.com/mholt/caddy/middleware" | ||||||
| @ -75,4 +76,5 @@ type TLSConfig struct { | |||||||
| 	ProtocolMaxVersion       uint16 | 	ProtocolMaxVersion       uint16 | ||||||
| 	PreferServerCipherSuites bool | 	PreferServerCipherSuites bool | ||||||
| 	ClientCerts              []string | 	ClientCerts              []string | ||||||
|  | 	ClientAuth               tls.ClientAuthType | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										104
									
								
								server/server.go
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								server/server.go
									
									
									
									
									
								
							| @ -4,20 +4,28 @@ | |||||||
| package server | package server | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"crypto/rand" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"log" | 	"log" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | const ( | ||||||
|  | 	tlsNewTicketEvery = time.Hour * 10 // generate a new ticket for TLS PFS encryption every so often | ||||||
|  | 	tlsNumTickets     = 4              // hold and consider that many tickets to decrypt TLS sessions | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| // Server represents an instance of a server, which serves | // Server represents an instance of a server, which serves | ||||||
| // HTTP requests at a particular address (host and port). A | // HTTP requests at a particular address (host and port). A | ||||||
| // server is capable of serving numerous virtual hosts on | // server is capable of serving numerous virtual hosts on | ||||||
| @ -28,6 +36,7 @@ type Server struct { | |||||||
| 	HTTP2       bool                   // whether to enable HTTP/2 | 	HTTP2       bool                   // whether to enable HTTP/2 | ||||||
| 	tls         bool                   // whether this server is serving all HTTPS hosts or not | 	tls         bool                   // whether this server is serving all HTTPS hosts or not | ||||||
| 	OnDemandTLS bool                   // whether this server supports on-demand TLS (load certs at handshake-time) | 	OnDemandTLS bool                   // whether this server supports on-demand TLS (load certs at handshake-time) | ||||||
|  | 	tlsGovChan  chan struct{}          // close to stop the TLS maintenance goroutine | ||||||
| 	vhosts      map[string]virtualHost // virtual hosts keyed by their address | 	vhosts      map[string]virtualHost // virtual hosts keyed by their address | ||||||
| 	listener    ListenerFile           // the listener which is bound to the socket | 	listener    ListenerFile           // the listener which is bound to the socket | ||||||
| 	listenerMu  sync.Mutex             // protects listener | 	listenerMu  sync.Mutex             // protects listener | ||||||
| @ -216,6 +225,11 @@ func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Setup any goroutines governing over TLS settings | ||||||
|  | 	s.tlsGovChan = make(chan struct{}) | ||||||
|  | 	timer := time.NewTicker(tlsNewTicketEvery) | ||||||
|  | 	go runTLSTicketKeyRotation(s.TLSConfig, timer, s.tlsGovChan) | ||||||
|  | 
 | ||||||
| 	// Create TLS listener - note that we do not replace s.listener | 	// Create TLS listener - note that we do not replace s.listener | ||||||
| 	// with this TLS listener; tls.listener is unexported and does | 	// with this TLS listener; tls.listener is unexported and does | ||||||
| 	// not implement the File() method we need for graceful restarts | 	// not implement the File() method we need for graceful restarts | ||||||
| @ -258,6 +272,11 @@ func (s *Server) Stop() (err error) { | |||||||
| 	} | 	} | ||||||
| 	s.listenerMu.Unlock() | 	s.listenerMu.Unlock() | ||||||
| 
 | 
 | ||||||
|  | 	// Closing this signals any TLS governor goroutines to exit | ||||||
|  | 	if s.tlsGovChan != nil { | ||||||
|  | 		close(s.tlsGovChan) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -314,6 +333,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Use URL.RawPath If you need the original, "raw" URL.Path in your middleware. | ||||||
|  | 	// Collapse any ./ ../ /// madness here instead of doing that in every plugin. | ||||||
|  | 	if r.URL.Path != "/" { | ||||||
|  | 		path := filepath.Clean(r.URL.Path) | ||||||
|  | 		if !strings.HasPrefix(path, "/") { | ||||||
|  | 			path = "/" + path | ||||||
|  | 		} | ||||||
|  | 		r.URL.Path = path | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Execute the optional request callback if it exists and it's not disabled | 	// Execute the optional request callback if it exists and it's not disabled | ||||||
| 	if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) { | 	if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) { | ||||||
| 		return | 		return | ||||||
| @ -350,17 +379,19 @@ func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) { | |||||||
| // setupClientAuth sets up TLS client authentication only if | // setupClientAuth sets up TLS client authentication only if | ||||||
| // any of the TLS configs specified at least one cert file. | // any of the TLS configs specified at least one cert file. | ||||||
| func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { | func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { | ||||||
| 	var clientAuth bool | 	whatClientAuth := tls.NoClientCert | ||||||
| 	for _, cfg := range tlsConfigs { | 	for _, cfg := range tlsConfigs { | ||||||
| 		if len(cfg.ClientCerts) > 0 { | 		if whatClientAuth < cfg.ClientAuth { // Use the most restrictive. | ||||||
| 			clientAuth = true | 			whatClientAuth = cfg.ClientAuth | ||||||
| 			break |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if clientAuth { | 	if whatClientAuth != tls.NoClientCert { | ||||||
| 		pool := x509.NewCertPool() | 		pool := x509.NewCertPool() | ||||||
| 		for _, cfg := range tlsConfigs { | 		for _, cfg := range tlsConfigs { | ||||||
|  | 			if len(cfg.ClientCerts) == 0 { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
| 			for _, caFile := range cfg.ClientCerts { | 			for _, caFile := range cfg.ClientCerts { | ||||||
| 				caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect | 				caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| @ -372,12 +403,73 @@ func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		config.ClientCAs = pool | 		config.ClientCAs = pool | ||||||
| 		config.ClientAuth = tls.RequireAndVerifyClientCert | 		config.ClientAuth = whatClientAuth | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | var runTLSTicketKeyRotation = standaloneTLSTicketKeyRotation | ||||||
|  | 
 | ||||||
|  | var setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { | ||||||
|  | 	return keys | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // standaloneTLSTicketKeyRotation governs over the array of TLS ticket keys used to de/crypt TLS tickets. | ||||||
|  | // It periodically sets a new ticket key as the first one, used to encrypt (and decrypt), | ||||||
|  | // pushing any old ticket keys to the back, where they are considered for decryption only. | ||||||
|  | // | ||||||
|  | // Lack of entropy for the very first ticket key results in the feature being disabled (as does Go), | ||||||
|  | // later lack of entropy temporarily disables ticket key rotation. | ||||||
|  | // Old ticket keys are still phased out, though. | ||||||
|  | // | ||||||
|  | // Stops the timer when returning. | ||||||
|  | func standaloneTLSTicketKeyRotation(c *tls.Config, timer *time.Ticker, exitChan chan struct{}) { | ||||||
|  | 	defer timer.Stop() | ||||||
|  | 	// The entire page should be marked as sticky, but Go cannot do that | ||||||
|  | 	// without resorting to syscall#Mlock. And, we don't have madvise (for NODUMP), too. ☹ | ||||||
|  | 	keys := make([][32]byte, 1, tlsNumTickets) | ||||||
|  | 
 | ||||||
|  | 	rng := c.Rand | ||||||
|  | 	if rng == nil { | ||||||
|  | 		rng = rand.Reader | ||||||
|  | 	} | ||||||
|  | 	if _, err := io.ReadFull(rng, keys[0][:]); err != nil { | ||||||
|  | 		c.SessionTicketsDisabled = true // bail if we don't have the entropy for the first one | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) | ||||||
|  | 
 | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case _, isOpen := <-exitChan: | ||||||
|  | 			if !isOpen { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		case <-timer.C: | ||||||
|  | 			rng = c.Rand // could've changed since the start | ||||||
|  | 			if rng == nil { | ||||||
|  | 				rng = rand.Reader | ||||||
|  | 			} | ||||||
|  | 			var newTicketKey [32]byte | ||||||
|  | 			_, err := io.ReadFull(rng, newTicketKey[:]) | ||||||
|  | 
 | ||||||
|  | 			if len(keys) < tlsNumTickets { | ||||||
|  | 				keys = append(keys, keys[0]) // manipulates the internal length | ||||||
|  | 			} | ||||||
|  | 			for idx := len(keys) - 1; idx >= 1; idx-- { | ||||||
|  | 				keys[idx] = keys[idx-1] // yes, this makes copies | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if err == nil { | ||||||
|  | 				keys[0] = newTicketKey | ||||||
|  | 			} | ||||||
|  | 			// pushes the last key out, doesn't matter that we don't have a new one | ||||||
|  | 			c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // RunFirstStartupFuncs runs all of the server's FirstStartup | // RunFirstStartupFuncs runs all of the server's FirstStartup | ||||||
| // callback functions unless one of them returns an error first. | // callback functions unless one of them returns an error first. | ||||||
| // It is the caller's responsibility to call this only once and | // It is the caller's responsibility to call this only once and | ||||||
|  | |||||||
							
								
								
									
										60
									
								
								server/server_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								server/server_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | |||||||
|  | package server | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestStandaloneTLSTicketKeyRotation(t *testing.T) { | ||||||
|  | 	tlsGovChan := make(chan struct{}) | ||||||
|  | 	defer close(tlsGovChan) | ||||||
|  | 	callSync := make(chan bool, 1) | ||||||
|  | 	defer close(callSync) | ||||||
|  | 
 | ||||||
|  | 	oldHook := setSessionTicketKeysTestHook | ||||||
|  | 	defer func() { | ||||||
|  | 		setSessionTicketKeysTestHook = oldHook | ||||||
|  | 	}() | ||||||
|  | 	var keysInUse [][32]byte | ||||||
|  | 	setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { | ||||||
|  | 		keysInUse = keys | ||||||
|  | 		callSync <- true | ||||||
|  | 		return keys | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c := new(tls.Config) | ||||||
|  | 	timer := time.NewTicker(time.Millisecond * 1) | ||||||
|  | 
 | ||||||
|  | 	go standaloneTLSTicketKeyRotation(c, timer, tlsGovChan) | ||||||
|  | 
 | ||||||
|  | 	rounds := 0 | ||||||
|  | 	var lastTicketKey [32]byte | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-callSync: | ||||||
|  | 			if lastTicketKey == keysInUse[0] { | ||||||
|  | 				close(tlsGovChan) | ||||||
|  | 				t.Errorf("The same TLS ticket key has been used again (not rotated): %x.", lastTicketKey) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			lastTicketKey = keysInUse[0] | ||||||
|  | 			rounds++ | ||||||
|  | 			if rounds <= tlsNumTickets && len(keysInUse) != rounds { | ||||||
|  | 				close(tlsGovChan) | ||||||
|  | 				t.Errorf("Expected TLS ticket keys in use: %d; Got instead: %d.", rounds, len(keysInUse)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			if c.SessionTicketsDisabled == true { | ||||||
|  | 				t.Error("Session tickets have been disabled unexpectedly.") | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			if rounds >= tlsNumTickets+1 { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		case <-time.After(time.Second * 1): | ||||||
|  | 			t.Errorf("Timeout after %d rounds.", rounds) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user