mirror of
https://github.com/caddyserver/caddy.git
synced 2025-05-31 12:15:56 -04:00
tls: Refactor internals related to TLS configurations (#1466)
* tls: Refactor TLS config innards with a few minor syntax changes muststaple -> must_staple "http2 off" -> "alpn" with list of ALPN values * Fix typo * Fix QUIC handler * Inline struct field assignments
This commit is contained in:
parent
4b877eebc4
commit
73794f2a2c
@ -79,7 +79,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
|
|||||||
cfg.TLS.Enabled = true
|
cfg.TLS.Enabled = true
|
||||||
cfg.Addr.Scheme = "https"
|
cfg.Addr.Scheme = "https"
|
||||||
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
|
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
|
||||||
_, err := caddytls.CacheManagedCertificate(cfg.Addr.Host, cfg.TLS)
|
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,11 @@ type tlsHandler struct {
|
|||||||
// Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17):
|
// Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17):
|
||||||
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
|
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
|
||||||
func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if h.listener == nil {
|
||||||
|
h.next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.listener.helloInfosMu.RLock()
|
h.listener.helloInfosMu.RLock()
|
||||||
info := h.listener.helloInfos[r.RemoteAddr]
|
info := h.listener.helloInfos[r.RemoteAddr]
|
||||||
h.listener.helloInfosMu.RUnlock()
|
h.listener.helloInfosMu.RUnlock()
|
||||||
@ -78,63 +83,62 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
h.next.ServeHTTP(w, r)
|
h.next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clientHelloConn reads the ClientHello
|
||||||
|
// and stores it in the attached listener.
|
||||||
type clientHelloConn struct {
|
type clientHelloConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
readHello bool
|
|
||||||
listener *tlsHelloListener
|
listener *tlsHelloListener
|
||||||
|
readHello bool // whether ClientHello has been read
|
||||||
|
buf *bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read reads from c.Conn (by letting the standard library
|
||||||
|
// do the reading off the wire), with the exception of
|
||||||
|
// getting a copy of the ClientHello so it can parse it.
|
||||||
func (c *clientHelloConn) Read(b []byte) (n int, err error) {
|
func (c *clientHelloConn) Read(b []byte) (n int, err error) {
|
||||||
if !c.readHello {
|
// if we've already read the ClientHello, pass thru
|
||||||
// Read the header bytes.
|
if c.readHello {
|
||||||
hdr := make([]byte, 5)
|
return c.Conn.Read(b)
|
||||||
n, err := io.ReadFull(c.Conn, hdr)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the length of the ClientHello message and read it as well.
|
|
||||||
length := uint16(hdr[3])<<8 | uint16(hdr[4])
|
|
||||||
hello := make([]byte, int(length))
|
|
||||||
n, err = io.ReadFull(c.Conn, hello)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the ClientHello and store it in the map.
|
|
||||||
rawParsed := parseRawClientHello(hello)
|
|
||||||
c.listener.helloInfosMu.Lock()
|
|
||||||
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
|
|
||||||
c.listener.helloInfosMu.Unlock()
|
|
||||||
|
|
||||||
// Since we buffered the header and ClientHello, pretend we were
|
|
||||||
// never here by lining up the buffered values to be read with a
|
|
||||||
// custom connection type, followed by the rest of the actual
|
|
||||||
// underlying connection.
|
|
||||||
mr := io.MultiReader(bytes.NewReader(hdr), bytes.NewReader(hello), c.Conn)
|
|
||||||
mc := multiConn{Conn: c.Conn, reader: mr}
|
|
||||||
|
|
||||||
c.Conn = mc
|
|
||||||
|
|
||||||
c.readHello = true
|
|
||||||
}
|
}
|
||||||
return c.Conn.Read(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// multiConn is a net.Conn that reads from the
|
// we let the standard lib read off the wire for us, and
|
||||||
// given reader instead of the wire directly. This
|
// tee that into our buffer so we can read the ClientHello
|
||||||
// is useful when some of the connection has already
|
tee := io.TeeReader(c.Conn, c.buf)
|
||||||
// been read (like the TLS Client Hello) and the
|
n, err = tee.Read(b)
|
||||||
// reader is a io.MultiReader that starts with
|
if err != nil {
|
||||||
// the contents of the buffer.
|
return
|
||||||
type multiConn struct {
|
}
|
||||||
net.Conn
|
if c.buf.Len() < 5 {
|
||||||
reader io.Reader
|
return // need to read more bytes for header
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads from mc.reader.
|
// read the header bytes
|
||||||
func (mc multiConn) Read(b []byte) (n int, err error) {
|
hdr := make([]byte, 5)
|
||||||
return mc.reader.Read(b)
|
_, err = io.ReadFull(c.buf, hdr)
|
||||||
|
if err != nil {
|
||||||
|
return // this would be highly unusual and sad
|
||||||
|
}
|
||||||
|
|
||||||
|
// get length of the ClientHello message and read it
|
||||||
|
length := int(uint16(hdr[3])<<8 | uint16(hdr[4]))
|
||||||
|
if c.buf.Len() < length {
|
||||||
|
return // need to read more bytes
|
||||||
|
}
|
||||||
|
hello := make([]byte, length)
|
||||||
|
_, err = io.ReadFull(c.buf, hello)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.buf = nil // buffer no longer needed
|
||||||
|
|
||||||
|
// parse the ClientHello and store it in the map
|
||||||
|
rawParsed := parseRawClientHello(hello)
|
||||||
|
c.listener.helloInfosMu.Lock()
|
||||||
|
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
|
||||||
|
c.listener.helloInfosMu.Unlock()
|
||||||
|
|
||||||
|
c.readHello = true
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseRawClientHello parses data which contains the raw
|
// parseRawClientHello parses data which contains the raw
|
||||||
@ -279,7 +283,7 @@ func (l *tlsHelloListener) Accept() (net.Conn, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
helloConn := &clientHelloConn{Conn: conn, listener: l}
|
helloConn := &clientHelloConn{Conn: conn, listener: l, buf: new(bytes.Buffer)}
|
||||||
return tls.Server(helloConn, l.config), nil
|
return tls.Server(helloConn, l.config), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ func TestHeuristicFunctions(t *testing.T) {
|
|||||||
// clientHello pairs a User-Agent string to its ClientHello message.
|
// clientHello pairs a User-Agent string to its ClientHello message.
|
||||||
type clientHello struct {
|
type clientHello struct {
|
||||||
userAgent string
|
userAgent string
|
||||||
helloHex string
|
helloHex string // do NOT include the header, just the ClientHello message
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientHellos groups samples of true (real) ClientHellos by the
|
// clientHellos groups samples of true (real) ClientHellos by the
|
||||||
@ -158,7 +158,12 @@ func TestHeuristicFunctions(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
// IE 11 on Windows 7, this connection was intercepted by Blue Coat
|
// IE 11 on Windows 7, this connection was intercepted by Blue Coat
|
||||||
helloHex: "010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100",
|
helloHex: `010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Firefox 51.0.1 being intercepted by burp 1.7.17
|
||||||
|
userAgent: "(TODO)",
|
||||||
|
helloHex: `010000d8030358a92f4daca95acc2f6a10a9c50d736135eae39406d3090238464540d482677600003ac023c027003cc025c02900670040c009c013002fc004c00e00330032c02bc02f009cc02dc031009e00a2c008c012000ac003c00d0016001300ff01000075000a0034003200170001000300130015000600070009000a0018000b000c0019000d000e000f001000110002001200040005001400080016000b00020100000d00180016060306010503050104030401040202030201020201010000001700150000126a61677561722e6b796877616e612e6f7267`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -31,40 +31,47 @@ type Server struct {
|
|||||||
connTimeout time.Duration // max time to wait for a connection before force stop
|
connTimeout time.Duration // max time to wait for a connection before force stop
|
||||||
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
|
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
|
||||||
vhosts *vhostTrie
|
vhosts *vhostTrie
|
||||||
tlsConfig caddytls.ConfigGroup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure it satisfies the interface
|
// ensure it satisfies the interface
|
||||||
var _ caddy.GracefulServer = new(Server)
|
var _ caddy.GracefulServer = new(Server)
|
||||||
|
|
||||||
|
var defaultALPN = []string{"h2", "http/1.1"}
|
||||||
|
|
||||||
|
// makeTLSConfig extracts TLS settings from each site config to
|
||||||
|
// build a tls.Config usable in Caddy HTTP servers. The returned
|
||||||
|
// config will be nil if TLS is disabled for these sites.
|
||||||
|
func makeTLSConfig(group []*SiteConfig) (*tls.Config, error) {
|
||||||
|
var tlsConfigs []*caddytls.Config
|
||||||
|
for i := range group {
|
||||||
|
if HTTP2 && len(group[i].TLS.ALPN) == 0 {
|
||||||
|
// if no application-level protocol was configured up to now,
|
||||||
|
// default to HTTP/2, then HTTP/1.1 if necessary
|
||||||
|
group[i].TLS.ALPN = defaultALPN
|
||||||
|
}
|
||||||
|
tlsConfigs = append(tlsConfigs, group[i].TLS)
|
||||||
|
}
|
||||||
|
return caddytls.MakeTLSConfig(tlsConfigs)
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer creates a new Server instance that will listen on addr
|
// NewServer creates a new Server instance that will listen on addr
|
||||||
// and will serve the sites configured in group.
|
// and will serve the sites configured in group.
|
||||||
func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
Server: makeHTTPServer(addr, group),
|
Server: makeHTTPServerWithTimeouts(addr, group),
|
||||||
vhosts: newVHostTrie(),
|
vhosts: newVHostTrie(),
|
||||||
sites: group,
|
sites: group,
|
||||||
connTimeout: GracefulTimeout,
|
connTimeout: GracefulTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Server.Handler = s // this is weird, but whatever
|
s.Server.Handler = s // this is weird, but whatever
|
||||||
tlsh := &tlsHandler{next: s.Server.Handler}
|
|
||||||
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
|
|
||||||
// when a connection closes or is hijacked, delete its entry
|
|
||||||
// in the map, because we are done with it.
|
|
||||||
if tlsh.listener != nil {
|
|
||||||
if cs == http.StateHijacked || cs == http.StateClosed {
|
|
||||||
tlsh.listener.helloInfosMu.Lock()
|
|
||||||
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
|
|
||||||
tlsh.listener.helloInfosMu.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disable HTTP/2 if desired
|
// extract TLS settings from each site config to build
|
||||||
if !HTTP2 {
|
// a tls.Config, which will not be nil if TLS is enabled
|
||||||
s.Server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
tlsConfig, err := makeTLSConfig(group)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
s.Server.TLSConfig = tlsConfig
|
||||||
|
|
||||||
// Enable QUIC if desired
|
// Enable QUIC if desired
|
||||||
if QUIC {
|
if QUIC {
|
||||||
@ -72,41 +79,36 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
|||||||
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
|
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up TLS configuration
|
// if TLS is enabled, make sure we prepare the Server accordingly
|
||||||
tlsConfigs := make(caddytls.ConfigGroup)
|
|
||||||
var allConfigs []*caddytls.Config
|
|
||||||
|
|
||||||
for _, site := range group {
|
|
||||||
|
|
||||||
if err := site.TLS.Build(tlsConfigs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfigs[site.TLS.Hostname] = site.TLS
|
|
||||||
allConfigs = append(allConfigs, site.TLS)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if configs are valid
|
|
||||||
if err := caddytls.CheckConfigs(allConfigs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.tlsConfig = tlsConfigs
|
|
||||||
|
|
||||||
if caddytls.HasTLSEnabled(allConfigs) {
|
|
||||||
s.Server.TLSConfig = &tls.Config{
|
|
||||||
GetConfigForClient: s.tlsConfig.GetConfigForClient,
|
|
||||||
GetCertificate: s.tlsConfig.GetCertificate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2"
|
|
||||||
if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 {
|
|
||||||
s.Server.TLSConfig.NextProtos = []string{"h2"}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Server.TLSConfig != nil {
|
if s.Server.TLSConfig != nil {
|
||||||
s.Server.Handler = tlsh
|
// wrap the HTTP handler with a handler that does MITM detection
|
||||||
|
tlsh := &tlsHandler{next: s.Server.Handler}
|
||||||
|
s.Server.Handler = tlsh // this needs to be the "outer" handler when Serve() is called, for type assertion
|
||||||
|
|
||||||
|
// when Serve() creates the TLS listener later, that listener should
|
||||||
|
// be adding a reference the ClientHello info to a map; this callback
|
||||||
|
// will be sure to clear out that entry when the connection closes.
|
||||||
|
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
|
||||||
|
// when a connection closes or is hijacked, delete its entry
|
||||||
|
// in the map, because we are done with it.
|
||||||
|
if tlsh.listener != nil {
|
||||||
|
if cs == http.StateHijacked || cs == http.StateClosed {
|
||||||
|
tlsh.listener.helloInfosMu.Lock()
|
||||||
|
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
|
||||||
|
tlsh.listener.helloInfosMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// As of Go 1.7, if the Server's TLSConfig is not nil, HTTP/2 is enabled only
|
||||||
|
// if TLSConfig.NextProtos includes the string "h2"
|
||||||
|
if HTTP2 && len(s.Server.TLSConfig.NextProtos) == 0 {
|
||||||
|
// some experimenting shows that this NextProtos must have at least
|
||||||
|
// one value that overlaps with the NextProtos of any other tls.Config
|
||||||
|
// that is returned from GetConfigForClient; if there is no overlap,
|
||||||
|
// the connection will fail (as of Go 1.8, Feb. 2017).
|
||||||
|
s.Server.TLSConfig.NextProtos = defaultALPN
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile custom middleware for every site (enables virtual hosting)
|
// Compile custom middleware for every site (enables virtual hosting)
|
||||||
@ -122,6 +124,61 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// makeHTTPServerWithTimeouts makes an http.Server from the group of
|
||||||
|
// configs in a way that configures timeouts (or, if not set, it uses
|
||||||
|
// the default timeouts) by combining the configuration of each
|
||||||
|
// SiteConfig in the group. (Timeouts are important for mitigating
|
||||||
|
// slowloris attacks.)
|
||||||
|
func makeHTTPServerWithTimeouts(addr string, group []*SiteConfig) *http.Server {
|
||||||
|
// find the minimum duration configured for each timeout
|
||||||
|
var min Timeouts
|
||||||
|
for _, cfg := range group {
|
||||||
|
if cfg.Timeouts.ReadTimeoutSet &&
|
||||||
|
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
|
||||||
|
min.ReadTimeoutSet = true
|
||||||
|
min.ReadTimeout = cfg.Timeouts.ReadTimeout
|
||||||
|
}
|
||||||
|
if cfg.Timeouts.ReadHeaderTimeoutSet &&
|
||||||
|
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
|
||||||
|
min.ReadHeaderTimeoutSet = true
|
||||||
|
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
|
||||||
|
}
|
||||||
|
if cfg.Timeouts.WriteTimeoutSet &&
|
||||||
|
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
|
||||||
|
min.WriteTimeoutSet = true
|
||||||
|
min.WriteTimeout = cfg.Timeouts.WriteTimeout
|
||||||
|
}
|
||||||
|
if cfg.Timeouts.IdleTimeoutSet &&
|
||||||
|
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
|
||||||
|
min.IdleTimeoutSet = true
|
||||||
|
min.IdleTimeout = cfg.Timeouts.IdleTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for the values that were not set, use defaults
|
||||||
|
if !min.ReadTimeoutSet {
|
||||||
|
min.ReadTimeout = defaultTimeouts.ReadTimeout
|
||||||
|
}
|
||||||
|
if !min.ReadHeaderTimeoutSet {
|
||||||
|
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
|
||||||
|
}
|
||||||
|
if !min.WriteTimeoutSet {
|
||||||
|
min.WriteTimeout = defaultTimeouts.WriteTimeout
|
||||||
|
}
|
||||||
|
if !min.IdleTimeoutSet {
|
||||||
|
min.IdleTimeout = defaultTimeouts.IdleTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the final values on the server and return it
|
||||||
|
return &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
ReadTimeout: min.ReadTimeout,
|
||||||
|
ReadHeaderTimeout: min.ReadHeaderTimeout,
|
||||||
|
WriteTimeout: min.WriteTimeout,
|
||||||
|
IdleTimeout: min.IdleTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc {
|
func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
s.quicServer.SetQuicHeaders(w.Header())
|
s.quicServer.SetQuicHeaders(w.Header())
|
||||||
@ -390,62 +447,6 @@ var defaultTimeouts = Timeouts{
|
|||||||
IdleTimeout: 2 * time.Minute,
|
IdleTimeout: 2 * time.Minute,
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeHTTPServer makes an http.Server from the group of configs
|
|
||||||
// in a way that configures timeouts (or, if not set, it uses the
|
|
||||||
// default timeouts) and other http.Server properties by combining
|
|
||||||
// the configuration of each SiteConfig in the group. (Timeouts
|
|
||||||
// are important for mitigating slowloris attacks.)
|
|
||||||
func makeHTTPServer(addr string, group []*SiteConfig) *http.Server {
|
|
||||||
s := &http.Server{Addr: addr}
|
|
||||||
|
|
||||||
// find the minimum duration configured for each timeout
|
|
||||||
var min Timeouts
|
|
||||||
for _, cfg := range group {
|
|
||||||
if cfg.Timeouts.ReadTimeoutSet &&
|
|
||||||
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
|
|
||||||
min.ReadTimeoutSet = true
|
|
||||||
min.ReadTimeout = cfg.Timeouts.ReadTimeout
|
|
||||||
}
|
|
||||||
if cfg.Timeouts.ReadHeaderTimeoutSet &&
|
|
||||||
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
|
|
||||||
min.ReadHeaderTimeoutSet = true
|
|
||||||
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
|
|
||||||
}
|
|
||||||
if cfg.Timeouts.WriteTimeoutSet &&
|
|
||||||
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
|
|
||||||
min.WriteTimeoutSet = true
|
|
||||||
min.WriteTimeout = cfg.Timeouts.WriteTimeout
|
|
||||||
}
|
|
||||||
if cfg.Timeouts.IdleTimeoutSet &&
|
|
||||||
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
|
|
||||||
min.IdleTimeoutSet = true
|
|
||||||
min.IdleTimeout = cfg.Timeouts.IdleTimeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// for the values that were not set, use defaults
|
|
||||||
if !min.ReadTimeoutSet {
|
|
||||||
min.ReadTimeout = defaultTimeouts.ReadTimeout
|
|
||||||
}
|
|
||||||
if !min.ReadHeaderTimeoutSet {
|
|
||||||
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
|
|
||||||
}
|
|
||||||
if !min.WriteTimeoutSet {
|
|
||||||
min.WriteTimeout = defaultTimeouts.WriteTimeout
|
|
||||||
}
|
|
||||||
if !min.IdleTimeoutSet {
|
|
||||||
min.IdleTimeout = defaultTimeouts.IdleTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
// set the final values on the server
|
|
||||||
s.ReadTimeout = min.ReadTimeout
|
|
||||||
s.ReadHeaderTimeout = min.ReadHeaderTimeout
|
|
||||||
s.WriteTimeout = min.WriteTimeout
|
|
||||||
s.IdleTimeout = min.IdleTimeout
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||||
|
@ -92,7 +92,7 @@ func TestMakeHTTPServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
actual := makeHTTPServer("127.0.0.1:9005", tc.group)
|
actual := makeHTTPServerWithTimeouts("127.0.0.1:9005", tc.group)
|
||||||
|
|
||||||
if got, want := actual.Addr, "127.0.0.1:9005"; got != want {
|
if got, want := actual.Addr, "127.0.0.1:9005"; got != want {
|
||||||
t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got)
|
t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got)
|
||||||
|
@ -89,8 +89,8 @@ func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
|||||||
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
|
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
|
||||||
// (meaning that it was obtained or loaded during a TLS handshake).
|
// (meaning that it was obtained or loaded during a TLS handshake).
|
||||||
//
|
//
|
||||||
// This function is safe for concurrent use.
|
// This method is safe for concurrent use.
|
||||||
func CacheManagedCertificate(domain string, cfg *Config) (Certificate, error) {
|
func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
|
||||||
storage, err := cfg.StorageFor(cfg.CAUrl)
|
storage, err := cfg.StorageFor(cfg.CAUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Certificate{}, err
|
return Certificate{}, err
|
||||||
|
@ -109,11 +109,11 @@ type Config struct {
|
|||||||
// Add the must staple TLS extension to the CSR generated by lego/acme
|
// Add the must staple TLS extension to the CSR generated by lego/acme
|
||||||
MustStaple bool
|
MustStaple bool
|
||||||
|
|
||||||
// Disables HTTP2 completely
|
// The list of protocols to choose from for Application Layer
|
||||||
DisableHTTP2 bool
|
// Protocol Negotiation (ALPN).
|
||||||
|
ALPN []string
|
||||||
|
|
||||||
// Holds final tls.Config
|
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
|
||||||
tlsConfig *tls.Config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnDemandState contains some state relevant for providing
|
// OnDemandState contains some state relevant for providing
|
||||||
@ -223,33 +223,20 @@ func (c *Config) StorageFor(caURL string) (Storage, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) Build(group ConfigGroup) error {
|
// buildStandardTLSConfig converts cfg (*caddytls.Config) to a *tls.Config
|
||||||
config, err := cfg.build()
|
// and stores it in cfg so it can be used in servers. If TLS is disabled,
|
||||||
|
// no tls.Config is created.
|
||||||
if err != nil {
|
func (cfg *Config) buildStandardTLSConfig() error {
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config != nil {
|
|
||||||
cfg.tlsConfig = config
|
|
||||||
cfg.tlsConfig.GetCertificate = group.GetCertificate
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *Config) build() (*tls.Config, error) {
|
|
||||||
config := new(tls.Config)
|
|
||||||
|
|
||||||
if !cfg.Enabled {
|
if !cfg.Enabled {
|
||||||
return nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config := new(tls.Config)
|
||||||
|
|
||||||
ciphersAdded := make(map[uint16]struct{})
|
ciphersAdded := make(map[uint16]struct{})
|
||||||
curvesAdded := make(map[tls.CurveID]struct{})
|
curvesAdded := make(map[tls.CurveID]struct{})
|
||||||
|
|
||||||
// Add cipher suites
|
// add cipher suites
|
||||||
for _, ciph := range cfg.Ciphers {
|
for _, ciph := range cfg.Ciphers {
|
||||||
if _, ok := ciphersAdded[ciph]; !ok {
|
if _, ok := ciphersAdded[ciph]; !ok {
|
||||||
ciphersAdded[ciph] = struct{}{}
|
ciphersAdded[ciph] = struct{}{}
|
||||||
@ -259,7 +246,7 @@ func (cfg *Config) build() (*tls.Config, error) {
|
|||||||
|
|
||||||
config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
|
config.PreferServerCipherSuites = cfg.PreferServerCipherSuites
|
||||||
|
|
||||||
// Union curves
|
// add curve preferences
|
||||||
for _, curv := range cfg.CurvePreferences {
|
for _, curv := range cfg.CurvePreferences {
|
||||||
if _, ok := curvesAdded[curv]; !ok {
|
if _, ok := curvesAdded[curv]; !ok {
|
||||||
curvesAdded[curv] = struct{}{}
|
curvesAdded[curv] = struct{}{}
|
||||||
@ -270,8 +257,10 @@ func (cfg *Config) build() (*tls.Config, error) {
|
|||||||
config.MinVersion = cfg.ProtocolMinVersion
|
config.MinVersion = cfg.ProtocolMinVersion
|
||||||
config.MaxVersion = cfg.ProtocolMaxVersion
|
config.MaxVersion = cfg.ProtocolMaxVersion
|
||||||
config.ClientAuth = cfg.ClientAuth
|
config.ClientAuth = cfg.ClientAuth
|
||||||
|
config.NextProtos = cfg.ALPN
|
||||||
|
config.GetCertificate = cfg.GetCertificate
|
||||||
|
|
||||||
// Set up client authentication if enabled
|
// set up client authentication if enabled
|
||||||
if config.ClientAuth != tls.NoClientCert {
|
if config.ClientAuth != tls.NoClientCert {
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
clientCertsAdded := make(map[string]struct{})
|
clientCertsAdded := make(map[string]struct{})
|
||||||
@ -286,45 +275,51 @@ func (cfg *Config) build() (*tls.Config, error) {
|
|||||||
// Any client with a certificate from this CA will be allowed to connect
|
// Any client with a certificate from this CA will be allowed to connect
|
||||||
caCrt, err := ioutil.ReadFile(caFile)
|
caCrt, err := ioutil.ReadFile(caFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !pool.AppendCertsFromPEM(caCrt) {
|
if !pool.AppendCertsFromPEM(caCrt) {
|
||||||
return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
|
return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config.ClientCAs = pool
|
config.ClientCAs = pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default cipher suites
|
// default cipher suites
|
||||||
if len(config.CipherSuites) == 0 {
|
if len(config.CipherSuites) == 0 {
|
||||||
config.CipherSuites = defaultCiphers
|
config.CipherSuites = defaultCiphers
|
||||||
}
|
}
|
||||||
|
|
||||||
// For security, ensure TLS_FALLBACK_SCSV is always included first
|
// for security, ensure TLS_FALLBACK_SCSV is always included first
|
||||||
if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV {
|
if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV {
|
||||||
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
|
config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.DisableHTTP2 {
|
// store the resulting new tls.Config
|
||||||
config.NextProtos = []string{}
|
cfg.tlsConfig = config
|
||||||
} else {
|
|
||||||
config.NextProtos = []string{"h2"}
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckConfigs checks if multiple TLS configs does not collide with each other
|
// MakeTLSConfig makes a tls.Config from configs. The returned
|
||||||
func CheckConfigs(configs []*Config) error {
|
// tls.Config is programmed to load the matching caddytls.Config
|
||||||
|
// based on the hostname in SNI, but that's all.
|
||||||
|
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
||||||
if len(configs) == 0 {
|
if len(configs) == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, cfg := range configs {
|
configMap := make(configGroup)
|
||||||
|
|
||||||
// Can't serve TLS and not-TLS on same port
|
for i, cfg := range configs {
|
||||||
|
if cfg == nil {
|
||||||
|
// avoid nil pointer dereference below this loop
|
||||||
|
configs[i] = new(Config)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// can't serve TLS and non-TLS on same port
|
||||||
if i > 0 && cfg.Enabled != configs[i-1].Enabled {
|
if i > 0 && cfg.Enabled != configs[i-1].Enabled {
|
||||||
thisConfProto, lastConfProto := "not TLS", "not TLS"
|
thisConfProto, lastConfProto := "not TLS", "not TLS"
|
||||||
if cfg.Enabled {
|
if cfg.Enabled {
|
||||||
@ -333,26 +328,33 @@ func CheckConfigs(configs []*Config) error {
|
|||||||
if configs[i-1].Enabled {
|
if configs[i-1].Enabled {
|
||||||
lastConfProto = "TLS"
|
lastConfProto = "TLS"
|
||||||
}
|
}
|
||||||
return fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
|
return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener",
|
||||||
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
|
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cfg.Enabled {
|
// convert each caddytls.Config into a tls.Config
|
||||||
continue
|
if err := cfg.buildStandardTLSConfig(); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Key this config by its hostname (overwriting
|
||||||
|
// configs with the same hostname pattern); during
|
||||||
|
// TLS handshakes, configs are loaded based on
|
||||||
|
// the hostname pattern, according to client's SNI.
|
||||||
|
configMap[cfg.Hostname] = cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Is TLS disabled? By now, we know that all
|
||||||
}
|
// configs agree whether it is or not, so we
|
||||||
|
// can just look at the first one. If so,
|
||||||
func HasTLSEnabled(configs []*Config) bool {
|
// we're done here.
|
||||||
for _, config := range configs {
|
if len(configs) == 0 || !configs[0].Enabled {
|
||||||
if config.Enabled {
|
return nil, nil
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return &tls.Config{
|
||||||
|
GetConfigForClient: configMap.GetConfigForClient,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigGetter gets a Config keyed by key.
|
// ConfigGetter gets a Config keyed by key.
|
||||||
|
@ -8,50 +8,50 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMakeTLSConfigProtocolVersions(t *testing.T) {
|
func TestConvertTLSConfigProtocolVersions(t *testing.T) {
|
||||||
// same min and max protocol versions
|
// same min and max protocol versions
|
||||||
config := Config{
|
config := &Config{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
ProtocolMinVersion: tls.VersionTLS12,
|
ProtocolMinVersion: tls.VersionTLS12,
|
||||||
ProtocolMaxVersion: tls.VersionTLS12,
|
ProtocolMaxVersion: tls.VersionTLS12,
|
||||||
}
|
}
|
||||||
result, err := config.build()
|
err := config.buildStandardTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Did not expect an error, but got %v", err)
|
t.Fatalf("Did not expect an error, but got %v", err)
|
||||||
}
|
}
|
||||||
if got, want := result.MinVersion, uint16(tls.VersionTLS12); got != want {
|
if got, want := config.tlsConfig.MinVersion, uint16(tls.VersionTLS12); got != want {
|
||||||
t.Errorf("Expected min version to be %x, got %x", want, got)
|
t.Errorf("Expected min version to be %x, got %x", want, got)
|
||||||
}
|
}
|
||||||
if got, want := result.MaxVersion, uint16(tls.VersionTLS12); got != want {
|
if got, want := config.tlsConfig.MaxVersion, uint16(tls.VersionTLS12); got != want {
|
||||||
t.Errorf("Expected max version to be %x, got %x", want, got)
|
t.Errorf("Expected max version to be %x, got %x", want, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) {
|
func TestConvertTLSConfigPreferServerCipherSuites(t *testing.T) {
|
||||||
// prefer server cipher suites
|
// prefer server cipher suites
|
||||||
config := Config{Enabled: true, PreferServerCipherSuites: true}
|
config := Config{Enabled: true, PreferServerCipherSuites: true}
|
||||||
result, err := config.build()
|
err := config.buildStandardTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Did not expect an error, but got %v", err)
|
t.Fatalf("Did not expect an error, but got %v", err)
|
||||||
}
|
}
|
||||||
if got, want := result.PreferServerCipherSuites, true; got != want {
|
if got, want := config.tlsConfig.PreferServerCipherSuites, true; got != want {
|
||||||
t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got)
|
t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) {
|
func TestMakeTLSConfigTLSEnabledDisabledError(t *testing.T) {
|
||||||
// verify handling when Enabled is true and false
|
// verify handling when Enabled is true and false
|
||||||
configs := []*Config{
|
configs := []*Config{
|
||||||
{Enabled: true},
|
{Enabled: true},
|
||||||
{Enabled: false},
|
{Enabled: false},
|
||||||
}
|
}
|
||||||
err := CheckConfigs(configs)
|
_, err := MakeTLSConfig(configs)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Expected an error, but got %v", err)
|
t.Fatalf("Expected an error, but got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMakeTLSConfigCipherSuites(t *testing.T) {
|
func TestConvertTLSConfigCipherSuites(t *testing.T) {
|
||||||
// ensure cipher suites are unioned and
|
// ensure cipher suites are unioned and
|
||||||
// that TLS_FALLBACK_SCSV is prepended
|
// that TLS_FALLBACK_SCSV is prepended
|
||||||
configs := []*Config{
|
configs := []*Config{
|
||||||
@ -67,10 +67,13 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, config := range configs {
|
for i, config := range configs {
|
||||||
cfg, _ := config.build()
|
err := config.buildStandardTLSConfig()
|
||||||
|
if err != nil {
|
||||||
if !reflect.DeepEqual(cfg.CipherSuites, expectedCiphers[i]) {
|
t.Errorf("Test %d: Expected no error, got: %v", i, err)
|
||||||
t.Errorf("Expected ciphers %v but got %v", expectedCiphers[i], cfg.CipherSuites)
|
}
|
||||||
|
if !reflect.DeepEqual(config.tlsConfig.CipherSuites, expectedCiphers[i]) {
|
||||||
|
t.Errorf("Test %d: Expected ciphers %v but got %v",
|
||||||
|
i, expectedCiphers[i], config.tlsConfig.CipherSuites)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -13,18 +13,19 @@ import (
|
|||||||
|
|
||||||
// configGroup is a type that keys configs by their hostname
|
// configGroup is a type that keys configs by their hostname
|
||||||
// (hostnames can have wildcard characters; use the getConfig
|
// (hostnames can have wildcard characters; use the getConfig
|
||||||
// method to get a config by matching its hostname). Its
|
// method to get a config by matching its hostname).
|
||||||
// GetCertificate function can be used with tls.Config.
|
type configGroup map[string]*Config
|
||||||
type ConfigGroup map[string]*Config
|
|
||||||
|
|
||||||
// getConfig gets the config by the first key match for name.
|
// getConfig gets the config by the first key match for name.
|
||||||
// In other words, "sub.foo.bar" will get the config for "*.foo.bar"
|
// In other words, "sub.foo.bar" will get the config for "*.foo.bar"
|
||||||
// if that is the closest match. This function MAY return nil
|
// if that is the closest match. If no match is found, the first
|
||||||
// if no match is found.
|
// (random) config will be loaded, which will defer any TLS alerts
|
||||||
|
// to the certificate validation (this may or may not be ideal;
|
||||||
|
// let's talk about it if this becomes problematic).
|
||||||
//
|
//
|
||||||
// This function follows nearly the same logic to lookup
|
// This function follows nearly the same logic to lookup
|
||||||
// a hostname as the getCertificate function uses.
|
// a hostname as the getCertificate function uses.
|
||||||
func (cg ConfigGroup) getConfig(name string) *Config {
|
func (cg configGroup) getConfig(name string) *Config {
|
||||||
name = strings.ToLower(name)
|
name = strings.ToLower(name)
|
||||||
|
|
||||||
// exact match? great, let's use it
|
// exact match? great, let's use it
|
||||||
@ -42,14 +43,36 @@ func (cg ConfigGroup) getConfig(name string) *Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// as last resort, try a config that serves all names
|
// as a fallback, try a config that serves all names
|
||||||
if config, ok := cg[""]; ok {
|
if config, ok := cg[""]; ok {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// as a last resort, use a random config
|
||||||
|
// (even if the config isn't for that hostname,
|
||||||
|
// it should help us serve clients without SNI
|
||||||
|
// or at least defer TLS alerts to the cert)
|
||||||
|
for _, config := range cg {
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfigForClient gets a TLS configuration satisfying clientHello.
|
||||||
|
// In getting the configuration, it abides the rules and settings
|
||||||
|
// defined in the Config that matches clientHello.ServerName. If no
|
||||||
|
// tls.Config is set on the matching Config, a nil value is returned.
|
||||||
|
//
|
||||||
|
// This method is safe for use as a tls.Config.GetConfigForClient callback.
|
||||||
|
func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
config := cg.getConfig(clientHello.ServerName)
|
||||||
|
if config != nil {
|
||||||
|
return config.tlsConfig, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetCertificate gets a certificate to satisfy clientHello. In getting
|
// GetCertificate gets a certificate to satisfy clientHello. In getting
|
||||||
// the certificate, it abides the rules and settings defined in the
|
// the certificate, it abides the rules and settings defined in the
|
||||||
// Config that matches clientHello.ServerName. It first checks the in-
|
// Config that matches clientHello.ServerName. It first checks the in-
|
||||||
@ -58,27 +81,11 @@ func (cg ConfigGroup) getConfig(name string) *Config {
|
|||||||
// via ACME.
|
// via ACME.
|
||||||
//
|
//
|
||||||
// This method is safe for use as a tls.Config.GetCertificate callback.
|
// This method is safe for use as a tls.Config.GetCertificate callback.
|
||||||
func (cg ConfigGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
|
cert, err := cfg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
|
||||||
return &cert.Certificate, err
|
return &cert.Certificate, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfigForClient gets a TLS configuration satisfying clientHello. In getting
|
|
||||||
// the configuration, it abides the rules and settings defined in the
|
|
||||||
// Config that matches clientHello.ServerName.
|
|
||||||
//
|
|
||||||
// This method is safe for use as a tls.Config.GetConfigForClient callback.
|
|
||||||
func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
||||||
|
|
||||||
config := cg.getConfig(clientHello.ServerName)
|
|
||||||
|
|
||||||
if config != nil {
|
|
||||||
return config.tlsConfig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getCertDuringHandshake will get a certificate for name. It first tries
|
// getCertDuringHandshake will get a certificate for name. It first tries
|
||||||
// the in-memory cache. If no certificate for name is in the cache, the
|
// the in-memory cache. If no certificate for name is in the cache, the
|
||||||
// config most closely corresponding to name will be loaded. If that config
|
// config most closely corresponding to name will be loaded. If that config
|
||||||
@ -90,21 +97,20 @@ func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls
|
|||||||
// certificate is available.
|
// certificate is available.
|
||||||
//
|
//
|
||||||
// This function is safe for concurrent use.
|
// This function is safe for concurrent use.
|
||||||
func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||||
// First check our in-memory cache to see if we've already loaded it
|
// First check our in-memory cache to see if we've already loaded it
|
||||||
cert, matched, defaulted := getCertificate(name)
|
cert, matched, defaulted := getCertificate(name)
|
||||||
if matched {
|
if matched {
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the relevant TLS config for this name. If OnDemand is enabled,
|
// If OnDemand is enabled, then we might be able to load or
|
||||||
// then we might be able to load or obtain a needed certificate.
|
// obtain a needed certificate
|
||||||
cfg := cg.getConfig(name)
|
if cfg.OnDemand && loadIfNecessary {
|
||||||
if cfg != nil && cfg.OnDemand && loadIfNecessary {
|
|
||||||
// Then check to see if we have one on disk
|
// Then check to see if we have one on disk
|
||||||
loadedCert, err := CacheManagedCertificate(name, cfg)
|
loadedCert, err := cfg.CacheManagedCertificate(name)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
loadedCert, err = cg.handshakeMaintenance(name, loadedCert)
|
loadedCert, err = cfg.handshakeMaintenance(name, loadedCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
|
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
|
||||||
}
|
}
|
||||||
@ -116,7 +122,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
|
|||||||
name = strings.ToLower(name)
|
name = strings.ToLower(name)
|
||||||
|
|
||||||
// Make sure aren't over any applicable limits
|
// Make sure aren't over any applicable limits
|
||||||
err := cg.checkLimitsForObtainingNewCerts(name, cfg)
|
err := cfg.checkLimitsForObtainingNewCerts(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Certificate{}, err
|
return Certificate{}, err
|
||||||
}
|
}
|
||||||
@ -127,7 +133,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Obtain certificate from the CA
|
// Obtain certificate from the CA
|
||||||
return cg.obtainOnDemandCertificate(name, cfg)
|
return cfg.obtainOnDemandCertificate(name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,7 +149,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai
|
|||||||
// now according to mitigating factors we keep track of and preferences the
|
// now according to mitigating factors we keep track of and preferences the
|
||||||
// user has set. If a non-nil error is returned, do not issue a new certificate
|
// user has set. If a non-nil error is returned, do not issue a new certificate
|
||||||
// for name.
|
// for name.
|
||||||
func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error {
|
func (cfg *Config) checkLimitsForObtainingNewCerts(name string) error {
|
||||||
// User can set hard limit for number of certs for the process to issue
|
// User can set hard limit for number of certs for the process to issue
|
||||||
if cfg.OnDemandState.MaxObtain > 0 &&
|
if cfg.OnDemandState.MaxObtain > 0 &&
|
||||||
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
|
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
|
||||||
@ -167,7 +173,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
|
|||||||
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
|
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 👍Good to go
|
// Good to go 👍
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,7 +182,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config)
|
|||||||
// name, it will wait and use what the other goroutine obtained.
|
// name, it will wait and use what the other goroutine obtained.
|
||||||
//
|
//
|
||||||
// This function is safe for use by multiple concurrent goroutines.
|
// This function is safe for use by multiple concurrent goroutines.
|
||||||
func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) {
|
func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
|
||||||
// We must protect this process from happening concurrently, so synchronize.
|
// We must protect this process from happening concurrently, so synchronize.
|
||||||
obtainCertWaitChansMu.Lock()
|
obtainCertWaitChansMu.Lock()
|
||||||
wait, ok := obtainCertWaitChans[name]
|
wait, ok := obtainCertWaitChans[name]
|
||||||
@ -185,7 +191,7 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
|
|||||||
// wait for it to finish obtaining the cert and then we'll use it.
|
// wait for it to finish obtaining the cert and then we'll use it.
|
||||||
obtainCertWaitChansMu.Unlock()
|
obtainCertWaitChansMu.Unlock()
|
||||||
<-wait
|
<-wait
|
||||||
return cg.getCertDuringHandshake(name, true, false)
|
return cfg.getCertDuringHandshake(name, true, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// looks like it's up to us to do all the work and obtain the cert.
|
// looks like it's up to us to do all the work and obtain the cert.
|
||||||
@ -228,19 +234,19 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi
|
|||||||
lastIssueTimeMu.Unlock()
|
lastIssueTimeMu.Unlock()
|
||||||
|
|
||||||
// certificate is already on disk; now just start over to load it and serve it
|
// certificate is already on disk; now just start over to load it and serve it
|
||||||
return cg.getCertDuringHandshake(name, true, false)
|
return cfg.getCertDuringHandshake(name, true, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handshakeMaintenance performs a check on cert for expiration and OCSP
|
// handshakeMaintenance performs a check on cert for expiration and OCSP
|
||||||
// validity.
|
// validity.
|
||||||
//
|
//
|
||||||
// This function is safe for use by multiple concurrent goroutines.
|
// This function is safe for use by multiple concurrent goroutines.
|
||||||
func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
|
func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
|
||||||
// Check cert expiration
|
// Check cert expiration
|
||||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||||
if timeLeft < RenewDurationBefore {
|
if timeLeft < RenewDurationBefore {
|
||||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
||||||
return cg.renewDynamicCertificate(name, cert.Config)
|
return cfg.renewDynamicCertificate(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check OCSP staple validity
|
// Check OCSP staple validity
|
||||||
@ -268,7 +274,7 @@ func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certi
|
|||||||
// usable. name should already be lower-cased before calling this function.
|
// usable. name should already be lower-cased before calling this function.
|
||||||
//
|
//
|
||||||
// This function is safe for use by multiple concurrent goroutines.
|
// This function is safe for use by multiple concurrent goroutines.
|
||||||
func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) {
|
func (cfg *Config) renewDynamicCertificate(name string) (Certificate, error) {
|
||||||
obtainCertWaitChansMu.Lock()
|
obtainCertWaitChansMu.Lock()
|
||||||
wait, ok := obtainCertWaitChans[name]
|
wait, ok := obtainCertWaitChans[name]
|
||||||
if ok {
|
if ok {
|
||||||
@ -276,7 +282,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
|
|||||||
// wait for it to finish, then we'll use the new one.
|
// wait for it to finish, then we'll use the new one.
|
||||||
obtainCertWaitChansMu.Unlock()
|
obtainCertWaitChansMu.Unlock()
|
||||||
<-wait
|
<-wait
|
||||||
return cg.getCertDuringHandshake(name, true, false)
|
return cfg.getCertDuringHandshake(name, true, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// looks like it's up to us to do all the work and renew the cert
|
// looks like it's up to us to do all the work and renew the cert
|
||||||
@ -300,7 +306,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi
|
|||||||
return Certificate{}, err
|
return Certificate{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return cg.getCertDuringHandshake(name, true, false)
|
return cfg.getCertDuringHandshake(name, true, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
|
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
func TestGetCertificate(t *testing.T) {
|
func TestGetCertificate(t *testing.T) {
|
||||||
defer func() { certCache = make(map[string]Certificate) }()
|
defer func() { certCache = make(map[string]Certificate) }()
|
||||||
|
|
||||||
cg := make(ConfigGroup)
|
cfg := new(Config)
|
||||||
|
|
||||||
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
||||||
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
||||||
@ -17,10 +17,10 @@ func TestGetCertificate(t *testing.T) {
|
|||||||
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
|
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
|
||||||
|
|
||||||
// When cache is empty
|
// When cache is empty
|
||||||
if cert, err := cg.GetCertificate(hello); err == nil {
|
if cert, err := cfg.GetCertificate(hello); err == nil {
|
||||||
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
|
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
|
||||||
}
|
}
|
||||||
if cert, err := cg.GetCertificate(helloNoSNI); err == nil {
|
if cert, err := cfg.GetCertificate(helloNoSNI); err == nil {
|
||||||
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,12 +28,12 @@ func TestGetCertificate(t *testing.T) {
|
|||||||
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
||||||
certCache[""] = defaultCert
|
certCache[""] = defaultCert
|
||||||
certCache["example.com"] = defaultCert
|
certCache["example.com"] = defaultCert
|
||||||
if cert, err := cg.GetCertificate(hello); err != nil {
|
if cert, err := cfg.GetCertificate(hello); err != nil {
|
||||||
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||||
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
||||||
}
|
}
|
||||||
if cert, err := cg.GetCertificate(helloNoSNI); err != nil {
|
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||||
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||||
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
|
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
|
||||||
@ -41,14 +41,14 @@ func TestGetCertificate(t *testing.T) {
|
|||||||
|
|
||||||
// When retrieving wildcard certificate
|
// When retrieving wildcard certificate
|
||||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
|
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
|
||||||
if cert, err := cg.GetCertificate(helloSub); err != nil {
|
if cert, err := cfg.GetCertificate(helloSub); err != nil {
|
||||||
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
||||||
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When no certificate matches, the default is returned
|
// When no certificate matches, the default is returned
|
||||||
if cert, err := cg.GetCertificate(helloNoMatch); err != nil {
|
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil {
|
||||||
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
|
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||||
t.Errorf("Expected default cert with no matches, got: %v", cert)
|
t.Errorf("Expected default cert with no matches, got: %v", cert)
|
||||||
|
@ -152,7 +152,7 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
|
|||||||
delete(certCache, "")
|
delete(certCache, "")
|
||||||
certCacheMu.Unlock()
|
certCacheMu.Unlock()
|
||||||
}
|
}
|
||||||
_, err := CacheManagedCertificate(cert.Names[0], cert.Config)
|
_, err := cert.Config.CacheManagedCertificate(cert.Names[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if allowPrompts {
|
if allowPrompts {
|
||||||
return err // operator is present, so report error immediately
|
return err // operator is present, so report error immediately
|
||||||
|
@ -164,21 +164,15 @@ func setupTLS(c *caddy.Controller) error {
|
|||||||
return c.Errf("Unsupported Storage provider '%s'", args[0])
|
return c.Errf("Unsupported Storage provider '%s'", args[0])
|
||||||
}
|
}
|
||||||
config.StorageProvider = args[0]
|
config.StorageProvider = args[0]
|
||||||
|
case "alpn":
|
||||||
case "http2":
|
|
||||||
args := c.RemainingArgs()
|
args := c.RemainingArgs()
|
||||||
if len(args) != 1 {
|
if len(args) == 0 {
|
||||||
return c.ArgErr()
|
return c.ArgErr()
|
||||||
}
|
}
|
||||||
|
for _, arg := range args {
|
||||||
switch args[0] {
|
config.ALPN = append(config.ALPN, arg)
|
||||||
case "off":
|
|
||||||
config.DisableHTTP2 = true
|
|
||||||
default:
|
|
||||||
c.ArgErr()
|
|
||||||
}
|
}
|
||||||
|
case "must_staple":
|
||||||
case "muststaple":
|
|
||||||
config.MustStaple = true
|
config.MustStaple = true
|
||||||
default:
|
default:
|
||||||
return c.Errf("Unknown keyword '%s'", c.Val())
|
return c.Errf("Unknown keyword '%s'", c.Val())
|
||||||
|
@ -91,8 +91,8 @@ func TestSetupParseBasic(t *testing.T) {
|
|||||||
t.Error("Expected PreferServerCipherSuites = true, but was false")
|
t.Error("Expected PreferServerCipherSuites = true, but was false")
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.DisableHTTP2 {
|
if len(cfg.ALPN) != 0 {
|
||||||
t.Error("Expected HTTP2 to be enabled by default")
|
t.Error("Expected ALPN empty by default")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure curve count is correct
|
// Ensure curve count is correct
|
||||||
@ -121,8 +121,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
|||||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||||
protocols tls1.0 tls1.2
|
protocols tls1.0 tls1.2
|
||||||
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
|
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
|
||||||
muststaple
|
must_staple
|
||||||
http2 off
|
alpn http/1.1
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
cfg := new(Config)
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
@ -149,8 +149,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
|||||||
t.Error("Expected must staple to be true")
|
t.Error("Expected must staple to be true")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cfg.DisableHTTP2 {
|
if len(cfg.ALPN) != 1 || cfg.ALPN[0] != "http/1.1" {
|
||||||
t.Error("Expected HTTP2 to be disabled")
|
t.Errorf("Expected ALPN to contain only 'http/1.1' but got: %v", cfg.ALPN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user