mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-10-31 02:27:19 -04:00 
			
		
		
		
	Merge pull request #1373 from mholt/go18shutdown
Replace our old faithful gracefulListener with Go 1.8's Shutdown()
This commit is contained in:
		
						commit
						524dcee9f6
					
				| @ -1,80 +0,0 @@ | |||||||
| package httpserver |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"net" |  | ||||||
| 	"sync" |  | ||||||
| 	"syscall" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // TODO: Should this be a generic graceful listener available in its own package or something? |  | ||||||
| // Also, passing in a WaitGroup is a little awkward. Why can't this listener just keep |  | ||||||
| // the waitgroup internal to itself? |  | ||||||
| 
 |  | ||||||
| // newGracefulListener returns a gracefulListener that wraps l and |  | ||||||
| // uses wg (stored in the host server) to count connections. |  | ||||||
| func newGracefulListener(l net.Listener, wg *sync.WaitGroup) *gracefulListener { |  | ||||||
| 	gl := &gracefulListener{Listener: l, stop: make(chan error), connWg: wg} |  | ||||||
| 	go func() { |  | ||||||
| 		<-gl.stop |  | ||||||
| 		gl.Lock() |  | ||||||
| 		gl.stopped = true |  | ||||||
| 		gl.Unlock() |  | ||||||
| 		gl.stop <- gl.Listener.Close() |  | ||||||
| 	}() |  | ||||||
| 	return gl |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // gracefuListener is a net.Listener which can |  | ||||||
| // count the number of connections on it. Its |  | ||||||
| // methods mainly wrap net.Listener to be graceful. |  | ||||||
| type gracefulListener struct { |  | ||||||
| 	net.Listener |  | ||||||
| 	stop       chan error |  | ||||||
| 	stopped    bool |  | ||||||
| 	sync.Mutex                 // protects the stopped flag |  | ||||||
| 	connWg     *sync.WaitGroup // pointer to the host's wg used for counting connections |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Accept accepts a connection. |  | ||||||
| func (gl *gracefulListener) Accept() (c net.Conn, err error) { |  | ||||||
| 	c, err = gl.Listener.Accept() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	c = gracefulConn{Conn: c, connWg: gl.connWg} |  | ||||||
| 	gl.connWg.Add(1) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Close immediately closes the listener. |  | ||||||
| func (gl *gracefulListener) Close() error { |  | ||||||
| 	gl.Lock() |  | ||||||
| 	if gl.stopped { |  | ||||||
| 		gl.Unlock() |  | ||||||
| 		return syscall.EINVAL |  | ||||||
| 	} |  | ||||||
| 	gl.Unlock() |  | ||||||
| 	gl.stop <- nil |  | ||||||
| 	return <-gl.stop |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // gracefulConn represents a connection on a |  | ||||||
| // gracefulListener so that we can keep track |  | ||||||
| // of the number of connections, thus facilitating |  | ||||||
| // a graceful shutdown. |  | ||||||
| type gracefulConn struct { |  | ||||||
| 	net.Conn |  | ||||||
| 	connWg *sync.WaitGroup // pointer to the host server's connection waitgroup |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Close closes c's underlying connection while updating the wg count. |  | ||||||
| func (c gracefulConn) Close() error { |  | ||||||
| 	err := c.Conn.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	// close can fail on http2 connections (as of Oct. 2015, before http2 in std lib) |  | ||||||
| 	// so don't decrement count unless close succeeds |  | ||||||
| 	c.connWg.Done() |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| @ -2,6 +2,7 @@ | |||||||
| package httpserver | package httpserver | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| @ -28,7 +29,6 @@ type Server struct { | |||||||
| 	listenerMu  sync.Mutex | 	listenerMu  sync.Mutex | ||||||
| 	sites       []*SiteConfig | 	sites       []*SiteConfig | ||||||
| 	connTimeout time.Duration // max time to wait for a connection before force stop | 	connTimeout time.Duration // max time to wait for a connection before force stop | ||||||
| 	connWg      sync.WaitGroup // one increment per connection |  | ||||||
| 	tlsGovChan  chan struct{} // close to stop the TLS maintenance goroutine | 	tlsGovChan  chan struct{} // close to stop the TLS maintenance goroutine | ||||||
| 	vhosts      *vhostTrie | 	vhosts      *vhostTrie | ||||||
| } | } | ||||||
| @ -46,16 +46,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { | |||||||
| 		connTimeout: GracefulTimeout, | 		connTimeout: GracefulTimeout, | ||||||
| 	} | 	} | ||||||
| 	s.Server.Handler = s // this is weird, but whatever | 	s.Server.Handler = s // this is weird, but whatever | ||||||
| 	s.Server.ConnState = func(c net.Conn, cs http.ConnState) { |  | ||||||
| 		if cs == http.StateIdle { |  | ||||||
| 			s.listenerMu.Lock() |  | ||||||
| 			// server stopped, close idle connection |  | ||||||
| 			if s.listener == nil { |  | ||||||
| 				c.Close() |  | ||||||
| 			} |  | ||||||
| 			s.listenerMu.Unlock() |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	// Disable HTTP/2 if desired | 	// Disable HTTP/2 if desired | ||||||
| 	if !HTTP2 { | 	if !HTTP2 { | ||||||
| @ -68,14 +58,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { | |||||||
| 		s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler) | 		s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// We have to bound our wg with one increment |  | ||||||
| 	// to prevent a "race condition" that is hard-coded |  | ||||||
| 	// into sync.WaitGroup.Wait() - basically, an add |  | ||||||
| 	// with a positive delta must be guaranteed to |  | ||||||
| 	// occur before Wait() is called on the wg. |  | ||||||
| 	// In a way, this kind of acts as a safety barrier. |  | ||||||
| 	s.connWg.Add(1) |  | ||||||
| 
 |  | ||||||
| 	// Set up TLS configuration | 	// Set up TLS configuration | ||||||
| 	var tlsConfigs []*caddytls.Config | 	var tlsConfigs []*caddytls.Config | ||||||
| 	for _, site := range group { | 	for _, site := range group { | ||||||
| @ -163,8 +145,6 @@ func (s *Server) Serve(ln net.Listener) error { | |||||||
| 		ln = tcpKeepAliveListener{TCPListener: tcpLn} | 		ln = tcpKeepAliveListener{TCPListener: tcpLn} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ln = newGracefulListener(ln, &s.connWg) |  | ||||||
| 
 |  | ||||||
| 	s.listenerMu.Lock() | 	s.listenerMu.Lock() | ||||||
| 	s.listener = ln | 	s.listener = ln | ||||||
| 	s.listenerMu.Unlock() | 	s.listenerMu.Unlock() | ||||||
| @ -306,40 +286,21 @@ func (s *Server) Address() string { | |||||||
| 
 | 
 | ||||||
| // Stop stops s gracefully (or forcefully after timeout) and | // Stop stops s gracefully (or forcefully after timeout) and | ||||||
| // closes its listener. | // closes its listener. | ||||||
| func (s *Server) Stop() (err error) { | func (s *Server) Stop() error { | ||||||
| 	s.Server.SetKeepAlivesEnabled(false) | 	ctx, cancel := context.WithTimeout(context.Background(), s.connTimeout) | ||||||
|  | 	defer cancel() | ||||||
| 
 | 
 | ||||||
| 	if runtime.GOOS != "windows" { | 	err := s.Server.Shutdown(ctx) | ||||||
| 		// force connections to close after timeout | 	if err != nil { | ||||||
| 		done := make(chan struct{}) | 		return err | ||||||
| 		go func() { |  | ||||||
| 			s.connWg.Done() // decrement our initial increment used as a barrier |  | ||||||
| 			s.connWg.Wait() |  | ||||||
| 			close(done) |  | ||||||
| 		}() |  | ||||||
| 
 |  | ||||||
| 		// Wait for remaining connections to finish or |  | ||||||
| 		// force them all to close after timeout |  | ||||||
| 		select { |  | ||||||
| 		case <-time.After(s.connTimeout): |  | ||||||
| 		case <-done: |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Close the listener now; this stops the server without delay | 	// signal any TLS governor goroutines to exit | ||||||
| 	s.listenerMu.Lock() |  | ||||||
| 	if s.listener != nil { |  | ||||||
| 		err = s.listener.Close() |  | ||||||
| 		s.listener = nil |  | ||||||
| 	} |  | ||||||
| 	s.listenerMu.Unlock() |  | ||||||
| 
 |  | ||||||
| 	// Closing this signals any TLS governor goroutines to exit |  | ||||||
| 	if s.tlsGovChan != nil { | 	if s.tlsGovChan != nil { | ||||||
| 		close(s.tlsGovChan) | 		close(s.tlsGovChan) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // sanitizePath collapses any ./ ../ /// madness | // sanitizePath collapses any ./ ../ /// madness | ||||||
| @ -439,11 +400,10 @@ func makeHTTPServer(addr string, group []*SiteConfig) *http.Server { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// set the final values on the server | 	// set the final values on the server | ||||||
| 	// TODO: ReadHeaderTimeout and IdleTimeout require Go 1.8 |  | ||||||
| 	s.ReadTimeout = min.ReadTimeout | 	s.ReadTimeout = min.ReadTimeout | ||||||
| 	// s.ReadHeaderTimeout = min.ReadHeaderTimeout | 	s.ReadHeaderTimeout = min.ReadHeaderTimeout | ||||||
| 	s.WriteTimeout = min.WriteTimeout | 	s.WriteTimeout = min.WriteTimeout | ||||||
| 	// s.IdleTimeout = min.IdleTimeout | 	s.IdleTimeout = min.IdleTimeout | ||||||
| 
 | 
 | ||||||
| 	return s | 	return s | ||||||
| } | } | ||||||
|  | |||||||
| @ -100,15 +100,14 @@ func TestMakeHTTPServer(t *testing.T) { | |||||||
| 		if got, want := actual.ReadTimeout, tc.expected.ReadTimeout; got != want { | 		if got, want := actual.ReadTimeout, tc.expected.ReadTimeout; got != want { | ||||||
| 			t.Errorf("Test %d: Expected ReadTimeout=%v, but was %v", i, want, got) | 			t.Errorf("Test %d: Expected ReadTimeout=%v, but was %v", i, want, got) | ||||||
| 		} | 		} | ||||||
| 		// TODO: ReadHeaderTimeout and IdleTimeout require Go 1.8 | 		if got, want := actual.ReadHeaderTimeout, tc.expected.ReadHeaderTimeout; got != want { | ||||||
| 		// if got, want := actual.ReadHeaderTimeout, tc.expected.ReadHeaderTimeout; got != want { | 			t.Errorf("Test %d: Expected ReadHeaderTimeout=%v, but was %v", i, want, got) | ||||||
| 		// 	t.Errorf("Test %d: Expected ReadHeaderTimeout=%v, but was %v", i, want, got) | 		} | ||||||
| 		// } |  | ||||||
| 		if got, want := actual.WriteTimeout, tc.expected.WriteTimeout; got != want { | 		if got, want := actual.WriteTimeout, tc.expected.WriteTimeout; got != want { | ||||||
| 			t.Errorf("Test %d: Expected WriteTimeout=%v, but was %v", i, want, got) | 			t.Errorf("Test %d: Expected WriteTimeout=%v, but was %v", i, want, got) | ||||||
| 		} | 		} | ||||||
| 		// if got, want := actual.IdleTimeout, tc.expected.IdleTimeout; got != want { | 		if got, want := actual.IdleTimeout, tc.expected.IdleTimeout; got != want { | ||||||
| 		// 	t.Errorf("Test %d: Expected IdleTimeout=%v, but was %v", i, want, got) | 			t.Errorf("Test %d: Expected IdleTimeout=%v, but was %v", i, want, got) | ||||||
| 		// } | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user