diff --git a/caddy/https/certificates.go b/caddy/https/certificates.go index b123d4c32..0dc3db523 100644 --- a/caddy/https/certificates.go +++ b/caddy/https/certificates.go @@ -50,23 +50,21 @@ type Certificate struct { OCSP *ocsp.Response } -// getCertificate gets a certificate from the in-memory cache that -// matches name (a certificate name). Note that if name does not have -// an exact match, it will be checked against names of the form -// '*.example.com' (wildcard certificates) according to RFC 6125. -// -// If cert was found by matching name, matched will be returned true. -// If no match is found, the default certificate will be returned and -// matched will be returned as false. (The default certificate is the -// first one that entered the cache.) If the cache is empty (or there -// is no default certificate for some reason), matched will still be -// false, but cert.Certificate will be nil. +// getCertificate gets a certificate that matches name (a server name) +// from the in-memory cache. If there is no exact match for name, it +// will be checked against names of the form '*.example.com' (wildcard +// certificates) according to RFC 6125. If a match is found, matched will +// be true. If no matches are found, matched will be false and a default +// certificate will be returned with defaulted set to true. If no default +// certificate is set, defaulted will be set to false. // // The logic in this function is adapted from the Go standard library, // which is by the Go Authors. // // This function is safe for concurrent use. -func getCertificate(name string) (cert Certificate, matched bool) { +func getCertificate(name string) (cert Certificate, matched, defaulted bool) { + var ok bool + // Not going to trim trailing dots here since RFC 3546 says, // "The hostname is represented ... without a trailing dot." // Just normalize to lowercase. @@ -76,8 +74,9 @@ func getCertificate(name string) (cert Certificate, matched bool) { defer certCacheMu.RUnlock() // exact match? great, let's use it - if cert, ok := certCache[name]; ok { - return cert, true + if cert, ok = certCache[name]; ok { + matched = true + return } // try replacing labels in the name with wildcards until we get a match @@ -85,14 +84,15 @@ func getCertificate(name string) (cert Certificate, matched bool) { for i := range labels { labels[i] = "*" candidate := strings.Join(labels, ".") - if cert, ok := certCache[candidate]; ok { - return cert, true + if cert, ok = certCache[candidate]; ok { + matched = true + return } } - // if nothing matches, return the default certificate - cert = certCache[""] - return cert, false + // if nothing matches, use the default certificate or bust + cert, defaulted = certCache[""] + return } // cacheManagedCertificate loads the certificate for domain into the @@ -214,8 +214,8 @@ func cacheCertificate(cert Certificate) { certCacheMu.Lock() if _, ok := certCache[""]; !ok { // use as default - certCache[""] = cert cert.Names = append(cert.Names, "") + certCache[""] = cert } for len(certCache)+len(cert.Names) > 10000 { // for simplicity, just remove random elements diff --git a/caddy/https/certificates_test.go b/caddy/https/certificates_test.go new file mode 100644 index 000000000..dbfb4efc1 --- /dev/null +++ b/caddy/https/certificates_test.go @@ -0,0 +1,59 @@ +package https + +import "testing" + +func TestUnexportedGetCertificate(t *testing.T) { + defer func() { certCache = make(map[string]Certificate) }() + + // When cache is empty + if _, matched, defaulted := getCertificate("example.com"); matched || defaulted { + t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted) + } + + // When cache has one certificate in it (also is default) + defaultCert := Certificate{Names: []string{"example.com", ""}} + certCache[""] = defaultCert + certCache["example.com"] = defaultCert + if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" { + t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) + } + if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" { + t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) + } + + // When retrieving wildcard certificate + certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}} + if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" { + t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) + } + + // When no certificate matches, the default is returned + if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted { + t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) + } else if cert.Names[0] != "example.com" { + t.Errorf("Expected default cert, got: %v", cert) + } +} + +func TestCacheCertificate(t *testing.T) { + defer func() { certCache = make(map[string]Certificate) }() + + cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}}) + if _, ok := certCache["example.com"]; !ok { + t.Error("Expected first cert to be cached by key 'example.com', but it wasn't") + } + if _, ok := certCache["sub.example.com"]; !ok { + t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't") + } + if cert, ok := certCache[""]; !ok || cert.Names[2] != "" { + t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't") + } + + cacheCertificate(Certificate{Names: []string{"example2.com"}}) + if _, ok := certCache["example2.com"]; !ok { + t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't") + } + if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" { + t.Error("Expected second cert to NOT be cached as default, but it was") + } +} diff --git a/caddy/https/handshake.go b/caddy/https/handshake.go index 38f9afb55..fc6ef809e 100644 --- a/caddy/https/handshake.go +++ b/caddy/https/handshake.go @@ -39,31 +39,30 @@ func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, } // getCertDuringHandshake will get a certificate for name. It first tries -// the in-memory cache. If no certificate for name is in the cach and if +// the in-memory cache. If no certificate for name is in the cache and if // loadIfNecessary == true, it goes to disk to load it into the cache and // serve it. If it's not on disk and if obtainIfNecessary == true, the // certificate will be obtained from the CA, cached, and served. If // obtainIfNecessary is true, then loadIfNecessary must also be set to true. +// An error will be returned if and only if no certificate is available. // // This function is safe for concurrent use. func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { // First check our in-memory cache to see if we've already loaded it - cert, ok := getCertificate(name) - if ok { + cert, matched, defaulted := getCertificate(name) + if matched { return cert, nil } if loadIfNecessary { - var err error - // Then check to see if we have one on disk - cert, err = cacheManagedCertificate(name, true) + loadedCert, err := cacheManagedCertificate(name, true) if err == nil { - cert, err = handshakeMaintenance(name, cert) + loadedCert, err = handshakeMaintenance(name, loadedCert) if err != nil { log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) } - return cert, nil + return loadedCert, nil } if obtainIfNecessary { @@ -87,7 +86,11 @@ func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool } } - return Certificate{}, nil + if defaulted { + return cert, nil + } + + return Certificate{}, errors.New("no certificate for " + name) } // checkLimitsForObtainingNewCerts checks to see if name can be issued right diff --git a/caddy/https/handshake_test.go b/caddy/https/handshake_test.go new file mode 100644 index 000000000..cf70eb17d --- /dev/null +++ b/caddy/https/handshake_test.go @@ -0,0 +1,54 @@ +package https + +import ( + "crypto/tls" + "crypto/x509" + "testing" +) + +func TestGetCertificate(t *testing.T) { + defer func() { certCache = make(map[string]Certificate) }() + + hello := &tls.ClientHelloInfo{ServerName: "example.com"} + helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} + helloNoSNI := &tls.ClientHelloInfo{} + helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} + + // When cache is empty + if cert, err := GetCertificate(hello); err == nil { + t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert) + } + if cert, err := GetCertificate(helloNoSNI); err == nil { + t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) + } + + // When cache has one certificate in it (also is default) + defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} + certCache[""] = defaultCert + certCache["example.com"] = defaultCert + if cert, err := GetCertificate(hello); err != nil { + t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) + } else if cert.Leaf.DNSNames[0] != "example.com" { + t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) + } + if cert, err := GetCertificate(helloNoSNI); err != nil { + t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) + } else if cert.Leaf.DNSNames[0] != "example.com" { + t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert) + } + + // When retrieving wildcard certificate + certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} + if cert, err := GetCertificate(helloSub); err != nil { + t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) + } else if cert.Leaf.DNSNames[0] != "*.example.com" { + t.Errorf("Got wrong certificate, expected wildcard: %v", cert) + } + + // When no certificate matches, the default is returned + if cert, err := GetCertificate(helloNoMatch); err != nil { + t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) + } else if cert.Leaf.DNSNames[0] != "example.com" { + t.Errorf("Expected default cert with no matches, got: %v", cert) + } +}