mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-10-31 10:37:24 -04:00 
			
		
		
		
	Merge branch 'master' into md_changes
This commit is contained in:
		
						commit
						19a85d08c6
					
				| @ -83,10 +83,30 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { | ||||
| 					c.TLS.Ciphers = append(c.TLS.Ciphers, value) | ||||
| 				} | ||||
| 			case "clients": | ||||
| 				c.TLS.ClientCerts = c.RemainingArgs() | ||||
| 				if len(c.TLS.ClientCerts) == 0 { | ||||
| 				clientCertList := c.RemainingArgs() | ||||
| 				if len(clientCertList) == 0 { | ||||
| 					return nil, c.ArgErr() | ||||
| 				} | ||||
| 
 | ||||
| 				listStart, mustProvideCA := 1, true | ||||
| 				switch clientCertList[0] { | ||||
| 				case "request": | ||||
| 					c.TLS.ClientAuth = tls.RequestClientCert | ||||
| 					mustProvideCA = false | ||||
| 				case "require": | ||||
| 					c.TLS.ClientAuth = tls.RequireAnyClientCert | ||||
| 					mustProvideCA = false | ||||
| 				case "verify_if_given": | ||||
| 					c.TLS.ClientAuth = tls.VerifyClientCertIfGiven | ||||
| 				default: | ||||
| 					c.TLS.ClientAuth = tls.RequireAndVerifyClientCert | ||||
| 					listStart = 0 | ||||
| 				} | ||||
| 				if mustProvideCA && len(clientCertList) <= listStart { | ||||
| 					return nil, c.ArgErr() | ||||
| 				} | ||||
| 
 | ||||
| 				c.TLS.ClientCerts = clientCertList[listStart:] | ||||
| 			case "load": | ||||
| 				c.Args(&loadDir) | ||||
| 				c.TLS.Manual = true | ||||
|  | ||||
| @ -189,34 +189,69 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestSetupParseWithClientAuth(t *testing.T) { | ||||
| 	// Test missing client cert file | ||||
| 	params := `tls ` + certFile + ` ` + keyFile + ` { | ||||
| 			clients client_ca.crt client2_ca.crt | ||||
| 			clients | ||||
| 		}` | ||||
| 	c := setup.NewTestController(params) | ||||
| 	_, err := Setup(c) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Expected no errors, got: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := len(c.TLS.ClientCerts); count != 2 { | ||||
| 		t.Fatalf("Expected two client certs, had %d", count) | ||||
| 	} | ||||
| 	if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" { | ||||
| 		t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual) | ||||
| 	} | ||||
| 	if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" { | ||||
| 		t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual) | ||||
| 	} | ||||
| 
 | ||||
| 	// Test missing client cert file | ||||
| 	params = `tls ` + certFile + ` ` + keyFile + ` { | ||||
| 			clients | ||||
| 		}` | ||||
| 	c = setup.NewTestController(params) | ||||
| 	_, err = Setup(c) | ||||
| 	if err == nil { | ||||
| 		t.Errorf("Expected an error, but no error returned") | ||||
| 	} | ||||
| 
 | ||||
| 	noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"} | ||||
| 	for caseNumber, caseData := range []struct { | ||||
| 		params         string | ||||
| 		clientAuthType tls.ClientAuthType | ||||
| 		expectedErr    bool | ||||
| 		expectedCAs    []string | ||||
| 	}{ | ||||
| 		{"", tls.NoClientCert, false, noCAs}, | ||||
| 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||
| 			clients client_ca.crt client2_ca.crt | ||||
| 		}`, tls.RequireAndVerifyClientCert, false, twoCAs}, | ||||
| 		// now come modifier | ||||
| 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||
| 			clients request | ||||
| 		}`, tls.RequestClientCert, false, noCAs}, | ||||
| 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||
| 			clients require | ||||
| 		}`, tls.RequireAnyClientCert, false, noCAs}, | ||||
| 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||
| 			clients verify_if_given client_ca.crt client2_ca.crt | ||||
| 		}`, tls.VerifyClientCertIfGiven, false, twoCAs}, | ||||
| 		{`tls ` + certFile + ` ` + keyFile + ` { | ||||
| 			clients verify_if_given | ||||
| 		}`, tls.VerifyClientCertIfGiven, true, noCAs}, | ||||
| 	} { | ||||
| 		c := setup.NewTestController(caseData.params) | ||||
| 		_, err := Setup(c) | ||||
| 		if caseData.expectedErr { | ||||
| 			if err == nil { | ||||
| 				t.Errorf("In case %d: Expected an error, got: %v", caseNumber, err) | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err) | ||||
| 		} | ||||
| 
 | ||||
| 		if caseData.clientAuthType != c.TLS.ClientAuth { | ||||
| 			t.Errorf("In case %d: Expected TLS client auth type %v, got: %v", | ||||
| 				caseNumber, caseData.clientAuthType, c.TLS.ClientAuth) | ||||
| 		} | ||||
| 
 | ||||
| 		if count := len(c.TLS.ClientCerts); count < len(caseData.expectedCAs) { | ||||
| 			t.Fatalf("In case %d: Expected %d client certs, had %d", caseNumber, len(caseData.expectedCAs), count) | ||||
| 		} | ||||
| 
 | ||||
| 		for idx, expected := range caseData.expectedCAs { | ||||
| 			if actual := c.TLS.ClientCerts[idx]; actual != expected { | ||||
| 				t.Errorf("In case %d: Expected %dth client cert file to be '%s', but was '%s'", | ||||
| 					caseNumber, idx, expected, actual) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSetupParseWithKeyType(t *testing.T) { | ||||
|  | ||||
| @ -210,7 +210,8 @@ td:first-child svg { | ||||
| 	position: absolute; | ||||
| } | ||||
| 
 | ||||
| td .name { | ||||
| td .name, | ||||
| td .goup { | ||||
| 	margin-left: 1.75em; | ||||
| 	word-break: break-all; | ||||
| 	overflow-wrap: break-word; | ||||
| @ -263,7 +264,6 @@ footer { | ||||
| 					</g> | ||||
| 				</g> | ||||
| 
 | ||||
| 
 | ||||
| 				<!-- File --> | ||||
| 				<linearGradient id="a"> | ||||
| 					<stop stop-color="#cbcbcb" offset="0"/> | ||||
| @ -299,10 +299,10 @@ footer { | ||||
| 			</defs> | ||||
| 		</svg> | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 		<header> | ||||
| 			<h1>{{.LinkedPath}}</h1> | ||||
| 			<h1> | ||||
| 				{{range $url, $name := .BreadcrumbMap}}<a href="{{$url}}">{{$name}}</a>{{if ne $url "/"}}/{{end}}{{end}} | ||||
| 			</h1> | ||||
| 		</header> | ||||
| 		<main> | ||||
| 			<div class="meta"> | ||||
| @ -342,6 +342,17 @@ footer { | ||||
| 							{{end}} | ||||
| 						</th> | ||||
| 					</tr> | ||||
| 					{{if .CanGoUp}} | ||||
| 					<tr> | ||||
| 						<td> | ||||
| 							<a href=".."> | ||||
| 								<span class="goup">Go up</span> | ||||
| 							</a> | ||||
| 						</td> | ||||
| 						<td>—</td> | ||||
| 						<td>—</td> | ||||
| 					</tr> | ||||
| 					{{end}} | ||||
| 					{{range .Items}} | ||||
| 					<tr> | ||||
| 						<td> | ||||
|  | ||||
							
								
								
									
										11
									
								
								dist/automate.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								dist/automate.go
									
									
									
									
										vendored
									
									
								
							| @ -66,10 +66,7 @@ func main() { | ||||
| 			if p.arch == "arm" { | ||||
| 				baseFilename += p.arm | ||||
| 			} | ||||
| 			binFilename = baseFilename | ||||
| 			if p.os == "windows" { | ||||
| 				binFilename += ".exe" | ||||
| 			} | ||||
| 			binFilename = baseFilename + p.binExt | ||||
| 
 | ||||
| 			binPath := filepath.Join(buildDir, binFilename) | ||||
| 			archive := filepath.Join(releaseDir, fmt.Sprintf("%s.%s", baseFilename, p.archive)) | ||||
| @ -126,7 +123,7 @@ func (p platform) String() string { | ||||
| func numProcs() int { | ||||
| 	n := runtime.GOMAXPROCS(0) | ||||
| 	if n == runtime.NumCPU() && n > 1 { | ||||
| 		n -= 1 | ||||
| 		n-- | ||||
| 	} | ||||
| 	return n | ||||
| } | ||||
| @ -151,8 +148,8 @@ var platforms = []platform{ | ||||
| 	{os: "openbsd", arch: "386", archive: "tar.gz"}, | ||||
| 	{os: "openbsd", arch: "amd64", archive: "tar.gz"}, | ||||
| 	{os: "solaris", arch: "amd64", archive: "tar.gz"}, | ||||
| 	{os: "windows", arch: "386", archive: "zip"}, | ||||
| 	{os: "windows", arch: "amd64", archive: "zip"}, | ||||
| 	{os: "windows", arch: "386", binExt: ".exe", archive: "zip"}, | ||||
| 	{os: "windows", arch: "amd64", binExt: ".exe", archive: "zip"}, | ||||
| } | ||||
| 
 | ||||
| var distContents = []string{ | ||||
|  | ||||
| @ -6,7 +6,6 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| @ -69,11 +68,13 @@ type Listing struct { | ||||
| 	middleware.Context | ||||
| } | ||||
| 
 | ||||
| // LinkedPath returns l.Path where every element is a clickable | ||||
| // link to the path up to that point so far. | ||||
| func (l Listing) LinkedPath() string { | ||||
| // BreadcrumbMap returns l.Path where every element is a map | ||||
| // of URLs and path segment names. | ||||
| func (l Listing) BreadcrumbMap() map[string]string { | ||||
| 	result := map[string]string{} | ||||
| 
 | ||||
| 	if len(l.Path) == 0 { | ||||
| 		return "" | ||||
| 		return result | ||||
| 	} | ||||
| 
 | ||||
| 	// skip trailing slash | ||||
| @ -83,14 +84,13 @@ func (l Listing) LinkedPath() string { | ||||
| 	} | ||||
| 
 | ||||
| 	parts := strings.Split(lpath, "/") | ||||
| 	var result string | ||||
| 	for i, part := range parts { | ||||
| 		if i == 0 && part == "" { | ||||
| 			// Leading slash (root) | ||||
| 			result += `<a href="/">/</a>` | ||||
| 			result["/"] = "/" | ||||
| 			continue | ||||
| 		} | ||||
| 		result += fmt.Sprintf(`<a href="%s/">%s</a>/`, strings.Join(parts[:i+1], "/"), part) | ||||
| 		result[strings.Join(parts[:i+1], "/")] = part | ||||
| 	} | ||||
| 
 | ||||
| 	return result | ||||
| @ -241,6 +241,11 @@ func (b Browse) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { | ||||
| 		if !middleware.Path(r.URL.Path).Matches(bc.PathScope) { | ||||
| 			continue | ||||
| 		} | ||||
| 		switch r.Method { | ||||
| 		case http.MethodGet, http.MethodHead: | ||||
| 		default: | ||||
| 			return http.StatusMethodNotAllowed, nil | ||||
| 		} | ||||
| 
 | ||||
| 		// Browsing navigation gets messed up if browsing a directory | ||||
| 		// that doesn't end in "/" (which it should, anyway) | ||||
|  | ||||
| @ -104,6 +104,51 @@ func TestSort(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestBrowseHTTPMethods(t *testing.T) { | ||||
| 	tmpl, err := template.ParseFiles("testdata/photos.tpl") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("An error occured while parsing the template: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	b := Browse{ | ||||
| 		Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { | ||||
| 			t.Fatalf("Next shouldn't be called") | ||||
| 			return 0, nil | ||||
| 		}), | ||||
| 		Root: "./testdata", | ||||
| 		Configs: []Config{ | ||||
| 			{ | ||||
| 				PathScope: "/photos", | ||||
| 				Template:  tmpl, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	rec := httptest.NewRecorder() | ||||
| 	for method, expected := range map[string]int{ | ||||
| 		http.MethodGet:     http.StatusOK, | ||||
| 		http.MethodHead:    http.StatusOK, | ||||
| 		http.MethodOptions: http.StatusMethodNotAllowed, | ||||
| 		http.MethodPost:    http.StatusMethodNotAllowed, | ||||
| 		http.MethodPut:     http.StatusMethodNotAllowed, | ||||
| 		http.MethodPatch:   http.StatusMethodNotAllowed, | ||||
| 		http.MethodDelete:  http.StatusMethodNotAllowed, | ||||
| 		"COPY":             http.StatusMethodNotAllowed, | ||||
| 		"MOVE":             http.StatusMethodNotAllowed, | ||||
| 		"MKCOL":            http.StatusMethodNotAllowed, | ||||
| 	} { | ||||
| 		req, err := http.NewRequest(method, "/photos/", nil) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("Test: Could not create HTTP request: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		code, _ := b.ServeHTTP(rec, req) | ||||
| 		if code != expected { | ||||
| 			t.Errorf("Wrong status with HTTP Method %s: expected %d, got %d", method, expected, code) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestBrowseTemplate(t *testing.T) { | ||||
| 	tmpl, err := template.ParseFiles("testdata/photos.tpl") | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -52,7 +52,13 @@ func TestInclude(t *testing.T) { | ||||
| 			fileContent:          `str1 {{ .InvalidField }} str2`, | ||||
| 			expectedContent:      "", | ||||
| 			shouldErr:            true, | ||||
| 			expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`, | ||||
| 			expectedErrorContent: `InvalidField`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			fileContent:          `str1 {{ .InvalidField }} str2`, | ||||
| 			expectedContent:      "", | ||||
| 			shouldErr:            true, | ||||
| 			expectedErrorContent: `type middleware.Context`, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -141,3 +141,22 @@ func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { | ||||
| 	} | ||||
| 	return nil, nil, fmt.Errorf("not a Hijacker") | ||||
| } | ||||
| 
 | ||||
| // Flush implements http.Flusher. It simply wraps the underlying | ||||
| // ResponseWriter's Flush method if there is one, or panics. | ||||
| func (w *gzipResponseWriter) Flush() { | ||||
| 	if f, ok := w.ResponseWriter.(http.Flusher); ok { | ||||
| 		f.Flush() | ||||
| 	} else { | ||||
| 		panic("not a Flusher") // should be recovered at the beginning of middleware stack | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // CloseNotify implements http.CloseNotifier. | ||||
| // It just inherits the underlying ResponseWriter's CloseNotify method. | ||||
| func (w *gzipResponseWriter) CloseNotify() <-chan bool { | ||||
| 	if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok { | ||||
| 		return cn.CloseNotify() | ||||
| 	} | ||||
| 	panic("not a CloseNotifier") | ||||
| } | ||||
|  | ||||
| @ -87,3 +87,12 @@ func (r *ResponseRecorder) Flush() { | ||||
| 		panic("not a Flusher") // should be recovered at the beginning of middleware stack | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // CloseNotify implements http.CloseNotifier. | ||||
| // It just inherits the underlying ResponseWriter's CloseNotify method. | ||||
| func (r *ResponseRecorder) CloseNotify() <-chan bool { | ||||
| 	if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { | ||||
| 		return cn.CloseNotify() | ||||
| 	} | ||||
| 	panic("not a CloseNotifier") | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"net" | ||||
| 
 | ||||
| 	"github.com/mholt/caddy/middleware" | ||||
| @ -75,4 +76,5 @@ type TLSConfig struct { | ||||
| 	ProtocolMaxVersion       uint16 | ||||
| 	PreferServerCipherSuites bool | ||||
| 	ClientCerts              []string | ||||
| 	ClientAuth               tls.ClientAuthType | ||||
| } | ||||
|  | ||||
							
								
								
									
										104
									
								
								server/server.go
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								server/server.go
									
									
									
									
									
								
							| @ -4,20 +4,28 @@ | ||||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	tlsNewTicketEvery = time.Hour * 10 // generate a new ticket for TLS PFS encryption every so often | ||||
| 	tlsNumTickets     = 4              // hold and consider that many tickets to decrypt TLS sessions | ||||
| ) | ||||
| 
 | ||||
| // Server represents an instance of a server, which serves | ||||
| // HTTP requests at a particular address (host and port). A | ||||
| // server is capable of serving numerous virtual hosts on | ||||
| @ -28,6 +36,7 @@ type Server struct { | ||||
| 	HTTP2       bool                   // whether to enable HTTP/2 | ||||
| 	tls         bool                   // whether this server is serving all HTTPS hosts or not | ||||
| 	OnDemandTLS bool                   // whether this server supports on-demand TLS (load certs at handshake-time) | ||||
| 	tlsGovChan  chan struct{}          // close to stop the TLS maintenance goroutine | ||||
| 	vhosts      map[string]virtualHost // virtual hosts keyed by their address | ||||
| 	listener    ListenerFile           // the listener which is bound to the socket | ||||
| 	listenerMu  sync.Mutex             // protects listener | ||||
| @ -216,6 +225,11 @@ func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// Setup any goroutines governing over TLS settings | ||||
| 	s.tlsGovChan = make(chan struct{}) | ||||
| 	timer := time.NewTicker(tlsNewTicketEvery) | ||||
| 	go runTLSTicketKeyRotation(s.TLSConfig, timer, s.tlsGovChan) | ||||
| 
 | ||||
| 	// Create TLS listener - note that we do not replace s.listener | ||||
| 	// with this TLS listener; tls.listener is unexported and does | ||||
| 	// not implement the File() method we need for graceful restarts | ||||
| @ -258,6 +272,11 @@ func (s *Server) Stop() (err error) { | ||||
| 	} | ||||
| 	s.listenerMu.Unlock() | ||||
| 
 | ||||
| 	// Closing this signals any TLS governor goroutines to exit | ||||
| 	if s.tlsGovChan != nil { | ||||
| 		close(s.tlsGovChan) | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -314,6 +333,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Use URL.RawPath If you need the original, "raw" URL.Path in your middleware. | ||||
| 	// Collapse any ./ ../ /// madness here instead of doing that in every plugin. | ||||
| 	if r.URL.Path != "/" { | ||||
| 		path := filepath.Clean(r.URL.Path) | ||||
| 		if !strings.HasPrefix(path, "/") { | ||||
| 			path = "/" + path | ||||
| 		} | ||||
| 		r.URL.Path = path | ||||
| 	} | ||||
| 
 | ||||
| 	// Execute the optional request callback if it exists and it's not disabled | ||||
| 	if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) { | ||||
| 		return | ||||
| @ -350,17 +379,19 @@ func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) { | ||||
| // setupClientAuth sets up TLS client authentication only if | ||||
| // any of the TLS configs specified at least one cert file. | ||||
| func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { | ||||
| 	var clientAuth bool | ||||
| 	whatClientAuth := tls.NoClientCert | ||||
| 	for _, cfg := range tlsConfigs { | ||||
| 		if len(cfg.ClientCerts) > 0 { | ||||
| 			clientAuth = true | ||||
| 			break | ||||
| 		if whatClientAuth < cfg.ClientAuth { // Use the most restrictive. | ||||
| 			whatClientAuth = cfg.ClientAuth | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if clientAuth { | ||||
| 	if whatClientAuth != tls.NoClientCert { | ||||
| 		pool := x509.NewCertPool() | ||||
| 		for _, cfg := range tlsConfigs { | ||||
| 			if len(cfg.ClientCerts) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
| 			for _, caFile := range cfg.ClientCerts { | ||||
| 				caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect | ||||
| 				if err != nil { | ||||
| @ -372,12 +403,73 @@ func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { | ||||
| 			} | ||||
| 		} | ||||
| 		config.ClientCAs = pool | ||||
| 		config.ClientAuth = tls.RequireAndVerifyClientCert | ||||
| 		config.ClientAuth = whatClientAuth | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| var runTLSTicketKeyRotation = standaloneTLSTicketKeyRotation | ||||
| 
 | ||||
| var setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { | ||||
| 	return keys | ||||
| } | ||||
| 
 | ||||
| // standaloneTLSTicketKeyRotation governs over the array of TLS ticket keys used to de/crypt TLS tickets. | ||||
| // It periodically sets a new ticket key as the first one, used to encrypt (and decrypt), | ||||
| // pushing any old ticket keys to the back, where they are considered for decryption only. | ||||
| // | ||||
| // Lack of entropy for the very first ticket key results in the feature being disabled (as does Go), | ||||
| // later lack of entropy temporarily disables ticket key rotation. | ||||
| // Old ticket keys are still phased out, though. | ||||
| // | ||||
| // Stops the timer when returning. | ||||
| func standaloneTLSTicketKeyRotation(c *tls.Config, timer *time.Ticker, exitChan chan struct{}) { | ||||
| 	defer timer.Stop() | ||||
| 	// The entire page should be marked as sticky, but Go cannot do that | ||||
| 	// without resorting to syscall#Mlock. And, we don't have madvise (for NODUMP), too. ☹ | ||||
| 	keys := make([][32]byte, 1, tlsNumTickets) | ||||
| 
 | ||||
| 	rng := c.Rand | ||||
| 	if rng == nil { | ||||
| 		rng = rand.Reader | ||||
| 	} | ||||
| 	if _, err := io.ReadFull(rng, keys[0][:]); err != nil { | ||||
| 		c.SessionTicketsDisabled = true // bail if we don't have the entropy for the first one | ||||
| 		return | ||||
| 	} | ||||
| 	c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) | ||||
| 
 | ||||
| 	for { | ||||
| 		select { | ||||
| 		case _, isOpen := <-exitChan: | ||||
| 			if !isOpen { | ||||
| 				return | ||||
| 			} | ||||
| 		case <-timer.C: | ||||
| 			rng = c.Rand // could've changed since the start | ||||
| 			if rng == nil { | ||||
| 				rng = rand.Reader | ||||
| 			} | ||||
| 			var newTicketKey [32]byte | ||||
| 			_, err := io.ReadFull(rng, newTicketKey[:]) | ||||
| 
 | ||||
| 			if len(keys) < tlsNumTickets { | ||||
| 				keys = append(keys, keys[0]) // manipulates the internal length | ||||
| 			} | ||||
| 			for idx := len(keys) - 1; idx >= 1; idx-- { | ||||
| 				keys[idx] = keys[idx-1] // yes, this makes copies | ||||
| 			} | ||||
| 
 | ||||
| 			if err == nil { | ||||
| 				keys[0] = newTicketKey | ||||
| 			} | ||||
| 			// pushes the last key out, doesn't matter that we don't have a new one | ||||
| 			c.SetSessionTicketKeys(setSessionTicketKeysTestHook(keys)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // RunFirstStartupFuncs runs all of the server's FirstStartup | ||||
| // callback functions unless one of them returns an error first. | ||||
| // It is the caller's responsibility to call this only once and | ||||
|  | ||||
							
								
								
									
										60
									
								
								server/server_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								server/server_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | ||||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestStandaloneTLSTicketKeyRotation(t *testing.T) { | ||||
| 	tlsGovChan := make(chan struct{}) | ||||
| 	defer close(tlsGovChan) | ||||
| 	callSync := make(chan bool, 1) | ||||
| 	defer close(callSync) | ||||
| 
 | ||||
| 	oldHook := setSessionTicketKeysTestHook | ||||
| 	defer func() { | ||||
| 		setSessionTicketKeysTestHook = oldHook | ||||
| 	}() | ||||
| 	var keysInUse [][32]byte | ||||
| 	setSessionTicketKeysTestHook = func(keys [][32]byte) [][32]byte { | ||||
| 		keysInUse = keys | ||||
| 		callSync <- true | ||||
| 		return keys | ||||
| 	} | ||||
| 
 | ||||
| 	c := new(tls.Config) | ||||
| 	timer := time.NewTicker(time.Millisecond * 1) | ||||
| 
 | ||||
| 	go standaloneTLSTicketKeyRotation(c, timer, tlsGovChan) | ||||
| 
 | ||||
| 	rounds := 0 | ||||
| 	var lastTicketKey [32]byte | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-callSync: | ||||
| 			if lastTicketKey == keysInUse[0] { | ||||
| 				close(tlsGovChan) | ||||
| 				t.Errorf("The same TLS ticket key has been used again (not rotated): %x.", lastTicketKey) | ||||
| 				return | ||||
| 			} | ||||
| 			lastTicketKey = keysInUse[0] | ||||
| 			rounds++ | ||||
| 			if rounds <= tlsNumTickets && len(keysInUse) != rounds { | ||||
| 				close(tlsGovChan) | ||||
| 				t.Errorf("Expected TLS ticket keys in use: %d; Got instead: %d.", rounds, len(keysInUse)) | ||||
| 				return | ||||
| 			} | ||||
| 			if c.SessionTicketsDisabled == true { | ||||
| 				t.Error("Session tickets have been disabled unexpectedly.") | ||||
| 				return | ||||
| 			} | ||||
| 			if rounds >= tlsNumTickets+1 { | ||||
| 				return | ||||
| 			} | ||||
| 		case <-time.After(time.Second * 1): | ||||
| 			t.Errorf("Timeout after %d rounds.", rounds) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user