mirror of
https://github.com/caddyserver/caddy.git
synced 2025-05-24 02:02:26 -04:00
Merge branch 'master' into diagnostics
# Conflicts: # plugins.go # vendor/manifest
This commit is contained in:
commit
269a8b5fce
2
.gitignore
vendored
2
.gitignore
vendored
@ -17,3 +17,5 @@ Caddyfile
|
|||||||
og_static/
|
og_static/
|
||||||
|
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
|
*.bat
|
@ -1,5 +1,5 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://caddyserver.com"><img src="https://cloud.githubusercontent.com/assets/1128849/25305033/12916fce-2731-11e7-86ec-580d4d31cb16.png" alt="Caddy" width="400"></a>
|
<a href="https://caddyserver.com"><img src="https://user-images.githubusercontent.com/1128849/36137292-bebc223a-1051-11e8-9a81-4ea9054c96ac.png" alt="Caddy" width="400"></a>
|
||||||
</p>
|
</p>
|
||||||
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
|
<h3 align="center">Every Site on HTTPS <!-- Serve Confidently --></h3>
|
||||||
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
|
<p align="center">Caddy is a general-purpose HTTP/2 web server that serves HTTPS by default.</p>
|
||||||
@ -59,7 +59,7 @@ customize your build in the browser
|
|||||||
pre-built, vanilla binaries
|
pre-built, vanilla binaries
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.8 or newer). Follow these instruction for fast building:
|
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
|
||||||
|
|
||||||
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
|
- Get source `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
|
||||||
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
|
- Now `cd` to `$GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
|
||||||
|
66
caddy.go
66
caddy.go
@ -78,8 +78,18 @@ var (
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
OnProcessExit = append(OnProcessExit, func() {
|
||||||
|
if PidFile != "" {
|
||||||
|
os.Remove(PidFile)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Instance contains the state of servers created as a result of
|
// Instance contains the state of servers created as a result of
|
||||||
// calling Start and can be used to access or control those servers.
|
// calling Start and can be used to access or control those servers.
|
||||||
|
// It is literally an instance of a server type. Instance values
|
||||||
|
// should NOT be copied. Use *Instance for safety.
|
||||||
type Instance struct {
|
type Instance struct {
|
||||||
// serverType is the name of the instance's server type
|
// serverType is the name of the instance's server type
|
||||||
serverType string
|
serverType string
|
||||||
@ -90,10 +100,11 @@ type Instance struct {
|
|||||||
// wg is used to wait for all servers to shut down
|
// wg is used to wait for all servers to shut down
|
||||||
wg *sync.WaitGroup
|
wg *sync.WaitGroup
|
||||||
|
|
||||||
// context is the context created for this instance.
|
// context is the context created for this instance,
|
||||||
|
// used to coordinate the setting up of the server type
|
||||||
context Context
|
context Context
|
||||||
|
|
||||||
// servers is the list of servers with their listeners.
|
// servers is the list of servers with their listeners
|
||||||
servers []ServerListener
|
servers []ServerListener
|
||||||
|
|
||||||
// these callbacks execute when certain events occur
|
// these callbacks execute when certain events occur
|
||||||
@ -102,6 +113,18 @@ type Instance struct {
|
|||||||
onRestart []func() error // before restart commences
|
onRestart []func() error // before restart commences
|
||||||
onShutdown []func() error // stopping, even as part of a restart
|
onShutdown []func() error // stopping, even as part of a restart
|
||||||
onFinalShutdown []func() error // stopping, not as part of a restart
|
onFinalShutdown []func() error // stopping, not as part of a restart
|
||||||
|
|
||||||
|
// storing values on an instance is preferable to
|
||||||
|
// global state because these will get garbage-
|
||||||
|
// collected after in-process reloads when the
|
||||||
|
// old instances are destroyed; use StorageMu
|
||||||
|
// to access this value safely
|
||||||
|
Storage map[interface{}]interface{}
|
||||||
|
StorageMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func Instances() []*Instance {
|
||||||
|
return instances
|
||||||
}
|
}
|
||||||
|
|
||||||
// Servers returns the ServerListeners in i.
|
// Servers returns the ServerListeners in i.
|
||||||
@ -197,7 +220,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create new instance; if the restart fails, it is simply discarded
|
// create new instance; if the restart fails, it is simply discarded
|
||||||
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg}
|
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})}
|
||||||
|
|
||||||
// attempt to start new instance
|
// attempt to start new instance
|
||||||
err := startWithListenerFds(newCaddyfile, newInst, restartFds)
|
err := startWithListenerFds(newCaddyfile, newInst, restartFds)
|
||||||
@ -456,7 +479,7 @@ func (i *Instance) Caddyfile() Input {
|
|||||||
//
|
//
|
||||||
// This function blocks until all the servers are listening.
|
// This function blocks until all the servers are listening.
|
||||||
func Start(cdyfile Input) (*Instance, error) {
|
func Start(cdyfile Input) (*Instance, error) {
|
||||||
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)}
|
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
|
||||||
err := startWithListenerFds(cdyfile, inst, nil)
|
err := startWithListenerFds(cdyfile, inst, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return inst, err
|
return inst, err
|
||||||
@ -469,11 +492,34 @@ func Start(cdyfile Input) (*Instance, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error {
|
func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]restartTriple) error {
|
||||||
|
// save this instance in the list now so that
|
||||||
|
// plugins can access it if need be, for example
|
||||||
|
// the caddytls package, so it can perform cert
|
||||||
|
// renewals while starting up; we just have to
|
||||||
|
// remove the instance from the list later if
|
||||||
|
// it fails
|
||||||
|
instancesMu.Lock()
|
||||||
|
instances = append(instances, inst)
|
||||||
|
instancesMu.Unlock()
|
||||||
|
var err error
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
instancesMu.Lock()
|
||||||
|
for i, otherInst := range instances {
|
||||||
|
if otherInst == inst {
|
||||||
|
instances = append(instances[:i], instances[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
instancesMu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if cdyfile == nil {
|
if cdyfile == nil {
|
||||||
cdyfile = CaddyfileInput{}
|
cdyfile = CaddyfileInput{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ValidateAndExecuteDirectives(cdyfile, inst, false)
|
err = ValidateAndExecuteDirectives(cdyfile, inst, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -505,10 +551,6 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
instancesMu.Lock()
|
|
||||||
instances = append(instances, inst)
|
|
||||||
instancesMu.Unlock()
|
|
||||||
|
|
||||||
// run any AfterStartup callbacks if this is not
|
// run any AfterStartup callbacks if this is not
|
||||||
// part of a restart; then show file descriptor notice
|
// part of a restart; then show file descriptor notice
|
||||||
if restartFds == nil {
|
if restartFds == nil {
|
||||||
@ -547,7 +589,7 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
|
|||||||
func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error {
|
func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bool) error {
|
||||||
// If parsing only inst will be nil, create an instance for this function call only.
|
// If parsing only inst will be nil, create an instance for this function call only.
|
||||||
if justValidate {
|
if justValidate {
|
||||||
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup)}
|
inst = &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
|
||||||
}
|
}
|
||||||
|
|
||||||
stypeName := cdyfile.ServerType()
|
stypeName := cdyfile.ServerType()
|
||||||
@ -564,14 +606,14 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
inst.context = stype.NewContext()
|
inst.context = stype.NewContext(inst)
|
||||||
if inst.context == nil {
|
if inst.context == nil {
|
||||||
return fmt.Errorf("server type %s produced a nil Context", stypeName)
|
return fmt.Errorf("server type %s produced a nil Context", stypeName)
|
||||||
}
|
}
|
||||||
|
|
||||||
sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks)
|
sblocks, err = inst.context.InspectServerBlocks(cdyfile.Path(), sblocks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("error inspecting server blocks: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
diagnostics.Set("num_server_blocks", len(sblocks))
|
diagnostics.Set("num_server_blocks", len(sblocks))
|
||||||
|
@ -148,7 +148,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||||||
case "HEAD":
|
case "HEAD":
|
||||||
resp, err = fcgiBackend.Head(env)
|
resp, err = fcgiBackend.Head(env)
|
||||||
case "GET":
|
case "GET":
|
||||||
resp, err = fcgiBackend.Get(env)
|
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
|
||||||
case "OPTIONS":
|
case "OPTIONS":
|
||||||
resp, err = fcgiBackend.Options(env)
|
resp, err = fcgiBackend.Options(env)
|
||||||
default:
|
default:
|
||||||
|
@ -460,12 +460,12 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get issues a GET request to the fcgi responder.
|
// Get issues a GET request to the fcgi responder.
|
||||||
func (c *FCGIClient) Get(p map[string]string) (resp *http.Response, err error) {
|
func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
|
||||||
|
|
||||||
p["REQUEST_METHOD"] = "GET"
|
p["REQUEST_METHOD"] = "GET"
|
||||||
p["CONTENT_LENGTH"] = "0"
|
p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
|
||||||
|
|
||||||
return c.Request(p, nil)
|
return c.Request(p, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Head issues a HEAD request to the fcgi responder.
|
// Head issues a HEAD request to the fcgi responder.
|
||||||
|
@ -140,7 +140,8 @@ func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[
|
|||||||
}
|
}
|
||||||
resp, err = fcgi.PostForm(fcgiParams, values)
|
resp, err = fcgi.PostForm(fcgiParams, values)
|
||||||
} else {
|
} else {
|
||||||
resp, err = fcgi.Get(fcgiParams)
|
rd := bytes.NewReader(data)
|
||||||
|
resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -27,7 +27,7 @@ func activateHTTPS(cctx caddy.Context) error {
|
|||||||
operatorPresent := !caddy.Started()
|
operatorPresent := !caddy.Started()
|
||||||
|
|
||||||
if !caddy.Quiet && operatorPresent {
|
if !caddy.Quiet && operatorPresent {
|
||||||
fmt.Print("Activating privacy features...")
|
fmt.Print("Activating privacy features... ")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := cctx.(*httpContext)
|
ctx := cctx.(*httpContext)
|
||||||
@ -69,7 +69,7 @@ func activateHTTPS(cctx caddy.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !caddy.Quiet && operatorPresent {
|
if !caddy.Quiet && operatorPresent {
|
||||||
fmt.Println(" done.")
|
fmt.Println("done.")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -160,23 +160,37 @@ func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort str
|
|||||||
// to listen on HTTPPort. The TLS field of cfg must not be nil.
|
// to listen on HTTPPort. The TLS field of cfg must not be nil.
|
||||||
func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
|
func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
|
||||||
redirPort := cfg.Addr.Port
|
redirPort := cfg.Addr.Port
|
||||||
if redirPort == DefaultHTTPSPort {
|
if redirPort == HTTPSPort {
|
||||||
redirPort = "" // default port is redundant
|
// By default, HTTPSPort should be DefaultHTTPSPort,
|
||||||
|
// which of course doesn't need to be explicitly stated
|
||||||
|
// in the Location header. Even if HTTPSPort is changed
|
||||||
|
// so that it is no longer DefaultHTTPSPort, we shouldn't
|
||||||
|
// append it to the URL in the Location because changing
|
||||||
|
// the HTTPS port is assumed to be an internal-only change
|
||||||
|
// (in other words, we assume port forwarding is going on);
|
||||||
|
// but redirects go back to a presumably-external client.
|
||||||
|
// (If redirect clients are also internal, that is more
|
||||||
|
// advanced, and the user should configure HTTP->HTTPS
|
||||||
|
// redirects themselves.)
|
||||||
|
redirPort = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
redirMiddleware := func(next Handler) Handler {
|
redirMiddleware := func(next Handler) Handler {
|
||||||
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
// Construct the URL to which to redirect. Note that the Host in a request might
|
// Construct the URL to which to redirect. Note that the Host in a
|
||||||
// contain a port, but we just need the hostname; we'll set the port if needed.
|
// request might contain a port, but we just need the hostname from
|
||||||
|
// it; and we'll set the port if needed.
|
||||||
toURL := "https://"
|
toURL := "https://"
|
||||||
requestHost, _, err := net.SplitHostPort(r.Host)
|
requestHost, _, err := net.SplitHostPort(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestHost = r.Host // Host did not contain a port; great
|
requestHost = r.Host // Host did not contain a port, so use the whole value
|
||||||
}
|
}
|
||||||
if redirPort == "" {
|
if redirPort == "" {
|
||||||
toURL += requestHost
|
toURL += requestHost
|
||||||
} else {
|
} else {
|
||||||
toURL += net.JoinHostPort(requestHost, redirPort)
|
toURL += net.JoinHostPort(requestHost, redirPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
toURL += r.URL.RequestURI()
|
toURL += r.URL.RequestURI()
|
||||||
|
|
||||||
w.Header().Set("Connection", "close")
|
w.Header().Set("Connection", "close")
|
||||||
@ -184,9 +198,11 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
host := cfg.Addr.Host
|
host := cfg.Addr.Host
|
||||||
port := HTTPPort
|
port := HTTPPort
|
||||||
addr := net.JoinHostPort(host, port)
|
addr := net.JoinHostPort(host, port)
|
||||||
|
|
||||||
return &SiteConfig{
|
return &SiteConfig{
|
||||||
Addr: Address{Original: addr, Host: host, Port: port},
|
Addr: Address{Original: addr, Host: host, Port: port},
|
||||||
ListenHost: cfg.ListenHost,
|
ListenHost: cfg.ListenHost,
|
||||||
|
@ -53,7 +53,7 @@ func TestRedirPlaintextHost(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
Host: "foohost",
|
Host: "foohost",
|
||||||
Port: "443", // since this is the default HTTPS port, should not be included in Location value
|
Port: HTTPSPort, // since this is the 'default' HTTPS port, should not be included in Location value
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Host: "*.example.com",
|
Host: "*.example.com",
|
||||||
|
@ -91,11 +91,13 @@ func hideCaddyfile(cctx caddy.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newContext() caddy.Context {
|
func newContext(inst *caddy.Instance) caddy.Context {
|
||||||
return &httpContext{keysToSiteConfigs: make(map[string]*SiteConfig)}
|
return &httpContext{instance: inst, keysToSiteConfigs: make(map[string]*SiteConfig)}
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpContext struct {
|
type httpContext struct {
|
||||||
|
instance *caddy.Instance
|
||||||
|
|
||||||
// keysToSiteConfigs maps an address at the top of a
|
// keysToSiteConfigs maps an address at the top of a
|
||||||
// server block (a "key") to its SiteConfig. Not all
|
// server block (a "key") to its SiteConfig. Not all
|
||||||
// SiteConfigs will be represented here, only ones
|
// SiteConfigs will be represented here, only ones
|
||||||
@ -115,12 +117,14 @@ func (h *httpContext) saveConfig(key string, cfg *SiteConfig) {
|
|||||||
// executing directives and otherwise prepares the directives to
|
// executing directives and otherwise prepares the directives to
|
||||||
// be parsed and executed.
|
// be parsed and executed.
|
||||||
func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
|
func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
|
||||||
|
siteAddrs := make(map[string]string)
|
||||||
|
|
||||||
// For each address in each server block, make a new config
|
// For each address in each server block, make a new config
|
||||||
for _, sb := range serverBlocks {
|
for _, sb := range serverBlocks {
|
||||||
for _, key := range sb.Keys {
|
for _, key := range sb.Keys {
|
||||||
key = strings.ToLower(key)
|
key = strings.ToLower(key)
|
||||||
if _, dup := h.keysToSiteConfigs[key]; dup {
|
if _, dup := h.keysToSiteConfigs[key]; dup {
|
||||||
return serverBlocks, fmt.Errorf("duplicate site address: %s", key)
|
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
|
||||||
}
|
}
|
||||||
addr, err := standardizeAddress(key)
|
addr, err := standardizeAddress(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -136,6 +140,23 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
|||||||
addr.Port = Port
|
addr.Port = Port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure the adjusted site address is distinct
|
||||||
|
addrCopy := addr // make copy so we don't disturb the original, carefully-parsed address struct
|
||||||
|
if addrCopy.Port == "" && Port == DefaultPort {
|
||||||
|
addrCopy.Port = Port
|
||||||
|
}
|
||||||
|
addrStr := strings.ToLower(addrCopy.String())
|
||||||
|
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
|
||||||
|
err := fmt.Errorf("duplicate site address: %s", addrStr)
|
||||||
|
if (addrCopy.Host == Host && Host != DefaultHost) ||
|
||||||
|
(addrCopy.Port == Port && Port != DefaultPort) {
|
||||||
|
err = fmt.Errorf("site defined as %s is a duplicate of %s because of modified "+
|
||||||
|
"default host and/or port values (usually via -host or -port flags)", key, otherSiteKey)
|
||||||
|
}
|
||||||
|
return serverBlocks, err
|
||||||
|
}
|
||||||
|
siteAddrs[addrStr] = key
|
||||||
|
|
||||||
// If default HTTP or HTTPS ports have been customized,
|
// If default HTTP or HTTPS ports have been customized,
|
||||||
// make sure the ACME challenge ports match
|
// make sure the ACME challenge ports match
|
||||||
var altHTTPPort, altTLSSNIPort string
|
var altHTTPPort, altTLSSNIPort string
|
||||||
@ -146,15 +167,19 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
|||||||
altTLSSNIPort = HTTPSPort
|
altTLSSNIPort = HTTPSPort
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make our caddytls.Config, which has a pointer to the
|
||||||
|
// instance's certificate cache and enough information
|
||||||
|
// to use automatic HTTPS when the time comes
|
||||||
|
caddytlsConfig := caddytls.NewConfig(h.instance)
|
||||||
|
caddytlsConfig.Hostname = addr.Host
|
||||||
|
caddytlsConfig.AltHTTPPort = altHTTPPort
|
||||||
|
caddytlsConfig.AltTLSSNIPort = altTLSSNIPort
|
||||||
|
|
||||||
// Save the config to our master list, and key it for lookups
|
// Save the config to our master list, and key it for lookups
|
||||||
cfg := &SiteConfig{
|
cfg := &SiteConfig{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Root: Root,
|
Root: Root,
|
||||||
TLS: &caddytls.Config{
|
TLS: caddytlsConfig,
|
||||||
Hostname: addr.Host,
|
|
||||||
AltHTTPPort: altHTTPPort,
|
|
||||||
AltTLSSNIPort: altTLSSNIPort,
|
|
||||||
},
|
|
||||||
originCaddyfile: sourceFile,
|
originCaddyfile: sourceFile,
|
||||||
IndexPages: staticfiles.DefaultIndexPages,
|
IndexPages: staticfiles.DefaultIndexPages,
|
||||||
}
|
}
|
||||||
|
@ -137,7 +137,7 @@ func TestAddressString(t *testing.T) {
|
|||||||
func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
|
func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
|
||||||
Port = "9999"
|
Port = "9999"
|
||||||
filename := "Testfile"
|
filename := "Testfile"
|
||||||
ctx := newContext().(*httpContext)
|
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||||
input := strings.NewReader(`localhost`)
|
input := strings.NewReader(`localhost`)
|
||||||
sblocks, err := caddyfile.Parse(filename, input, nil)
|
sblocks, err := caddyfile.Parse(filename, input, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -153,9 +153,26 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// See discussion on PR #2015
|
||||||
|
func TestInspectServerBlocksWithAdjustedAddress(t *testing.T) {
|
||||||
|
Port = DefaultPort
|
||||||
|
Host = "example.com"
|
||||||
|
filename := "Testfile"
|
||||||
|
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||||
|
input := strings.NewReader("example.com {\n}\n:2015 {\n}")
|
||||||
|
sblocks, err := caddyfile.Parse(filename, input, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error setting up test, got: %v", err)
|
||||||
|
}
|
||||||
|
_, err = ctx.InspectServerBlocks(filename, sblocks)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Expected an error because site definitions should overlap, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
|
func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
|
||||||
filename := "Testfile"
|
filename := "Testfile"
|
||||||
ctx := newContext().(*httpContext)
|
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||||
input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}")
|
input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}")
|
||||||
sblocks, err := caddyfile.Parse(filename, input, nil)
|
sblocks, err := caddyfile.Parse(filename, input, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -207,7 +224,7 @@ func TestDirectivesList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestContextSaveConfig(t *testing.T) {
|
func TestContextSaveConfig(t *testing.T) {
|
||||||
ctx := newContext().(*httpContext)
|
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||||
ctx.saveConfig("foo", new(SiteConfig))
|
ctx.saveConfig("foo", new(SiteConfig))
|
||||||
if _, ok := ctx.keysToSiteConfigs["foo"]; !ok {
|
if _, ok := ctx.keysToSiteConfigs["foo"]; !ok {
|
||||||
t.Error("Expected config to be saved, but it wasn't")
|
t.Error("Expected config to be saved, but it wasn't")
|
||||||
@ -226,7 +243,7 @@ func TestContextSaveConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Test to make sure we are correctly hiding the Caddyfile
|
// Test to make sure we are correctly hiding the Caddyfile
|
||||||
func TestHideCaddyfile(t *testing.T) {
|
func TestHideCaddyfile(t *testing.T) {
|
||||||
ctx := newContext().(*httpContext)
|
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
|
||||||
ctx.saveConfig("test", &SiteConfig{
|
ctx.saveConfig("test", &SiteConfig{
|
||||||
Root: Root,
|
Root: Root,
|
||||||
originCaddyfile: "Testfile",
|
originCaddyfile: "Testfile",
|
||||||
|
@ -392,7 +392,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||||||
if vhost == nil {
|
if vhost == nil {
|
||||||
// check for ACME challenge even if vhost is nil;
|
// check for ACME challenge even if vhost is nil;
|
||||||
// could be a new host coming online soon
|
// could be a new host coming online soon
|
||||||
if caddytls.HTTPChallengeHandler(w, r, "localhost", caddytls.DefaultHTTPAlternatePort) {
|
if caddytls.HTTPChallengeHandler(w, r, "localhost") {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
// otherwise, log the error and write a message to the client
|
// otherwise, log the error and write a message to the client
|
||||||
@ -408,7 +408,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||||||
|
|
||||||
// we still check for ACME challenge if the vhost exists,
|
// we still check for ACME challenge if the vhost exists,
|
||||||
// because we must apply its HTTP challenge config settings
|
// because we must apply its HTTP challenge config settings
|
||||||
if s.proxyHTTPChallenge(vhost, w, r) {
|
if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,31 +416,25 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
|||||||
// the URL path, so a request to example.com/foo/blog on the site
|
// the URL path, so a request to example.com/foo/blog on the site
|
||||||
// defined as example.com/foo appears as /blog instead of /foo/blog.
|
// defined as example.com/foo appears as /blog instead of /foo/blog.
|
||||||
if pathPrefix != "/" {
|
if pathPrefix != "/" {
|
||||||
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathPrefix)
|
r.URL = trimPathPrefix(r.URL, pathPrefix)
|
||||||
if !strings.HasPrefix(r.URL.Path, "/") {
|
|
||||||
r.URL.Path = "/" + r.URL.Path
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return vhost.middlewareChain.ServeHTTP(w, r)
|
return vhost.middlewareChain.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyHTTPChallenge solves the ACME HTTP challenge if r is the HTTP
|
func trimPathPrefix(u *url.URL, prefix string) *url.URL {
|
||||||
// request for the challenge. If it is, and if the request has been
|
// We need to use URL.EscapedPath() when trimming the pathPrefix as
|
||||||
// fulfilled (response written), true is returned; false otherwise.
|
// URL.Path is ambiguous about / or %2f - see docs. See #1927
|
||||||
// If you don't have a vhost, just call the challenge handler directly.
|
trimmed := strings.TrimPrefix(u.EscapedPath(), prefix)
|
||||||
func (s *Server) proxyHTTPChallenge(vhost *SiteConfig, w http.ResponseWriter, r *http.Request) bool {
|
if !strings.HasPrefix(trimmed, "/") {
|
||||||
if vhost.Addr.Port != caddytls.HTTPChallengePort {
|
trimmed = "/" + trimmed
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
if vhost.TLS != nil && vhost.TLS.Manual {
|
trimmedURL, err := url.Parse(trimmed)
|
||||||
return false
|
if err != nil {
|
||||||
|
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err)
|
||||||
|
return u
|
||||||
}
|
}
|
||||||
altPort := caddytls.DefaultHTTPAlternatePort
|
return trimmedURL
|
||||||
if vhost.TLS != nil && vhost.TLS.AltHTTPPort != "" {
|
|
||||||
altPort = vhost.TLS.AltHTTPPort
|
|
||||||
}
|
|
||||||
return caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost, altPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the address s was assigned to listen on.
|
// Address returns the address s was assigned to listen on.
|
||||||
|
@ -16,6 +16,7 @@ package httpserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -126,6 +127,94 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTrimPathPrefix(t *testing.T) {
|
||||||
|
for i, pt := range []struct {
|
||||||
|
path string
|
||||||
|
prefix string
|
||||||
|
expected string
|
||||||
|
shouldFail bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
path: "/my/path",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/path",
|
||||||
|
shouldFail: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/my/%2f/path",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/%2f/path",
|
||||||
|
shouldFail: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/my/path",
|
||||||
|
prefix: "/my/",
|
||||||
|
expected: "/path",
|
||||||
|
shouldFail: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/my///path",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/path",
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/my///path",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "///path",
|
||||||
|
shouldFail: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/my/path///slash",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/path///slash",
|
||||||
|
shouldFail: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/my/%2f/path/%2f",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/%2f/path/%2f",
|
||||||
|
shouldFail: false,
|
||||||
|
}, {
|
||||||
|
path: "/my/%20/path",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/%20/path",
|
||||||
|
shouldFail: false,
|
||||||
|
}, {
|
||||||
|
path: "/path",
|
||||||
|
prefix: "",
|
||||||
|
expected: "/path",
|
||||||
|
shouldFail: false,
|
||||||
|
}, {
|
||||||
|
path: "/path/my/",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/path/my/",
|
||||||
|
shouldFail: false,
|
||||||
|
}, {
|
||||||
|
path: "",
|
||||||
|
prefix: "/my",
|
||||||
|
expected: "/",
|
||||||
|
shouldFail: false,
|
||||||
|
}, {
|
||||||
|
path: "/apath",
|
||||||
|
prefix: "",
|
||||||
|
expected: "/apath",
|
||||||
|
shouldFail: false,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
|
||||||
|
u, _ := url.Parse(pt.path)
|
||||||
|
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want {
|
||||||
|
if !pt.shouldFail {
|
||||||
|
|
||||||
|
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath())
|
||||||
|
}
|
||||||
|
} else if pt.shouldFail {
|
||||||
|
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
|
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
|
||||||
for name, c := range map[string]struct {
|
for name, c := range map[string]struct {
|
||||||
group []*SiteConfig
|
group []*SiteConfig
|
||||||
|
@ -16,6 +16,7 @@ package requestid
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -24,12 +25,29 @@ import (
|
|||||||
|
|
||||||
// Handler is a middleware handler
|
// Handler is a middleware handler
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
Next httpserver.Handler
|
Next httpserver.Handler
|
||||||
|
HeaderName string // (optional) header from which to read an existing ID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
reqid := uuid.New().String()
|
var reqid uuid.UUID
|
||||||
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid)
|
|
||||||
|
uuidFromHeader := r.Header.Get(h.HeaderName)
|
||||||
|
if h.HeaderName != "" && uuidFromHeader != "" {
|
||||||
|
// use the ID in the header field if it exists
|
||||||
|
var err error
|
||||||
|
reqid, err = uuid.Parse(uuidFromHeader)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[NOTICE] Parsing request ID from %s header: %v", h.HeaderName, err)
|
||||||
|
reqid = uuid.New()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// otherwise, create a new one
|
||||||
|
reqid = uuid.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the request ID on the context
|
||||||
|
c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid.String())
|
||||||
r = r.WithContext(c)
|
r = r.WithContext(c)
|
||||||
|
|
||||||
return h.Next.ServeHTTP(w, r)
|
return h.Next.ServeHTTP(w, r)
|
||||||
|
@ -15,34 +15,53 @@
|
|||||||
package requestid
|
package requestid
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRequestID(t *testing.T) {
|
func TestRequestIDHandler(t *testing.T) {
|
||||||
request, err := http.NewRequest("GET", "http://localhost/", nil)
|
handler := Handler{
|
||||||
|
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
|
||||||
|
if value == "" {
|
||||||
|
t.Error("Request ID should not be empty")
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", "http://localhost/", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Could not create HTTP request:", err)
|
t.Fatal("Could not create HTTP request:", err)
|
||||||
}
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
reqid := uuid.New().String()
|
handler.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
c := context.WithValue(request.Context(), httpserver.RequestIDCtxKey, reqid)
|
|
||||||
|
func TestRequestIDFromHeader(t *testing.T) {
|
||||||
request = request.WithContext(c)
|
headerName := "X-Request-ID"
|
||||||
|
headerValue := "71a75329-d9f9-4d25-957e-e689a7b68d78"
|
||||||
// See caddyhttp/replacer.go
|
handler := Handler{
|
||||||
value, _ := request.Context().Value(httpserver.RequestIDCtxKey).(string)
|
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string)
|
||||||
if value == "" {
|
if value != headerValue {
|
||||||
t.Fatal("Request ID should not be empty")
|
t.Errorf("Request ID should be '%s' but got '%s'", headerValue, value)
|
||||||
}
|
}
|
||||||
|
return 0, nil
|
||||||
if value != reqid {
|
}),
|
||||||
t.Fatal("Request ID does not match")
|
HeaderName: headerName,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", "http://localhost/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Could not create HTTP request:", err)
|
||||||
|
}
|
||||||
|
req.Header.Set(headerName, headerValue)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
}
|
}
|
||||||
|
@ -27,14 +27,19 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setup(c *caddy.Controller) error {
|
func setup(c *caddy.Controller) error {
|
||||||
|
var headerName string
|
||||||
|
|
||||||
for c.Next() {
|
for c.Next() {
|
||||||
if c.NextArg() {
|
if c.NextArg() {
|
||||||
return c.ArgErr() //no arg expected.
|
headerName = c.Val()
|
||||||
|
}
|
||||||
|
if c.NextArg() {
|
||||||
|
return c.ArgErr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||||
return Handler{Next: next}
|
return Handler{Next: next, HeaderName: headerName}
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -45,7 +45,15 @@ func TestSetup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupWithArg(t *testing.T) {
|
func TestSetupWithArg(t *testing.T) {
|
||||||
c := caddy.NewTestController("http", `requestid abc`)
|
c := caddy.NewTestController("http", `requestid X-Request-ID`)
|
||||||
|
err := setup(c)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupWithTooManyArgs(t *testing.T) {
|
||||||
|
c := caddy.NewTestController("http", `requestid foo bar`)
|
||||||
err := setup(c)
|
err := setup(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error, got: %v", err)
|
t.Errorf("Expected an error, got: %v", err)
|
||||||
|
@ -107,6 +107,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
|
|||||||
if d.IsDir() {
|
if d.IsDir() {
|
||||||
// ensure there is a trailing slash
|
// ensure there is a trailing slash
|
||||||
if urlCopy.Path[len(urlCopy.Path)-1] != '/' {
|
if urlCopy.Path[len(urlCopy.Path)-1] != '/' {
|
||||||
|
for strings.HasPrefix(urlCopy.Path, "//") {
|
||||||
|
// prevent path-based open redirects
|
||||||
|
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
|
||||||
|
}
|
||||||
urlCopy.Path += "/"
|
urlCopy.Path += "/"
|
||||||
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
|
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
|
||||||
return http.StatusMovedPermanently, nil
|
return http.StatusMovedPermanently, nil
|
||||||
@ -131,6 +135,10 @@ func (fs FileServer) serveFile(w http.ResponseWriter, r *http.Request) (int, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
if redir {
|
if redir {
|
||||||
|
for strings.HasPrefix(urlCopy.Path, "//") {
|
||||||
|
// prevent path-based open redirects
|
||||||
|
urlCopy.Path = strings.TrimPrefix(urlCopy.Path, "/")
|
||||||
|
}
|
||||||
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
|
http.Redirect(w, r, urlCopy.String(), http.StatusMovedPermanently)
|
||||||
return http.StatusMovedPermanently, nil
|
return http.StatusMovedPermanently, nil
|
||||||
}
|
}
|
||||||
|
@ -77,9 +77,9 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
{
|
{
|
||||||
url: "https://foo/dirwithindex/",
|
url: "https://foo/dirwithindex/",
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedBodyContent: testFiles[webrootDirwithindexIndeHTML],
|
expectedBodyContent: testFiles[webrootDirwithindexIndexHTML],
|
||||||
expectedEtag: `"2n9cw"`,
|
expectedEtag: `"2n9cw"`,
|
||||||
expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndeHTML])),
|
expectedContentLength: strconv.Itoa(len(testFiles[webrootDirwithindexIndexHTML])),
|
||||||
},
|
},
|
||||||
// Test 4 - access folder with index file without trailing slash
|
// Test 4 - access folder with index file without trailing slash
|
||||||
{
|
{
|
||||||
@ -235,16 +235,38 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
expectedBodyContent: movedPermanently,
|
expectedBodyContent: movedPermanently,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
// Test 27 - Check etag
|
||||||
url: "https://foo/notindex.html",
|
url: "https://foo/notindex.html",
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedBodyContent: testFiles[webrootNotIndexHTML],
|
expectedBodyContent: testFiles[webrootNotIndexHTML],
|
||||||
expectedEtag: `"2n9cm"`,
|
expectedEtag: `"2n9cm"`,
|
||||||
expectedContentLength: strconv.Itoa(len(testFiles[webrootNotIndexHTML])),
|
expectedContentLength: strconv.Itoa(len(testFiles[webrootNotIndexHTML])),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
// Test 28 - Prevent path-based open redirects (directory)
|
||||||
|
url: "https://foo//example.com%2f..",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
expectedLocation: "https://foo/example.com/../",
|
||||||
|
expectedBodyContent: movedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Test 29 - Prevent path-based open redirects (file)
|
||||||
|
url: "https://foo//example.com%2f../dirwithindex/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
expectedLocation: "https://foo/example.com/../dirwithindex/",
|
||||||
|
expectedBodyContent: movedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Test 29 - Prevent path-based open redirects (extra leading slashes)
|
||||||
|
url: "https://foo///example.com%2f..",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
expectedLocation: "https://foo/example.com/../",
|
||||||
|
expectedBodyContent: movedPermanently,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
// set up response writer and rewuest
|
// set up response writer and request
|
||||||
responseRecorder := httptest.NewRecorder()
|
responseRecorder := httptest.NewRecorder()
|
||||||
request, err := http.NewRequest("GET", test.url, nil)
|
request, err := http.NewRequest("GET", test.url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -518,7 +540,7 @@ var (
|
|||||||
webrootNotIndexHTML = filepath.Join(webrootName, "notindex.html")
|
webrootNotIndexHTML = filepath.Join(webrootName, "notindex.html")
|
||||||
webrootDirFile2HTML = filepath.Join(webrootName, "dir", "file2.html")
|
webrootDirFile2HTML = filepath.Join(webrootName, "dir", "file2.html")
|
||||||
webrootDirHiddenHTML = filepath.Join(webrootName, "dir", "hidden.html")
|
webrootDirHiddenHTML = filepath.Join(webrootName, "dir", "hidden.html")
|
||||||
webrootDirwithindexIndeHTML = filepath.Join(webrootName, "dirwithindex", "index.html")
|
webrootDirwithindexIndexHTML = filepath.Join(webrootName, "dirwithindex", "index.html")
|
||||||
webrootSubGzippedHTML = filepath.Join(webrootName, "sub", "gzipped.html")
|
webrootSubGzippedHTML = filepath.Join(webrootName, "sub", "gzipped.html")
|
||||||
webrootSubGzippedHTMLGz = filepath.Join(webrootName, "sub", "gzipped.html.gz")
|
webrootSubGzippedHTMLGz = filepath.Join(webrootName, "sub", "gzipped.html.gz")
|
||||||
webrootSubGzippedHTMLBr = filepath.Join(webrootName, "sub", "gzipped.html.br")
|
webrootSubGzippedHTMLBr = filepath.Join(webrootName, "sub", "gzipped.html.br")
|
||||||
@ -544,7 +566,7 @@ var testFiles = map[string]string{
|
|||||||
webrootFile1HTML: "<h1>file1.html</h1>",
|
webrootFile1HTML: "<h1>file1.html</h1>",
|
||||||
webrootNotIndexHTML: "<h1>notindex.html</h1>",
|
webrootNotIndexHTML: "<h1>notindex.html</h1>",
|
||||||
webrootDirFile2HTML: "<h1>dir/file2.html</h1>",
|
webrootDirFile2HTML: "<h1>dir/file2.html</h1>",
|
||||||
webrootDirwithindexIndeHTML: "<h1>dirwithindex/index.html</h1>",
|
webrootDirwithindexIndexHTML: "<h1>dirwithindex/index.html</h1>",
|
||||||
webrootDirHiddenHTML: "<h1>dir/hidden.html</h1>",
|
webrootDirHiddenHTML: "<h1>dir/hidden.html</h1>",
|
||||||
webrootSubGzippedHTML: "<h1>gzipped.html</h1>",
|
webrootSubGzippedHTML: "<h1>gzipped.html</h1>",
|
||||||
webrootSubGzippedHTMLGz: "1.gzipped.html.gz",
|
webrootSubGzippedHTMLGz: "1.gzipped.html.gz",
|
||||||
|
@ -15,9 +15,11 @@
|
|||||||
package caddytls
|
package caddytls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
@ -27,24 +29,104 @@ import (
|
|||||||
"golang.org/x/crypto/ocsp"
|
"golang.org/x/crypto/ocsp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// certCache stores certificates in memory,
|
// certificateCache is to be an instance-wide cache of certs
|
||||||
// keying certificates by name. Certificates
|
// that site-specific TLS configs can refer to. Using a
|
||||||
// should not overlap in the names they serve,
|
// central map like this avoids duplication of certs in
|
||||||
// because a name only maps to one certificate.
|
// memory when the cert is used by multiple sites, and makes
|
||||||
var certCache = make(map[string]Certificate)
|
// maintenance easier. Because these are not to be global,
|
||||||
var certCacheMu sync.RWMutex
|
// the cache will get garbage collected after a config reload
|
||||||
|
// (a new instance will take its place).
|
||||||
|
type certificateCache struct {
|
||||||
|
sync.RWMutex
|
||||||
|
cache map[string]Certificate // keyed by certificate hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceCertificate replaces oldCert with newCert in the cache, and
|
||||||
|
// updates all configs that are pointing to the old certificate to
|
||||||
|
// point to the new one instead. newCert must already be loaded into
|
||||||
|
// the cache (this method does NOT load it into the cache).
|
||||||
|
//
|
||||||
|
// Note that all the names on the old certificate will be deleted
|
||||||
|
// from the name lookup maps of each config, then all the names on
|
||||||
|
// the new certificate will be added to the lookup maps as long as
|
||||||
|
// they do not overwrite any entries.
|
||||||
|
//
|
||||||
|
// The newCert may be modified and its cache entry updated.
|
||||||
|
//
|
||||||
|
// This method is safe for concurrent use.
|
||||||
|
func (certCache *certificateCache) replaceCertificate(oldCert, newCert Certificate) error {
|
||||||
|
certCache.Lock()
|
||||||
|
defer certCache.Unlock()
|
||||||
|
|
||||||
|
// have all the configs that are pointing to the old
|
||||||
|
// certificate point to the new certificate instead
|
||||||
|
for _, cfg := range oldCert.configs {
|
||||||
|
// first delete all the name lookup entries that
|
||||||
|
// pointed to the old certificate
|
||||||
|
for name, certKey := range cfg.Certificates {
|
||||||
|
if certKey == oldCert.Hash {
|
||||||
|
delete(cfg.Certificates, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// then add name lookup entries for the names
|
||||||
|
// on the new certificate, but don't overwrite
|
||||||
|
// entries that may already exist, not only as
|
||||||
|
// a courtesy, but importantly: because if we
|
||||||
|
// overwrote a value here, and this config no
|
||||||
|
// longer pointed to a certain certificate in
|
||||||
|
// the cache, that certificate's list of configs
|
||||||
|
// referring to it would be incorrect; so just
|
||||||
|
// insert entries, don't overwrite any
|
||||||
|
for _, name := range newCert.Names {
|
||||||
|
if _, ok := cfg.Certificates[name]; !ok {
|
||||||
|
cfg.Certificates[name] = newCert.Hash
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// since caching a new certificate attaches only the config
|
||||||
|
// that loaded it, the new certificate needs to be given the
|
||||||
|
// list of all the configs that use it, so copy the list
|
||||||
|
// over from the old certificate to the new certificate
|
||||||
|
// in the cache
|
||||||
|
newCert.configs = oldCert.configs
|
||||||
|
certCache.cache[newCert.Hash] = newCert
|
||||||
|
|
||||||
|
// finally, delete the old certificate from the cache
|
||||||
|
delete(certCache.cache, oldCert.Hash)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloadManagedCertificate reloads the certificate corresponding to the name(s)
|
||||||
|
// on oldCert into the cache, from storage. This also replaces the old certificate
|
||||||
|
// with the new one, so that all configurations that used the old cert now point
|
||||||
|
// to the new cert.
|
||||||
|
func (certCache *certificateCache) reloadManagedCertificate(oldCert Certificate) error {
|
||||||
|
// get the certificate from storage and cache it
|
||||||
|
newCert, err := oldCert.configs[0].CacheManagedCertificate(oldCert.Names[0])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to reload certificate for %v into cache: %v", oldCert.Names, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// and replace the old certificate with the new one
|
||||||
|
err = certCache.replaceCertificate(oldCert, newCert)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("replacing certificate %v: %v", oldCert.Names, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Certificate is a tls.Certificate with associated metadata tacked on.
|
// Certificate is a tls.Certificate with associated metadata tacked on.
|
||||||
// Even if the metadata can be obtained by parsing the certificate,
|
// Even if the metadata can be obtained by parsing the certificate,
|
||||||
// we can be more efficient by extracting the metadata once so it's
|
// we are more efficient by extracting the metadata onto this struct.
|
||||||
// just there, ready to use.
|
|
||||||
type Certificate struct {
|
type Certificate struct {
|
||||||
tls.Certificate
|
tls.Certificate
|
||||||
|
|
||||||
// Names is the list of names this certificate is written for.
|
// Names is the list of names this certificate is written for.
|
||||||
// The first is the CommonName (if any), the rest are SAN.
|
// The first is the CommonName (if any), the rest are SAN.
|
||||||
// This should be the exact list of keys by which this cert
|
|
||||||
// is accessed in the cache, careful to avoid overlap.
|
|
||||||
Names []string
|
Names []string
|
||||||
|
|
||||||
// NotAfter is when the certificate expires.
|
// NotAfter is when the certificate expires.
|
||||||
@ -53,59 +135,21 @@ type Certificate struct {
|
|||||||
// OCSP contains the certificate's parsed OCSP response.
|
// OCSP contains the certificate's parsed OCSP response.
|
||||||
OCSP *ocsp.Response
|
OCSP *ocsp.Response
|
||||||
|
|
||||||
// Config is the configuration with which the certificate was
|
// The hex-encoded hash of this cert's chain's bytes.
|
||||||
// loaded or obtained and with which it should be maintained.
|
Hash string
|
||||||
Config *Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// getCertificate gets a certificate that matches name (a server name)
|
// configs is the list of configs that use or refer to
|
||||||
// from the in-memory cache. If there is no exact match for name, it
|
// The first one is assumed to be the config that is
|
||||||
// will be checked against names of the form '*.example.com' (wildcard
|
// "in charge" of this certificate (i.e. determines
|
||||||
// certificates) according to RFC 6125. If a match is found, matched will
|
// whether it is managed, how it is managed, etc).
|
||||||
// be true. If no matches are found, matched will be false and a default
|
// This field will be populated by cacheCertificate.
|
||||||
// certificate will be returned with defaulted set to true. If no default
|
// Only meddle with it if you know what you're doing!
|
||||||
// certificate is set, defaulted will be set to false.
|
configs []*Config
|
||||||
//
|
|
||||||
// The logic in this function is adapted from the Go standard library,
|
|
||||||
// which is by the Go Authors.
|
|
||||||
//
|
|
||||||
// This function is safe for concurrent use.
|
|
||||||
func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
|
||||||
var ok bool
|
|
||||||
|
|
||||||
// Not going to trim trailing dots here since RFC 3546 says,
|
|
||||||
// "The hostname is represented ... without a trailing dot."
|
|
||||||
// Just normalize to lowercase.
|
|
||||||
name = strings.ToLower(name)
|
|
||||||
|
|
||||||
certCacheMu.RLock()
|
|
||||||
defer certCacheMu.RUnlock()
|
|
||||||
|
|
||||||
// exact match? great, let's use it
|
|
||||||
if cert, ok = certCache[name]; ok {
|
|
||||||
matched = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// try replacing labels in the name with wildcards until we get a match
|
|
||||||
labels := strings.Split(name, ".")
|
|
||||||
for i := range labels {
|
|
||||||
labels[i] = "*"
|
|
||||||
candidate := strings.Join(labels, ".")
|
|
||||||
if cert, ok = certCache[candidate]; ok {
|
|
||||||
matched = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if nothing matches, use the default certificate or bust
|
|
||||||
cert, defaulted = certCache[""]
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheManagedCertificate loads the certificate for domain into the
|
// CacheManagedCertificate loads the certificate for domain into the
|
||||||
// cache, flagging it as Managed and, if onDemand is true, as "OnDemand"
|
// cache, from the TLS storage for managed certificates. It returns a
|
||||||
// (meaning that it was obtained or loaded during a TLS handshake).
|
// copy of the Certificate that was put into the cache.
|
||||||
//
|
//
|
||||||
// This method is safe for concurrent use.
|
// This method is safe for concurrent use.
|
||||||
func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
|
func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
|
||||||
@ -117,39 +161,24 @@ func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return Certificate{}, err
|
return Certificate{}, err
|
||||||
}
|
}
|
||||||
cert, err := makeCertificate(siteData.Cert, siteData.Key)
|
cert, err := makeCertificateWithOCSP(siteData.Cert, siteData.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cert, err
|
return cert, err
|
||||||
}
|
}
|
||||||
cert.Config = cfg
|
return cfg.cacheCertificate(cert), nil
|
||||||
cacheCertificate(cert)
|
|
||||||
return cert, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
|
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
|
||||||
// and keyFile, which must be in PEM format. It stores the certificate in
|
// and keyFile, which must be in PEM format. It stores the certificate in
|
||||||
// memory after evicting any other entries in the cache keyed by the names
|
// the in-memory cache.
|
||||||
// on this certificate. In other words, it replaces existing certificates keyed
|
|
||||||
// by the names on this certificate. The Managed and OnDemand flags of the
|
|
||||||
// certificate will be set to false.
|
|
||||||
//
|
//
|
||||||
// This function is safe for concurrent use.
|
// This function is safe for concurrent use.
|
||||||
func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
|
func (cfg *Config) cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
|
||||||
cert, err := makeCertificateFromDisk(certFile, keyFile)
|
cert, err := makeCertificateFromDiskWithOCSP(certFile, keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
cfg.cacheCertificate(cert)
|
||||||
// since this is manually managed, this call might be part of a reload after
|
|
||||||
// the owner renewed a certificate; so clear cache of any previous cert first,
|
|
||||||
// otherwise the renewed certificate may never be loaded
|
|
||||||
certCacheMu.Lock()
|
|
||||||
for _, name := range cert.Names {
|
|
||||||
delete(certCache, name)
|
|
||||||
}
|
|
||||||
certCacheMu.Unlock()
|
|
||||||
|
|
||||||
cacheCertificate(cert)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,20 +186,20 @@ func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
|
|||||||
// of the certificate and key, then caches it in memory.
|
// of the certificate and key, then caches it in memory.
|
||||||
//
|
//
|
||||||
// This function is safe for concurrent use.
|
// This function is safe for concurrent use.
|
||||||
func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
|
func (cfg *Config) cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
|
||||||
cert, err := makeCertificate(certBytes, keyBytes)
|
cert, err := makeCertificateWithOCSP(certBytes, keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cacheCertificate(cert)
|
cfg.cacheCertificate(cert)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeCertificateFromDisk makes a Certificate by loading the
|
// makeCertificateFromDiskWithOCSP makes a Certificate by loading the
|
||||||
// certificate and key files. It fills out all the fields in
|
// certificate and key files. It fills out all the fields in
|
||||||
// the certificate except for the Managed and OnDemand flags.
|
// the certificate except for the Managed and OnDemand flags.
|
||||||
// (It is up to the caller to set those.)
|
// (It is up to the caller to set those.) It staples OCSP.
|
||||||
func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
|
func makeCertificateFromDiskWithOCSP(certFile, keyFile string) (Certificate, error) {
|
||||||
certPEMBlock, err := ioutil.ReadFile(certFile)
|
certPEMBlock, err := ioutil.ReadFile(certFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Certificate{}, err
|
return Certificate{}, err
|
||||||
@ -179,13 +208,14 @@ func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return Certificate{}, err
|
return Certificate{}, err
|
||||||
}
|
}
|
||||||
return makeCertificate(certPEMBlock, keyPEMBlock)
|
return makeCertificateWithOCSP(certPEMBlock, keyPEMBlock)
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeCertificate turns a certificate PEM bundle and a key PEM block into
|
// makeCertificate turns a certificate PEM bundle and a key PEM block into
|
||||||
// a Certificate, with OCSP and other relevant metadata tagged with it,
|
// a Certificate with necessary metadata from parsing its bytes filled into
|
||||||
// except for the OnDemand and Managed flags. It is up to the caller to
|
// its struct fields for convenience (except for the OnDemand and Managed
|
||||||
// set those properties.
|
// flags; it is up to the caller to set those properties!). This function
|
||||||
|
// does NOT staple OCSP.
|
||||||
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
||||||
var cert Certificate
|
var cert Certificate
|
||||||
|
|
||||||
@ -195,16 +225,26 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
|||||||
return cert, err
|
return cert, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract relevant metadata and staple OCSP
|
// Extract necessary metadata
|
||||||
err = fillCertFromLeaf(&cert, tlsCert)
|
err = fillCertFromLeaf(&cert, tlsCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cert, err
|
return cert, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeCertificateWithOCSP is the same as makeCertificate except that it also
|
||||||
|
// staples OCSP to the certificate.
|
||||||
|
func makeCertificateWithOCSP(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
||||||
|
cert, err := makeCertificate(certPEMBlock, keyPEMBlock)
|
||||||
|
if err != nil {
|
||||||
|
return cert, err
|
||||||
|
}
|
||||||
err = stapleOCSP(&cert, certPEMBlock)
|
err = stapleOCSP(&cert, certPEMBlock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARNING] Stapling OCSP: %v", err)
|
log.Printf("[WARNING] Stapling OCSP: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,65 +283,104 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
|
|||||||
return errors.New("certificate has no names")
|
return errors.New("certificate has no names")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// save the hash of this certificate (chain) and
|
||||||
|
// expiration date, for necessity and efficiency
|
||||||
|
cert.Hash = hashCertificateChain(cert.Certificate.Certificate)
|
||||||
cert.NotAfter = leaf.NotAfter
|
cert.NotAfter = leaf.NotAfter
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheCertificate adds cert to the in-memory cache. If the cache is
|
// hashCertificateChain computes the unique hash of certChain,
|
||||||
// empty, cert will be used as the default certificate. If the cache is
|
// which is the chain of DER-encoded bytes. It returns the
|
||||||
// full, random entries are deleted until there is room to map all the
|
// hex encoding of the hash.
|
||||||
// names on the certificate.
|
func hashCertificateChain(certChain [][]byte) string {
|
||||||
|
h := sha256.New()
|
||||||
|
for _, certInChain := range certChain {
|
||||||
|
h.Write(certInChain)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// managedCertInStorageExpiresSoon returns true if cert (being a
|
||||||
|
// managed certificate) is expiring within RenewDurationBefore.
|
||||||
|
// It returns false if there was an error checking the expiration
|
||||||
|
// of the certificate as found in storage, or if the certificate
|
||||||
|
// in storage is NOT expiring soon. A certificate that is expiring
|
||||||
|
// soon in our cache but is not expiring soon in storage probably
|
||||||
|
// means that another instance renewed the certificate in the
|
||||||
|
// meantime, and it would be a good idea to simply load the cert
|
||||||
|
// into our cache rather than repeating the renewal process again.
|
||||||
|
func managedCertInStorageExpiresSoon(cert Certificate) (bool, error) {
|
||||||
|
if len(cert.configs) == 0 {
|
||||||
|
return false, fmt.Errorf("no configs for certificate")
|
||||||
|
}
|
||||||
|
storage, err := cert.configs[0].StorageFor(cert.configs[0].CAUrl)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
siteData, err := storage.LoadSite(cert.Names[0])
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
tlsCert, err := tls.X509KeyPair(siteData.Cert, siteData.Key)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
timeLeft := leaf.NotAfter.Sub(time.Now().UTC())
|
||||||
|
return timeLeft < RenewDurationBefore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheCertificate adds cert to the in-memory cache. If a certificate
|
||||||
|
// with the same hash is already cached, it is NOT overwritten; instead,
|
||||||
|
// cfg is added to the existing certificate's list of configs if not
|
||||||
|
// already in the list. Then all the names on cert are used to add
|
||||||
|
// entries to cfg.Certificates (the config's name lookup map).
|
||||||
|
// Then the certificate is stored/updated in the cache. It returns
|
||||||
|
// a copy of the certificate that ends up being stored in the cache.
|
||||||
//
|
//
|
||||||
// This certificate will be keyed to the names in cert.Names. Any names
|
// It is VERY important, even for some test cases, that the Hash field
|
||||||
// already used as a cache key will NOT be replaced by this cert; in
|
// of the cert be set properly.
|
||||||
// other words, no overlap is allowed, and this certificate will not
|
|
||||||
// service those pre-existing names.
|
|
||||||
//
|
//
|
||||||
// This function is safe for concurrent use.
|
// This function is safe for concurrent use.
|
||||||
func cacheCertificate(cert Certificate) {
|
func (cfg *Config) cacheCertificate(cert Certificate) Certificate {
|
||||||
if cert.Config == nil {
|
cfg.certCache.Lock()
|
||||||
cert.Config = new(Config)
|
defer cfg.certCache.Unlock()
|
||||||
|
|
||||||
|
// if this certificate already exists in the cache,
|
||||||
|
// use it instead of overwriting it -- very important!
|
||||||
|
if existingCert, ok := cfg.certCache.cache[cert.Hash]; ok {
|
||||||
|
cert = existingCert
|
||||||
}
|
}
|
||||||
certCacheMu.Lock()
|
|
||||||
if _, ok := certCache[""]; !ok {
|
// attach this config to the certificate so we know which
|
||||||
// use as default - must be *appended* to end of list, or bad things happen!
|
// configs are referencing/using the certificate, but don't
|
||||||
cert.Names = append(cert.Names, "")
|
// duplicate entries
|
||||||
}
|
var found bool
|
||||||
for len(certCache)+len(cert.Names) > 10000 {
|
for _, c := range cert.configs {
|
||||||
// for simplicity, just remove random elements
|
if c == cfg {
|
||||||
for key := range certCache {
|
found = true
|
||||||
if key == "" { // ... but not the default cert
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
delete(certCache, key)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i := 0; i < len(cert.Names); i++ {
|
if !found {
|
||||||
name := cert.Names[i]
|
cert.configs = append(cert.configs, cfg)
|
||||||
if _, ok := certCache[name]; ok {
|
|
||||||
// do not allow certificates to overlap in the names they serve;
|
|
||||||
// this ambiguity causes problems because it is confusing while
|
|
||||||
// maintaining certificates; see OCSP maintenance code and
|
|
||||||
// https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt.
|
|
||||||
log.Printf("[NOTICE] There is already a certificate loaded for %s, "+
|
|
||||||
"so certificate for %v will not service that name",
|
|
||||||
name, cert.Names)
|
|
||||||
cert.Names = append(cert.Names[:i], cert.Names[i+1:]...)
|
|
||||||
i--
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
certCache[name] = cert
|
|
||||||
}
|
}
|
||||||
certCacheMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// uncacheCertificate deletes name's certificate from the
|
// key the certificate by all its names for this config only,
|
||||||
// cache. If name is not a key in the certificate cache,
|
// this is how we find the certificate during handshakes
|
||||||
// this function does nothing.
|
// (yes, if certs overlap in the names they serve, one will
|
||||||
func uncacheCertificate(name string) {
|
// overwrite another here, but that's just how it goes)
|
||||||
certCacheMu.Lock()
|
for _, name := range cert.Names {
|
||||||
delete(certCache, name)
|
cfg.Certificates[name] = cert.Hash
|
||||||
certCacheMu.Unlock()
|
}
|
||||||
|
|
||||||
|
// store the certificate
|
||||||
|
cfg.certCache.cache[cert.Hash] = cert
|
||||||
|
|
||||||
|
return cert
|
||||||
}
|
}
|
||||||
|
@ -17,57 +17,71 @@ package caddytls
|
|||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestUnexportedGetCertificate(t *testing.T) {
|
func TestUnexportedGetCertificate(t *testing.T) {
|
||||||
defer func() { certCache = make(map[string]Certificate) }()
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
|
|
||||||
// When cache is empty
|
// When cache is empty
|
||||||
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
|
if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted {
|
||||||
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
|
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When cache has one certificate in it (also is default)
|
// When cache has one certificate in it
|
||||||
defaultCert := Certificate{Names: []string{"example.com", ""}}
|
firstCert := Certificate{Names: []string{"example.com"}}
|
||||||
certCache[""] = defaultCert
|
certCache.cache["0xdeadbeef"] = firstCert
|
||||||
certCache["example.com"] = defaultCert
|
cfg.Certificates["example.com"] = "0xdeadbeef"
|
||||||
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||||
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||||
}
|
}
|
||||||
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
|
if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||||
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When retrieving wildcard certificate
|
// When retrieving wildcard certificate
|
||||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
|
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
|
||||||
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
cfg.Certificates["*.example.com"] = "0xb01dface"
|
||||||
|
if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
||||||
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When no certificate matches, the default is returned
|
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
|
||||||
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
|
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
|
||||||
|
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When no certificate matches and SNI is NOT provided, a random is returned
|
||||||
|
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
|
||||||
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||||
} else if cert.Names[0] != "example.com" {
|
|
||||||
t.Errorf("Expected default cert, got: %v", cert)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCacheCertificate(t *testing.T) {
|
func TestCacheCertificate(t *testing.T) {
|
||||||
defer func() { certCache = make(map[string]Certificate) }()
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
|
|
||||||
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
|
cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"})
|
||||||
if _, ok := certCache["example.com"]; !ok {
|
if len(certCache.cache) != 1 {
|
||||||
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
|
t.Errorf("Expected length of certificate cache to be 1")
|
||||||
}
|
}
|
||||||
if _, ok := certCache["sub.example.com"]; !ok {
|
if _, ok := certCache.cache["foobar"]; !ok {
|
||||||
t.Error("Expected first cert to be cached by key 'sub.example.com', but it wasn't")
|
t.Error("Expected first cert to be cached by key 'foobar', but it wasn't")
|
||||||
}
|
}
|
||||||
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
|
if _, ok := cfg.Certificates["example.com"]; !ok {
|
||||||
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
|
t.Error("Expected first cert to be keyed by 'example.com', but it wasn't")
|
||||||
|
}
|
||||||
|
if _, ok := cfg.Certificates["sub.example.com"]; !ok {
|
||||||
|
t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't")
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheCertificate(Certificate{Names: []string{"example2.com"}})
|
// different config, but using same cache; and has cert with overlapping name,
|
||||||
if _, ok := certCache["example2.com"]; !ok {
|
// but different hash
|
||||||
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
|
cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
|
cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"})
|
||||||
|
if _, ok := certCache.cache["barbaz"]; !ok {
|
||||||
|
t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't")
|
||||||
}
|
}
|
||||||
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
|
if hash, ok := cfg2.Certificates["example.com"]; !ok {
|
||||||
t.Error("Expected second cert to NOT be cached as default, but it was")
|
t.Error("Expected second cert to be keyed by 'example.com', but it wasn't")
|
||||||
|
} else if hash != "barbaz" {
|
||||||
|
t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ type ACMEClient struct {
|
|||||||
AllowPrompts bool
|
AllowPrompts bool
|
||||||
config *Config
|
config *Config
|
||||||
acmeClient *acme.Client
|
acmeClient *acme.Client
|
||||||
locker Locker
|
storage Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
// newACMEClient creates a new ACMEClient given an email and whether
|
// newACMEClient creates a new ACMEClient given an email and whether
|
||||||
@ -122,10 +122,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||||||
AllowPrompts: allowPrompts,
|
AllowPrompts: allowPrompts,
|
||||||
config: config,
|
config: config,
|
||||||
acmeClient: client,
|
acmeClient: client,
|
||||||
locker: &syncLock{
|
storage: storage,
|
||||||
nameLocks: make(map[string]*sync.WaitGroup),
|
|
||||||
nameLocksMu: sync.Mutex{},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.DNSProvider == "" {
|
if config.DNSProvider == "" {
|
||||||
@ -161,7 +158,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||||||
|
|
||||||
// See if TLS challenge needs to be handled by our own facilities
|
// See if TLS challenge needs to be handled by our own facilities
|
||||||
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
|
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
|
||||||
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSniSolver{})
|
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable any challenges that should not be used
|
// Disable any challenges that should not be used
|
||||||
@ -210,13 +207,7 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
|||||||
// Callers who have access to a Config value should use the ObtainCert
|
// Callers who have access to a Config value should use the ObtainCert
|
||||||
// method on that instead of this lower-level method.
|
// method on that instead of this lower-level method.
|
||||||
func (c *ACMEClient) Obtain(name string) error {
|
func (c *ACMEClient) Obtain(name string) error {
|
||||||
// Get access to ACME storage
|
waiter, err := c.storage.TryLock(name)
|
||||||
storage, err := c.config.StorageFor(c.config.CAUrl)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
waiter, err := c.locker.TryLock(name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -226,7 +217,7 @@ func (c *ACMEClient) Obtain(name string) error {
|
|||||||
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := c.locker.Unlock(name); err != nil {
|
if err := c.storage.Unlock(name); err != nil {
|
||||||
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
|
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -269,7 +260,7 @@ Attempts:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Success - immediately save the certificate resource
|
// Success - immediately save the certificate resource
|
||||||
err = saveCertResource(storage, certificate)
|
err = saveCertResource(c.storage, certificate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error saving assets for %v: %v", name, err)
|
return fmt.Errorf("error saving assets for %v: %v", name, err)
|
||||||
}
|
}
|
||||||
@ -282,35 +273,30 @@ Attempts:
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Renew renews the managed certificate for name. This function is
|
// Renew renews the managed certificate for name. It puts the renewed
|
||||||
// safe for concurrent use.
|
// certificate into storage (not the cache). This function is safe for
|
||||||
|
// concurrent use.
|
||||||
//
|
//
|
||||||
// Callers who have access to a Config value should use the RenewCert
|
// Callers who have access to a Config value should use the RenewCert
|
||||||
// method on that instead of this lower-level method.
|
// method on that instead of this lower-level method.
|
||||||
func (c *ACMEClient) Renew(name string) error {
|
func (c *ACMEClient) Renew(name string) error {
|
||||||
// Get access to ACME storage
|
waiter, err := c.storage.TryLock(name)
|
||||||
storage, err := c.config.StorageFor(c.config.CAUrl)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
waiter, err := c.locker.TryLock(name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if waiter != nil {
|
if waiter != nil {
|
||||||
log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name)
|
log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name)
|
||||||
waiter.Wait()
|
waiter.Wait()
|
||||||
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
return nil // assume that the worker that renewed the cert succeeded; avoid hammering this path over and over
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := c.locker.Unlock(name); err != nil {
|
if err := c.storage.Unlock(name); err != nil {
|
||||||
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
|
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Prepare for renewal (load PEM cert, key, and meta)
|
// Prepare for renewal (load PEM cert, key, and meta)
|
||||||
siteData, err := storage.LoadSite(name)
|
siteData, err := c.storage.LoadSite(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -357,18 +343,13 @@ func (c *ACMEClient) Renew(name string) error {
|
|||||||
go diagnostics.Increment("acme_certificates_obtained")
|
go diagnostics.Increment("acme_certificates_obtained")
|
||||||
go diagnostics.Increment("acme_certificates_renewed")
|
go diagnostics.Increment("acme_certificates_renewed")
|
||||||
|
|
||||||
return saveCertResource(storage, newCertMeta)
|
return saveCertResource(c.storage, newCertMeta)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoke revokes the certificate for name and deltes
|
// Revoke revokes the certificate for name and deletes
|
||||||
// it from storage.
|
// it from storage.
|
||||||
func (c *ACMEClient) Revoke(name string) error {
|
func (c *ACMEClient) Revoke(name string) error {
|
||||||
storage, err := c.config.StorageFor(c.config.CAUrl)
|
siteExists, err := c.storage.SiteExists(name)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
siteExists, err := storage.SiteExists(name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -377,7 +358,7 @@ func (c *ACMEClient) Revoke(name string) error {
|
|||||||
return errors.New("no certificate and key for " + name)
|
return errors.New("no certificate and key for " + name)
|
||||||
}
|
}
|
||||||
|
|
||||||
siteData, err := storage.LoadSite(name)
|
siteData, err := c.storage.LoadSite(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -387,7 +368,7 @@ func (c *ACMEClient) Revoke(name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = storage.DeleteSite(name)
|
err = c.storage.DeleteSite(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
|
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
|
||||||
}
|
}
|
||||||
|
@ -93,16 +93,17 @@ type Config struct {
|
|||||||
// an ACME challenge
|
// an ACME challenge
|
||||||
ListenHost string
|
ListenHost string
|
||||||
|
|
||||||
// The alternate port (ONLY port, not host)
|
// The alternate port (ONLY port, not host) to
|
||||||
// to use for the ACME HTTP challenge; this
|
// use for the ACME HTTP challenge; if non-empty,
|
||||||
// port will be used if we proxy challenges
|
// this port will be used instead of
|
||||||
// coming in on port 80 to this alternate port
|
// HTTPChallengePort to spin up a listener for
|
||||||
|
// the HTTP challenge
|
||||||
AltHTTPPort string
|
AltHTTPPort string
|
||||||
|
|
||||||
// The alternate port (ONLY port, not host)
|
// The alternate port (ONLY port, not host)
|
||||||
// to use for the ACME TLS-SNI challenge.
|
// to use for the ACME TLS-SNI challenge.
|
||||||
// The system must forward the standard port
|
// The system must forward TLSSNIChallengePort
|
||||||
// for the TLS-SNI challenge to this port.
|
// to this port for challenge to succeed
|
||||||
AltTLSSNIPort string
|
AltTLSSNIPort string
|
||||||
|
|
||||||
// The string identifier of the DNS provider
|
// The string identifier of the DNS provider
|
||||||
@ -134,7 +135,12 @@ type Config struct {
|
|||||||
// Protocol Negotiation (ALPN).
|
// Protocol Negotiation (ALPN).
|
||||||
ALPN []string
|
ALPN []string
|
||||||
|
|
||||||
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
|
// The map of hostname to certificate hash. This is used to complete
|
||||||
|
// handshakes and serve the right certificate given the SNI.
|
||||||
|
Certificates map[string]string
|
||||||
|
|
||||||
|
certCache *certificateCache // pointer to the Instance's certificate store
|
||||||
|
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnDemandState contains some state relevant for providing
|
// OnDemandState contains some state relevant for providing
|
||||||
@ -155,6 +161,25 @@ type OnDemandState struct {
|
|||||||
AskURL *url.URL
|
AskURL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewConfig returns a new Config with a pointer to the instance's
|
||||||
|
// certificate cache. You will usually need to set Other fields on
|
||||||
|
// the returned Config for successful practical use.
|
||||||
|
func NewConfig(inst *caddy.Instance) *Config {
|
||||||
|
inst.StorageMu.RLock()
|
||||||
|
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||||
|
inst.StorageMu.RUnlock()
|
||||||
|
if !ok || certCache == nil {
|
||||||
|
certCache = &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
inst.StorageMu.Lock()
|
||||||
|
inst.Storage[CertCacheInstStorageKey] = certCache
|
||||||
|
inst.StorageMu.Unlock()
|
||||||
|
}
|
||||||
|
cfg := new(Config)
|
||||||
|
cfg.Certificates = make(map[string]string)
|
||||||
|
cfg.certCache = certCache
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
// ObtainCert obtains a certificate for name using c, as long
|
// ObtainCert obtains a certificate for name using c, as long
|
||||||
// as a certificate does not already exist in storage for that
|
// as a certificate does not already exist in storage for that
|
||||||
// name. The name must qualify and c must be flagged as Managed.
|
// name. The name must qualify and c must be flagged as Managed.
|
||||||
@ -330,7 +355,9 @@ func (c *Config) buildStandardTLSConfig() error {
|
|||||||
|
|
||||||
// MakeTLSConfig makes a tls.Config from configs. The returned
|
// MakeTLSConfig makes a tls.Config from configs. The returned
|
||||||
// tls.Config is programmed to load the matching caddytls.Config
|
// tls.Config is programmed to load the matching caddytls.Config
|
||||||
// based on the hostname in SNI, but that's all.
|
// based on the hostname in SNI, but that's all. This is used
|
||||||
|
// to create a single TLS configuration for a listener (a group
|
||||||
|
// of sites).
|
||||||
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
||||||
if len(configs) == 0 {
|
if len(configs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -358,15 +385,28 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
|||||||
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
|
configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto)
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert each caddytls.Config into a tls.Config
|
// convert this caddytls.Config into a tls.Config
|
||||||
if err := cfg.buildStandardTLSConfig(); err != nil {
|
if err := cfg.buildStandardTLSConfig(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Key this config by its hostname (overwriting
|
// if an existing config with this hostname was already
|
||||||
// configs with the same hostname pattern); during
|
// configured, then they must be identical (or at least
|
||||||
// TLS handshakes, configs are loaded based on
|
// compatible), otherwise that is a configuration error
|
||||||
// the hostname pattern, according to client's SNI.
|
if otherConfig, ok := configMap[cfg.Hostname]; ok {
|
||||||
|
if err := assertConfigsCompatible(cfg, otherConfig); err != nil {
|
||||||
|
return nil, fmt.Errorf("incompabile TLS configurations for the same SNI "+
|
||||||
|
"name (%s) on the same listener: %v",
|
||||||
|
cfg.Hostname, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// key this config by its hostname (overwrites
|
||||||
|
// configs with the same hostname pattern; should
|
||||||
|
// be OK since we already asserted they are roughly
|
||||||
|
// the same); during TLS handshakes, configs are
|
||||||
|
// loaded based on the hostname pattern, according
|
||||||
|
// to client's SNI
|
||||||
configMap[cfg.Hostname] = cfg
|
configMap[cfg.Hostname] = cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,6 +423,63 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// assertConfigsCompatible returns an error if the two Configs
|
||||||
|
// do not have the same (or roughly compatible) configurations.
|
||||||
|
// If one of the tlsConfig pointers on either Config is nil,
|
||||||
|
// an error will be returned. If both are nil, no error.
|
||||||
|
func assertConfigsCompatible(cfg1, cfg2 *Config) error {
|
||||||
|
c1, c2 := cfg1.tlsConfig, cfg2.tlsConfig
|
||||||
|
|
||||||
|
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
|
||||||
|
return fmt.Errorf("one config is not made")
|
||||||
|
}
|
||||||
|
if c1 == nil && c2 == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c1.CipherSuites) != len(c2.CipherSuites) {
|
||||||
|
return fmt.Errorf("different number of allowed cipher suites")
|
||||||
|
}
|
||||||
|
for i, ciph := range c1.CipherSuites {
|
||||||
|
if c2.CipherSuites[i] != ciph {
|
||||||
|
return fmt.Errorf("different cipher suites or different order")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c1.CurvePreferences) != len(c2.CurvePreferences) {
|
||||||
|
return fmt.Errorf("different number of allowed cipher suites")
|
||||||
|
}
|
||||||
|
for i, curve := range c1.CurvePreferences {
|
||||||
|
if c2.CurvePreferences[i] != curve {
|
||||||
|
return fmt.Errorf("different curve preferences or different order")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c1.NextProtos) != len(c2.NextProtos) {
|
||||||
|
return fmt.Errorf("different number of ALPN (NextProtos) values")
|
||||||
|
}
|
||||||
|
for i, proto := range c1.NextProtos {
|
||||||
|
if c2.NextProtos[i] != proto {
|
||||||
|
return fmt.Errorf("different ALPN (NextProtos) values or different order")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c1.PreferServerCipherSuites != c2.PreferServerCipherSuites {
|
||||||
|
return fmt.Errorf("one prefers server cipher suites, the other does not")
|
||||||
|
}
|
||||||
|
if c1.MinVersion != c2.MinVersion {
|
||||||
|
return fmt.Errorf("minimum TLS version mismatch")
|
||||||
|
}
|
||||||
|
if c1.MaxVersion != c2.MaxVersion {
|
||||||
|
return fmt.Errorf("maximum TLS version mismatch")
|
||||||
|
}
|
||||||
|
if c1.ClientAuth != c2.ClientAuth {
|
||||||
|
return fmt.Errorf("client authentication policy mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ConfigGetter gets a Config keyed by key.
|
// ConfigGetter gets a Config keyed by key.
|
||||||
type ConfigGetter func(c *caddy.Controller) *Config
|
type ConfigGetter func(c *caddy.Controller) *Config
|
||||||
|
|
||||||
@ -522,7 +619,7 @@ var supportedCurvesMap = map[string]tls.CurveID{
|
|||||||
"P521": tls.CurveP521,
|
"P521": tls.CurveP521,
|
||||||
}
|
}
|
||||||
|
|
||||||
// List of all the curves we want to use by default
|
// List of all the curves we want to use by default.
|
||||||
//
|
//
|
||||||
// This list should only include curves which are fast by design (e.g. X25519)
|
// This list should only include curves which are fast by design (e.g. X25519)
|
||||||
// and those for which an optimized assembly implementation exists (e.g. P256).
|
// and those for which an optimized assembly implementation exists (e.g. P256).
|
||||||
@ -548,4 +645,8 @@ const (
|
|||||||
// be capable of proxying or forwarding the request to this
|
// be capable of proxying or forwarding the request to this
|
||||||
// alternate port.
|
// alternate port.
|
||||||
DefaultHTTPAlternatePort = "5033"
|
DefaultHTTPAlternatePort = "5033"
|
||||||
|
|
||||||
|
// CertCacheInstStorageKey is the name of the key for
|
||||||
|
// accessing the certificate storage on the *caddy.Instance.
|
||||||
|
CertCacheInstStorageKey = "tls_cert_cache"
|
||||||
)
|
)
|
||||||
|
@ -237,15 +237,17 @@ func makeSelfSignedCert(config *Config) error {
|
|||||||
return fmt.Errorf("could not create certificate: %v", err)
|
return fmt.Errorf("could not create certificate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheCertificate(Certificate{
|
chain := [][]byte{derBytes}
|
||||||
|
|
||||||
|
config.cacheCertificate(Certificate{
|
||||||
Certificate: tls.Certificate{
|
Certificate: tls.Certificate{
|
||||||
Certificate: [][]byte{derBytes},
|
Certificate: chain,
|
||||||
PrivateKey: privKey,
|
PrivateKey: privKey,
|
||||||
Leaf: cert,
|
Leaf: cert,
|
||||||
},
|
},
|
||||||
Names: cert.DNSNames,
|
Names: cert.DNSNames,
|
||||||
NotAfter: cert.NotAfter,
|
NotAfter: cert.NotAfter,
|
||||||
Config: config,
|
Hash: hashCertificateChain(chain),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -38,9 +38,9 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
|
|||||||
// Storage instance backed by the local disk. The resulting Storage
|
// Storage instance backed by the local disk. The resulting Storage
|
||||||
// instance is guaranteed to be non-nil if there is no error.
|
// instance is guaranteed to be non-nil if there is no error.
|
||||||
func NewFileStorage(caURL *url.URL) (Storage, error) {
|
func NewFileStorage(caURL *url.URL) (Storage, error) {
|
||||||
return &FileStorage{
|
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
|
||||||
Path: filepath.Join(storageBasePath, caURL.Host),
|
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
|
||||||
}, nil
|
return storage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FileStorage facilitates forming file paths derived from a root
|
// FileStorage facilitates forming file paths derived from a root
|
||||||
@ -48,6 +48,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) {
|
|||||||
// cross-platform way or persisting ACME assets on the file system.
|
// cross-platform way or persisting ACME assets on the file system.
|
||||||
type FileStorage struct {
|
type FileStorage struct {
|
||||||
Path string
|
Path string
|
||||||
|
Locker
|
||||||
}
|
}
|
||||||
|
|
||||||
// sites gets the directory that stores site certificate and keys.
|
// sites gets the directory that stores site certificate and keys.
|
||||||
|
127
caddytls/filestoragesync.go
Normal file
127
caddytls/filestoragesync.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
// Copyright 2015 Light Code Labs, LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package caddytls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// be sure to remove lock files when exiting the process!
|
||||||
|
caddy.OnProcessExit = append(caddy.OnProcessExit, func() {
|
||||||
|
fileStorageNameLocksMu.Lock()
|
||||||
|
defer fileStorageNameLocksMu.Unlock()
|
||||||
|
for key, fw := range fileStorageNameLocks {
|
||||||
|
os.Remove(fw.filename)
|
||||||
|
delete(fileStorageNameLocks, key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileStorageLock facilitates ACME-related locking by using
|
||||||
|
// the associated FileStorage, so multiple processes can coordinate
|
||||||
|
// renewals on the certificates on a shared file system.
|
||||||
|
type fileStorageLock struct {
|
||||||
|
caURL string
|
||||||
|
storage *FileStorage
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryLock attempts to get a lock for name, otherwise it returns
|
||||||
|
// a Waiter value to wait until the other process is finished.
|
||||||
|
func (s *fileStorageLock) TryLock(name string) (Waiter, error) {
|
||||||
|
fileStorageNameLocksMu.Lock()
|
||||||
|
defer fileStorageNameLocksMu.Unlock()
|
||||||
|
|
||||||
|
// see if lock already exists within this process
|
||||||
|
fw, ok := fileStorageNameLocks[s.caURL+name]
|
||||||
|
if ok {
|
||||||
|
// lock already created within process, let caller wait on it
|
||||||
|
return fw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// attempt to persist lock to disk by creating lock file
|
||||||
|
fw = &fileWaiter{
|
||||||
|
filename: s.storage.siteCertFile(name) + ".lock",
|
||||||
|
wg: new(sync.WaitGroup),
|
||||||
|
}
|
||||||
|
// parent dir must exist
|
||||||
|
if err := os.MkdirAll(s.storage.site(name), 0700); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsExist(err) {
|
||||||
|
// another process has the lock; use it to wait
|
||||||
|
return fw, nil
|
||||||
|
}
|
||||||
|
// otherwise, this was some unexpected error
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lf.Close()
|
||||||
|
|
||||||
|
// looks like we get the lock
|
||||||
|
fw.wg.Add(1)
|
||||||
|
fileStorageNameLocks[s.caURL+name] = fw
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlock unlocks name.
|
||||||
|
func (s *fileStorageLock) Unlock(name string) error {
|
||||||
|
fileStorageNameLocksMu.Lock()
|
||||||
|
defer fileStorageNameLocksMu.Unlock()
|
||||||
|
fw, ok := fileStorageNameLocks[s.caURL+name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
||||||
|
}
|
||||||
|
os.Remove(fw.filename)
|
||||||
|
fw.wg.Done()
|
||||||
|
delete(fileStorageNameLocks, s.caURL+name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileWaiter waits for a file to disappear; it polls
|
||||||
|
// the file system to check for the existence of a file.
|
||||||
|
// It also has a WaitGroup which will be faster than
|
||||||
|
// polling, for when locking need only happen within this
|
||||||
|
// process.
|
||||||
|
type fileWaiter struct {
|
||||||
|
filename string
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait waits until the lock is released.
|
||||||
|
func (fw *fileWaiter) Wait() {
|
||||||
|
start := time.Now()
|
||||||
|
fw.wg.Wait()
|
||||||
|
for time.Since(start) < 1*time.Hour {
|
||||||
|
_, err := os.Stat(fw.filename)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name
|
||||||
|
var fileStorageNameLocksMu sync.Mutex
|
||||||
|
|
||||||
|
var _ Locker = &fileStorageLock{}
|
||||||
|
var _ Waiter = &fileWaiter{}
|
@ -61,15 +61,15 @@ func (cg configGroup) getConfig(name string) *Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// as a fallback, try a config that serves all names
|
// try a config that serves all names (this
|
||||||
|
// is basically the same as a config defined
|
||||||
|
// for "*" -- I think -- but the above loop
|
||||||
|
// doesn't try an empty string)
|
||||||
if config, ok := cg[""]; ok {
|
if config, ok := cg[""]; ok {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
// as a last resort, use a random config
|
// no matches, so just serve up a random config
|
||||||
// (even if the config isn't for that hostname,
|
|
||||||
// it should help us serve clients without SNI
|
|
||||||
// or at least defer TLS alerts to the cert)
|
|
||||||
for _, config := range cg {
|
for _, config := range cg {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
@ -121,6 +121,86 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
|
|||||||
return &cert.Certificate, err
|
return &cert.Certificate, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getCertificate gets a certificate that matches name (a server name)
|
||||||
|
// from the in-memory cache, according to the lookup table associated with
|
||||||
|
// cfg. The lookup then points to a certificate in the Instance certificate
|
||||||
|
// cache.
|
||||||
|
//
|
||||||
|
// If there is no exact match for name, it will be checked against names of
|
||||||
|
// the form '*.example.com' (wildcard certificates) according to RFC 6125.
|
||||||
|
// If a match is found, matched will be true. If no matches are found, matched
|
||||||
|
// will be false and a "default" certificate will be returned with defaulted
|
||||||
|
// set to true. If defaulted is false, then no certificates were available.
|
||||||
|
//
|
||||||
|
// The logic in this function is adapted from the Go standard library,
|
||||||
|
// which is by the Go Authors.
|
||||||
|
//
|
||||||
|
// This function is safe for concurrent use.
|
||||||
|
func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
||||||
|
var certKey string
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// Not going to trim trailing dots here since RFC 3546 says,
|
||||||
|
// "The hostname is represented ... without a trailing dot."
|
||||||
|
// Just normalize to lowercase.
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
|
||||||
|
cfg.certCache.RLock()
|
||||||
|
defer cfg.certCache.RUnlock()
|
||||||
|
|
||||||
|
// exact match? great, let's use it
|
||||||
|
if certKey, ok = cfg.Certificates[name]; ok {
|
||||||
|
cert = cfg.certCache.cache[certKey]
|
||||||
|
matched = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// try replacing labels in the name with wildcards until we get a match
|
||||||
|
labels := strings.Split(name, ".")
|
||||||
|
for i := range labels {
|
||||||
|
labels[i] = "*"
|
||||||
|
candidate := strings.Join(labels, ".")
|
||||||
|
if certKey, ok = cfg.Certificates[candidate]; ok {
|
||||||
|
cert = cfg.certCache.cache[certKey]
|
||||||
|
matched = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check the certCache directly to see if the SNI name is
|
||||||
|
// already the key of the certificate it wants! this is vital
|
||||||
|
// for supporting the TLS-SNI challenge, since the tlsSNISolver
|
||||||
|
// just puts the temporary certificate in the instance cache,
|
||||||
|
// with no regard for configs; this also means that the SNI
|
||||||
|
// can contain the hash of a specific cert (chain) it wants
|
||||||
|
// and we will still be able to serve it up
|
||||||
|
// (this behavior, by the way, could be controversial as to
|
||||||
|
// whether it complies with RFC 6066 about SNI, but I think
|
||||||
|
// it does soooo...)
|
||||||
|
// NOTE/TODO: TLS-SNI challenge is changing, as of Jan. 2018
|
||||||
|
// but what will be different, if it ever returns, is unclear
|
||||||
|
if directCert, ok := cfg.certCache.cache[name]; ok {
|
||||||
|
cert = directCert
|
||||||
|
matched = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if nothing matches and SNI was not provided, use a random
|
||||||
|
// certificate; at least there's a chance this older client
|
||||||
|
// can connect, and in the future we won't need this provision
|
||||||
|
// (if SNI is present, it's probably best to just raise a TLS
|
||||||
|
// alert by not serving a certificate)
|
||||||
|
if name == "" {
|
||||||
|
for _, certKey := range cfg.Certificates {
|
||||||
|
defaulted = true
|
||||||
|
cert = cfg.certCache.cache[certKey]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// getCertDuringHandshake will get a certificate for name. It first tries
|
// getCertDuringHandshake will get a certificate for name. It first tries
|
||||||
// the in-memory cache. If no certificate for name is in the cache, the
|
// the in-memory cache. If no certificate for name is in the cache, the
|
||||||
// config most closely corresponding to name will be loaded. If that config
|
// config most closely corresponding to name will be loaded. If that config
|
||||||
@ -134,7 +214,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
|
|||||||
// This function is safe for concurrent use.
|
// This function is safe for concurrent use.
|
||||||
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||||
// First check our in-memory cache to see if we've already loaded it
|
// First check our in-memory cache to see if we've already loaded it
|
||||||
cert, matched, defaulted := getCertificate(name)
|
cert, matched, defaulted := cfg.getCertificate(name)
|
||||||
if matched {
|
if matched {
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
@ -277,7 +357,7 @@ func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
|
|||||||
obtainCertWaitChans[name] = wait
|
obtainCertWaitChans[name] = wait
|
||||||
obtainCertWaitChansMu.Unlock()
|
obtainCertWaitChansMu.Unlock()
|
||||||
|
|
||||||
// do the obtain
|
// obtain the certificate
|
||||||
log.Printf("[INFO] Obtaining new certificate for %s", name)
|
log.Printf("[INFO] Obtaining new certificate for %s", name)
|
||||||
err := cfg.ObtainCert(name, false)
|
err := cfg.ObtainCert(name, false)
|
||||||
|
|
||||||
@ -336,9 +416,9 @@ func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certific
|
|||||||
// quite common considering not all certs have issuer URLs that support it.
|
// quite common considering not all certs have issuer URLs that support it.
|
||||||
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
|
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
|
||||||
}
|
}
|
||||||
certCacheMu.Lock()
|
cfg.certCache.Lock()
|
||||||
certCache[name] = cert
|
cfg.certCache.cache[cert.Hash] = cert
|
||||||
certCacheMu.Unlock()
|
cfg.certCache.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -367,29 +447,22 @@ func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate)
|
|||||||
obtainCertWaitChans[name] = wait
|
obtainCertWaitChans[name] = wait
|
||||||
obtainCertWaitChansMu.Unlock()
|
obtainCertWaitChansMu.Unlock()
|
||||||
|
|
||||||
// do the renew and reload the certificate
|
// renew and reload the certificate
|
||||||
log.Printf("[INFO] Renewing certificate for %s", name)
|
log.Printf("[INFO] Renewing certificate for %s", name)
|
||||||
err := cfg.RenewCert(name, false)
|
err := cfg.RenewCert(name, false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// immediately flush this certificate from the cache so
|
|
||||||
// the name doesn't overlap when we try to replace it,
|
|
||||||
// which would fail, because overlapping existing cert
|
|
||||||
// names isn't allowed
|
|
||||||
certCacheMu.Lock()
|
|
||||||
for _, certName := range currentCert.Names {
|
|
||||||
delete(certCache, certName)
|
|
||||||
}
|
|
||||||
certCacheMu.Unlock()
|
|
||||||
|
|
||||||
// even though the recursive nature of the dynamic cert loading
|
// even though the recursive nature of the dynamic cert loading
|
||||||
// would just call this function anyway, we do it here to
|
// would just call this function anyway, we do it here to
|
||||||
// make the replacement as atomic as possible. (TODO: similar
|
// make the replacement as atomic as possible.
|
||||||
// to the note in maintain.go, it'd be nice if the clearing of
|
newCert, err := currentCert.configs[0].CacheManagedCertificate(name)
|
||||||
// the cache entries above and this load function were truly
|
|
||||||
// atomic...)
|
|
||||||
_, err := currentCert.Config.CacheManagedCertificate(name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] loading renewed certificate: %v", err)
|
log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err)
|
||||||
|
} else {
|
||||||
|
// replace the old certificate with the new one
|
||||||
|
err = cfg.certCache.replaceCertificate(currentCert, newCert)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[ERROR] Replacing certificate for %s: %v", name, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,9 +21,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestGetCertificate(t *testing.T) {
|
func TestGetCertificate(t *testing.T) {
|
||||||
defer func() { certCache = make(map[string]Certificate) }()
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
cfg := new(Config)
|
|
||||||
|
|
||||||
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
||||||
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
||||||
@ -38,33 +37,40 @@ func TestGetCertificate(t *testing.T) {
|
|||||||
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When cache has one certificate in it (also is default)
|
// When cache has one certificate in it
|
||||||
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
||||||
certCache[""] = defaultCert
|
cfg.cacheCertificate(firstCert)
|
||||||
certCache["example.com"] = defaultCert
|
|
||||||
if cert, err := cfg.GetCertificate(hello); err != nil {
|
if cert, err := cfg.GetCertificate(hello); err != nil {
|
||||||
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||||
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
||||||
}
|
}
|
||||||
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
if _, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||||
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
|
||||||
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// When retrieving wildcard certificate
|
// When retrieving wildcard certificate
|
||||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
|
wildcardCert := Certificate{
|
||||||
|
Names: []string{"*.example.com"},
|
||||||
|
Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}},
|
||||||
|
Hash: "(don't overwrite the first one)",
|
||||||
|
}
|
||||||
|
cfg.cacheCertificate(wildcardCert)
|
||||||
if cert, err := cfg.GetCertificate(helloSub); err != nil {
|
if cert, err := cfg.GetCertificate(helloSub); err != nil {
|
||||||
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
||||||
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When no certificate matches, the default is returned
|
// When cache is NOT empty but there's no SNI
|
||||||
if cert, err := cfg.GetCertificate(helloNoMatch); err != nil {
|
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||||
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
|
t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err)
|
||||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
} else if cert == nil || len(cert.Leaf.DNSNames) == 0 {
|
||||||
t.Errorf("Expected default cert with no matches, got: %v", cert)
|
t.Errorf("Expected random cert with no matches, got: %v", cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When no certificate matches, raise an alert
|
||||||
|
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
|
||||||
|
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,10 +27,11 @@ import (
|
|||||||
const challengeBasePath = "/.well-known/acme-challenge"
|
const challengeBasePath = "/.well-known/acme-challenge"
|
||||||
|
|
||||||
// HTTPChallengeHandler proxies challenge requests to ACME client if the
|
// HTTPChallengeHandler proxies challenge requests to ACME client if the
|
||||||
// request path starts with challengeBasePath. It returns true if it
|
// request path starts with challengeBasePath, if the HTTP challenge is not
|
||||||
// handled the request and no more needs to be done; it returns false
|
// disabled, and if we are known to be obtaining a certificate for the name.
|
||||||
// if this call was a no-op and the request still needs handling.
|
// It returns true if it handled the request and no more needs to be done;
|
||||||
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, altPort string) bool {
|
// it returns false if this call was a no-op and the request still needs handling.
|
||||||
|
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool {
|
||||||
if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
|
if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -50,7 +51,11 @@ func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost, al
|
|||||||
listenHost = "localhost"
|
listenHost = "localhost"
|
||||||
}
|
}
|
||||||
|
|
||||||
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, altPort))
|
// always proxy to the DefaultHTTPAlternatePort because obviously the
|
||||||
|
// ACME challenge request already got into one of our HTTP handlers, so
|
||||||
|
// it means we must have started a HTTP listener on the alternate
|
||||||
|
// port instead; which is only accessible via listenHost
|
||||||
|
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
log.Printf("[ERROR] ACME proxy handler: %v", err)
|
log.Printf("[ERROR] ACME proxy handler: %v", err)
|
||||||
|
@ -39,7 +39,7 @@ func TestHTTPChallengeHandlerNoOp(t *testing.T) {
|
|||||||
t.Fatalf("Could not craft request, got error: %v", err)
|
t.Fatalf("Could not craft request, got error: %v", err)
|
||||||
}
|
}
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
if HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort) {
|
if HTTPChallengeHandler(rw, req, "") {
|
||||||
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
|
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ func TestHTTPChallengeHandlerSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
HTTPChallengeHandler(rw, req, "", DefaultHTTPAlternatePort)
|
HTTPChallengeHandler(rw, req, "")
|
||||||
|
|
||||||
if !proxySuccess {
|
if !proxySuccess {
|
||||||
t.Fatal("Expected request to be proxied, but it wasn't")
|
t.Fatal("Expected request to be proxied, but it wasn't")
|
||||||
|
@ -87,103 +87,127 @@ func maintainAssets(stopChan chan struct{}) {
|
|||||||
// RenewManagedCertificates renews managed certificates,
|
// RenewManagedCertificates renews managed certificates,
|
||||||
// including ones loaded on-demand.
|
// including ones loaded on-demand.
|
||||||
func RenewManagedCertificates(allowPrompts bool) (err error) {
|
func RenewManagedCertificates(allowPrompts bool) (err error) {
|
||||||
var renewQueue, deleteQueue []Certificate
|
for _, inst := range caddy.Instances() {
|
||||||
visitedNames := make(map[string]struct{})
|
inst.StorageMu.RLock()
|
||||||
|
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||||
certCacheMu.RLock()
|
inst.StorageMu.RUnlock()
|
||||||
for name, cert := range certCache {
|
if !ok || certCache == nil {
|
||||||
if !cert.Config.Managed || cert.Config.SelfSigned {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// the list of names on this cert should never be empty...
|
// we use the queues for a very important reason: to do any and all
|
||||||
if cert.Names == nil || len(cert.Names) == 0 {
|
// operations that could require an exclusive write lock outside
|
||||||
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", name, cert.Names)
|
// of the read lock! otherwise we get a deadlock, yikes. in other
|
||||||
deleteQueue = append(deleteQueue, cert)
|
// words, our first iteration through the certificate cache does NOT
|
||||||
continue
|
// perform any operations--only queues them--so that more fine-grained
|
||||||
}
|
// write locks may be obtained during the actual operations.
|
||||||
|
var renewQueue, reloadQueue, deleteQueue []Certificate
|
||||||
|
|
||||||
// skip names whose certificate we've already renewed
|
certCache.RLock()
|
||||||
if _, ok := visitedNames[name]; ok {
|
for certKey, cert := range certCache.cache {
|
||||||
continue
|
if len(cert.configs) == 0 {
|
||||||
}
|
// this is bad if this happens, probably a programmer error (oops)
|
||||||
for _, name := range cert.Names {
|
log.Printf("[ERROR] No associated TLS config for certificate with names %v; unable to manage", cert.Names)
|
||||||
visitedNames[name] = struct{}{}
|
continue
|
||||||
}
|
}
|
||||||
|
if !cert.configs[0].Managed || cert.configs[0].SelfSigned {
|
||||||
// if its time is up or ending soon, we need to try to renew it
|
|
||||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
|
||||||
if timeLeft < RenewDurationBefore {
|
|
||||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
|
||||||
|
|
||||||
if cert.Config == nil {
|
|
||||||
log.Printf("[ERROR] %s: No associated TLS config; unable to renew", name)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// queue for renewal when we aren't in a read lock anymore
|
// the list of names on this cert should never be empty... programmer error?
|
||||||
// (the TLS-SNI challenge will need a write lock in order to
|
if cert.Names == nil || len(cert.Names) == 0 {
|
||||||
// present the certificate, so we renew outside of read lock)
|
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", certKey, cert.Names)
|
||||||
renewQueue = append(renewQueue, cert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
certCacheMu.RUnlock()
|
|
||||||
|
|
||||||
// Perform renewals that are queued
|
|
||||||
for _, cert := range renewQueue {
|
|
||||||
// Get the name which we should use to renew this certificate;
|
|
||||||
// we only support managing certificates with one name per cert,
|
|
||||||
// so this should be easy. We can't rely on cert.Config.Hostname
|
|
||||||
// because it may be a wildcard value from the Caddyfile (e.g.
|
|
||||||
// *.something.com) which, as of Jan. 2017, is not supported by ACME.
|
|
||||||
var renewName string
|
|
||||||
for _, name := range cert.Names {
|
|
||||||
if name != "" {
|
|
||||||
renewName = name
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// perform renewal
|
|
||||||
err := cert.Config.RenewCert(renewName, allowPrompts)
|
|
||||||
if err != nil {
|
|
||||||
if allowPrompts {
|
|
||||||
// Certificate renewal failed and the operator is present. See a discussion
|
|
||||||
// about this in issue 642. For a while, we only stopped if the certificate
|
|
||||||
// was expired, but in reality, there is no difference between reporting
|
|
||||||
// it now versus later, except that there's somebody present to deal with
|
|
||||||
// it right now.
|
|
||||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
|
||||||
if timeLeft < RenewDurationBeforeAtStartup {
|
|
||||||
// See issue 1680. Only fail at startup if the certificate is dangerously
|
|
||||||
// close to expiration.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Printf("[ERROR] %v", err)
|
|
||||||
if cert.Config.OnDemand {
|
|
||||||
// loaded dynamically, removed dynamically
|
|
||||||
deleteQueue = append(deleteQueue, cert)
|
deleteQueue = append(deleteQueue, cert)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
|
// if time is up or expires soon, we need to try to renew it
|
||||||
|
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||||
|
if timeLeft < RenewDurationBefore {
|
||||||
|
// see if the certificate in storage has already been renewed, possibly by another
|
||||||
|
// instance of Caddy that didn't coordinate with this one; if so, just load it (this
|
||||||
|
// might happen if another instance already renewed it - kinda sloppy but checking disk
|
||||||
|
// first is a simple way to possibly drastically reduce rate limit problems)
|
||||||
|
storedCertExpiring, err := managedCertInStorageExpiresSoon(cert)
|
||||||
|
if err != nil {
|
||||||
|
// hmm, weird, but not a big deal, maybe it was deleted or something
|
||||||
|
log.Printf("[NOTICE] Error while checking if certificate for %v in storage is also expiring soon: %v",
|
||||||
|
cert.Names, err)
|
||||||
|
} else if !storedCertExpiring {
|
||||||
|
// if the certificate is NOT expiring soon and there was no error, then we
|
||||||
|
// are good to just reload the certificate from storage instead of repeating
|
||||||
|
// a likely-unnecessary renewal procedure
|
||||||
|
reloadQueue = append(reloadQueue, cert)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// the certificate in storage has not been renewed yet, so we will do it
|
||||||
|
// NOTE 1: This is not correct 100% of the time, if multiple Caddy instances
|
||||||
|
// happen to run their maintenance checks at approximately the same times;
|
||||||
|
// both might start renewal at about the same time and do two renewals and one
|
||||||
|
// will overwrite the other. Hence TLS storage plugins. This is sort of a TODO.
|
||||||
|
// NOTE 2: It is super-important to note that the TLS-SNI challenge requires
|
||||||
|
// a write lock on the cache in order to complete its challenge, so it is extra
|
||||||
|
// vital that this renew operation does not happen inside our read lock!
|
||||||
|
renewQueue = append(renewQueue, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
certCache.RUnlock()
|
||||||
|
|
||||||
|
// Reload certificates that merely need to be updated in memory
|
||||||
|
for _, oldCert := range reloadQueue {
|
||||||
|
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||||
|
log.Printf("[INFO] Certificate for %v expires in %v, but is already renewed in storage; reloading stored certificate",
|
||||||
|
oldCert.Names, timeLeft)
|
||||||
|
|
||||||
|
err = certCache.reloadManagedCertificate(oldCert)
|
||||||
|
if err != nil {
|
||||||
|
if allowPrompts {
|
||||||
|
return err // operator is present, so report error immediately
|
||||||
|
}
|
||||||
|
log.Printf("[ERROR] Loading renewed certificate: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Renewal queue
|
||||||
|
for _, oldCert := range renewQueue {
|
||||||
|
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||||
|
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", oldCert.Names, timeLeft)
|
||||||
|
|
||||||
|
// Get the name which we should use to renew this certificate;
|
||||||
|
// we only support managing certificates with one name per cert,
|
||||||
|
// so this should be easy. We can't rely on cert.Config.Hostname
|
||||||
|
// because it may be a wildcard value from the Caddyfile (e.g.
|
||||||
|
// *.something.com) which, as of Jan. 2017, is not supported by ACME.
|
||||||
|
// TODO: ^ ^ ^ (wildcards)
|
||||||
|
renewName := oldCert.Names[0]
|
||||||
|
|
||||||
|
// perform renewal
|
||||||
|
err := oldCert.configs[0].RenewCert(renewName, allowPrompts)
|
||||||
|
if err != nil {
|
||||||
|
if allowPrompts {
|
||||||
|
// Certificate renewal failed and the operator is present. See a discussion
|
||||||
|
// about this in issue 642. For a while, we only stopped if the certificate
|
||||||
|
// was expired, but in reality, there is no difference between reporting
|
||||||
|
// it now versus later, except that there's somebody present to deal with
|
||||||
|
// it right now. Follow-up: See issue 1680. Only fail in this case if the
|
||||||
|
// certificate is dangerously close to expiration.
|
||||||
|
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||||
|
if timeLeft < RenewDurationBeforeAtStartup {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Printf("[ERROR] %v", err)
|
||||||
|
if oldCert.configs[0].OnDemand {
|
||||||
|
// loaded dynamically, remove dynamically
|
||||||
|
deleteQueue = append(deleteQueue, oldCert)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// successful renewal, so update in-memory cache by loading
|
// successful renewal, so update in-memory cache by loading
|
||||||
// renewed certificate so it will be used with handshakes
|
// renewed certificate so it will be used with handshakes
|
||||||
|
err = certCache.reloadManagedCertificate(oldCert)
|
||||||
// we must delete all the names this cert services from the cache
|
|
||||||
// so that we can replace the certificate, because replacing names
|
|
||||||
// already in the cache is not allowed, to avoid later conflicts
|
|
||||||
// with renewals.
|
|
||||||
// TODO: It would be nice if this whole operation were idempotent;
|
|
||||||
// i.e. a thread-safe function to replace a certificate in the cache,
|
|
||||||
// see also handshake.go for on-demand maintenance.
|
|
||||||
certCacheMu.Lock()
|
|
||||||
for _, name := range cert.Names {
|
|
||||||
delete(certCache, name)
|
|
||||||
}
|
|
||||||
certCacheMu.Unlock()
|
|
||||||
|
|
||||||
// put the certificate in the cache
|
|
||||||
_, err := cert.Config.CacheManagedCertificate(cert.Names[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if allowPrompts {
|
if allowPrompts {
|
||||||
return err // operator is present, so report error immediately
|
return err // operator is present, so report error immediately
|
||||||
@ -191,15 +215,22 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
|
|||||||
log.Printf("[ERROR] %v", err)
|
log.Printf("[ERROR] %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Apply queued deletion changes to the cache
|
// Deletion queue
|
||||||
for _, cert := range deleteQueue {
|
for _, cert := range deleteQueue {
|
||||||
certCacheMu.Lock()
|
certCache.Lock()
|
||||||
for _, name := range cert.Names {
|
// remove any pointers to this certificate from Configs
|
||||||
delete(certCache, name)
|
for _, cfg := range cert.configs {
|
||||||
|
for name, certKey := range cfg.Certificates {
|
||||||
|
if certKey == cert.Hash {
|
||||||
|
delete(cfg.Certificates, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// then delete the certificate from the cache
|
||||||
|
delete(certCache.cache, cert.Hash)
|
||||||
|
certCache.Unlock()
|
||||||
}
|
}
|
||||||
certCacheMu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -212,91 +243,75 @@ func RenewManagedCertificates(allowPrompts bool) (err error) {
|
|||||||
// Ryan Sleevi's recommendations for good OCSP support:
|
// Ryan Sleevi's recommendations for good OCSP support:
|
||||||
// https://gist.github.com/sleevi/5efe9ef98961ecfb4da8
|
// https://gist.github.com/sleevi/5efe9ef98961ecfb4da8
|
||||||
func UpdateOCSPStaples() {
|
func UpdateOCSPStaples() {
|
||||||
// Create a temporary place to store updates
|
for _, inst := range caddy.Instances() {
|
||||||
// until we release the potentially long-lived
|
inst.StorageMu.RLock()
|
||||||
// read lock and use a short-lived write lock.
|
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||||
type ocspUpdate struct {
|
inst.StorageMu.RUnlock()
|
||||||
rawBytes []byte
|
if !ok || certCache == nil {
|
||||||
parsed *ocsp.Response
|
|
||||||
}
|
|
||||||
updated := make(map[string]ocspUpdate)
|
|
||||||
|
|
||||||
// A single SAN certificate maps to multiple names, so we use this
|
|
||||||
// set to make sure we don't waste cycles checking OCSP for the same
|
|
||||||
// certificate multiple times.
|
|
||||||
visited := make(map[string]struct{})
|
|
||||||
|
|
||||||
certCacheMu.RLock()
|
|
||||||
for name, cert := range certCache {
|
|
||||||
// skip this certificate if we've already visited it,
|
|
||||||
// and if not, mark all the names as visited
|
|
||||||
if _, ok := visited[name]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, n := range cert.Names {
|
|
||||||
visited[n] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// no point in updating OCSP for expired certificates
|
|
||||||
if time.Now().After(cert.NotAfter) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastNextUpdate time.Time
|
// Create a temporary place to store updates
|
||||||
if cert.OCSP != nil {
|
// until we release the potentially long-lived
|
||||||
lastNextUpdate = cert.OCSP.NextUpdate
|
// read lock and use a short-lived write lock
|
||||||
if freshOCSP(cert.OCSP) {
|
// on the certificate cache.
|
||||||
// no need to update staple if ours is still fresh
|
type ocspUpdate struct {
|
||||||
|
rawBytes []byte
|
||||||
|
parsed *ocsp.Response
|
||||||
|
}
|
||||||
|
updated := make(map[string]ocspUpdate)
|
||||||
|
|
||||||
|
certCache.RLock()
|
||||||
|
for certHash, cert := range certCache.cache {
|
||||||
|
// no point in updating OCSP for expired certificates
|
||||||
|
if time.Now().After(cert.NotAfter) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
err := stapleOCSP(&cert, nil)
|
var lastNextUpdate time.Time
|
||||||
if err != nil {
|
|
||||||
if cert.OCSP != nil {
|
if cert.OCSP != nil {
|
||||||
// if there was no staple before, that's fine; otherwise we should log the error
|
lastNextUpdate = cert.OCSP.NextUpdate
|
||||||
log.Printf("[ERROR] Checking OCSP: %v", err)
|
if freshOCSP(cert.OCSP) {
|
||||||
|
continue // no need to update staple if ours is still fresh
|
||||||
|
}
|
||||||
}
|
}
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// By this point, we've obtained the latest OCSP response.
|
err := stapleOCSP(&cert, nil)
|
||||||
// If there was no staple before, or if the response is updated, make
|
if err != nil {
|
||||||
// sure we apply the update to all names on the certificate.
|
if cert.OCSP != nil {
|
||||||
if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) {
|
// if there was no staple before, that's fine; otherwise we should log the error
|
||||||
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
|
log.Printf("[ERROR] Checking OCSP: %v", err)
|
||||||
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
|
}
|
||||||
for _, n := range cert.Names {
|
continue
|
||||||
// BUG: If this certificate has names on it that appear on another
|
}
|
||||||
// certificate in the cache, AND the other certificate is keyed by
|
|
||||||
// that name in the cache, then this method of 'queueing' the staple
|
// By this point, we've obtained the latest OCSP response.
|
||||||
// update will cause this certificate's new OCSP to be stapled to
|
// If there was no staple before, or if the response is updated, make
|
||||||
// a different certificate! See:
|
// sure we apply the update to all names on the certificate.
|
||||||
// https://caddy.community/t/random-ocsp-response-errors-for-random-clients/2473?u=matt
|
if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) {
|
||||||
// This problem should be avoided if names on certificates in the
|
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
|
||||||
// cache don't overlap with regards to the cache keys.
|
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
|
||||||
// (This is isn't a bug anymore, since we're careful when we add
|
updated[certHash] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
|
||||||
// certificates to the cache by skipping keying when key already exists.)
|
|
||||||
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
certCache.RUnlock()
|
||||||
certCacheMu.RUnlock()
|
|
||||||
|
|
||||||
// This write lock should be brief since we have all the info we need now.
|
// These write locks should be brief since we have all the info we need now.
|
||||||
certCacheMu.Lock()
|
for certKey, update := range updated {
|
||||||
for name, update := range updated {
|
certCache.Lock()
|
||||||
cert := certCache[name]
|
cert := certCache.cache[certKey]
|
||||||
cert.OCSP = update.parsed
|
cert.OCSP = update.parsed
|
||||||
cert.Certificate.OCSPStaple = update.rawBytes
|
cert.Certificate.OCSPStaple = update.rawBytes
|
||||||
certCache[name] = cert
|
certCache.cache[certKey] = cert
|
||||||
|
certCache.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
certCacheMu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteOldStapleFiles deletes cached OCSP staples that have expired.
|
// DeleteOldStapleFiles deletes cached OCSP staples that have expired.
|
||||||
// TODO: Should we do this for certificates too?
|
// TODO: Should we do this for certificates too?
|
||||||
func DeleteOldStapleFiles() {
|
func DeleteOldStapleFiles() {
|
||||||
|
// TODO: Upgrade caddytls.Storage to support OCSP operations too
|
||||||
files, err := ioutil.ReadDir(ocspFolder)
|
files, err := ioutil.ReadDir(ocspFolder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// maybe just hasn't been created yet; no big deal
|
// maybe just hasn't been created yet; no big deal
|
||||||
|
@ -38,6 +38,7 @@ func init() {
|
|||||||
// are specified by the user in the config file. All the automatic HTTPS
|
// are specified by the user in the config file. All the automatic HTTPS
|
||||||
// stuff comes later outside of this function.
|
// stuff comes later outside of this function.
|
||||||
func setupTLS(c *caddy.Controller) error {
|
func setupTLS(c *caddy.Controller) error {
|
||||||
|
// obtain the configGetter, which loads the config we're, uh, configuring
|
||||||
configGetter, ok := configGetters[c.ServerType()]
|
configGetter, ok := configGetters[c.ServerType()]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
|
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
|
||||||
@ -47,6 +48,14 @@ func setupTLS(c *caddy.Controller) error {
|
|||||||
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
|
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the certificate cache is tied to the current caddy.Instance; get a pointer to it
|
||||||
|
certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache)
|
||||||
|
if !ok || certCache == nil {
|
||||||
|
certCache = &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
}
|
||||||
|
config.certCache = certCache
|
||||||
|
|
||||||
config.Enabled = true
|
config.Enabled = true
|
||||||
|
|
||||||
for c.Next() {
|
for c.Next() {
|
||||||
@ -237,7 +246,7 @@ func setupTLS(c *caddy.Controller) error {
|
|||||||
|
|
||||||
// load a single certificate and key, if specified
|
// load a single certificate and key, if specified
|
||||||
if certificateFile != "" && keyFile != "" {
|
if certificateFile != "" && keyFile != "" {
|
||||||
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
|
err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
|
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
|
||||||
}
|
}
|
||||||
@ -246,7 +255,7 @@ func setupTLS(c *caddy.Controller) error {
|
|||||||
|
|
||||||
// load a directory of certificates, if specified
|
// load a directory of certificates, if specified
|
||||||
if loadDir != "" {
|
if loadDir != "" {
|
||||||
err := loadCertsInDir(c, loadDir)
|
err := loadCertsInDir(config, c, loadDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -273,7 +282,7 @@ func setupTLS(c *caddy.Controller) error {
|
|||||||
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
|
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
|
||||||
//
|
//
|
||||||
// This function may write to the log as it walks the directory tree.
|
// This function may write to the log as it walks the directory tree.
|
||||||
func loadCertsInDir(c *caddy.Controller, dir string) error {
|
func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
|
||||||
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
|
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
|
||||||
@ -336,7 +345,7 @@ func loadCertsInDir(c *caddy.Controller, dir string) error {
|
|||||||
return c.Errf("%s: no private key block found", path)
|
return c.Errf("%s: no private key block found", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
|
err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
|
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
|
||||||
}
|
}
|
||||||
|
@ -46,9 +46,12 @@ func TestMain(m *testing.M) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupParseBasic(t *testing.T) {
|
func TestSetupParseBasic(t *testing.T) {
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
|
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
|
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -124,9 +127,12 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
|||||||
must_staple
|
must_staple
|
||||||
alpn http/1.1
|
alpn http/1.1
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
|
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", params)
|
c := caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -158,9 +164,11 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) {
|
|||||||
params := `tls {
|
params := `tls {
|
||||||
ciphers RSA-3DES-EDE-CBC-SHA
|
ciphers RSA-3DES-EDE-CBC-SHA
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", params)
|
c := caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -176,9 +184,12 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
|||||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||||
protocols ssl tls
|
protocols ssl tls
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", params)
|
c := caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected errors, but no error returned")
|
t.Errorf("Expected errors, but no error returned")
|
||||||
@ -191,6 +202,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
|||||||
cfg = new(Config)
|
cfg = new(Config)
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c = caddy.NewTestController("", params)
|
c = caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
err = setupTLS(c)
|
err = setupTLS(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected errors, but no error returned")
|
t.Error("Expected errors, but no error returned")
|
||||||
@ -215,6 +227,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
|||||||
cfg = new(Config)
|
cfg = new(Config)
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c = caddy.NewTestController("", params)
|
c = caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
err = setupTLS(c)
|
err = setupTLS(c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected errors, but no error returned")
|
t.Error("Expected errors, but no error returned")
|
||||||
@ -226,7 +239,8 @@ func TestSetupParseWithClientAuth(t *testing.T) {
|
|||||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||||
clients
|
clients
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", params)
|
c := caddy.NewTestController("", params)
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
@ -259,9 +273,11 @@ func TestSetupParseWithClientAuth(t *testing.T) {
|
|||||||
clients verify_if_given
|
clients verify_if_given
|
||||||
}`, tls.VerifyClientCertIfGiven, true, noCAs},
|
}`, tls.VerifyClientCertIfGiven, true, noCAs},
|
||||||
} {
|
} {
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", caseData.params)
|
c := caddy.NewTestController("", caseData.params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if caseData.expectedErr {
|
if caseData.expectedErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -311,9 +327,11 @@ func TestSetupParseWithCAUrl(t *testing.T) {
|
|||||||
ca 1 2
|
ca 1 2
|
||||||
}`, true, ""},
|
}`, true, ""},
|
||||||
} {
|
} {
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", caseData.params)
|
c := caddy.NewTestController("", caseData.params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if caseData.expectedErr {
|
if caseData.expectedErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -335,9 +353,11 @@ func TestSetupParseWithKeyType(t *testing.T) {
|
|||||||
params := `tls {
|
params := `tls {
|
||||||
key_type p384
|
key_type p384
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", params)
|
c := caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -353,9 +373,11 @@ func TestSetupParseWithCurves(t *testing.T) {
|
|||||||
params := `tls {
|
params := `tls {
|
||||||
curves x25519 p256 p384 p521
|
curves x25519 p256 p384 p521
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", params)
|
c := caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -380,9 +402,11 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
|
|||||||
params := `tls {
|
params := `tls {
|
||||||
protocols tls1.2
|
protocols tls1.2
|
||||||
}`
|
}`
|
||||||
cfg := new(Config)
|
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||||
|
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||||
c := caddy.NewTestController("", params)
|
c := caddy.NewTestController("", params)
|
||||||
|
c.Set(CertCacheInstStorageKey, certCache)
|
||||||
|
|
||||||
err := setupTLS(c)
|
err := setupTLS(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -107,6 +107,10 @@ type Storage interface {
|
|||||||
// in StoreUser. The result is an empty string if there are no
|
// in StoreUser. The result is an empty string if there are no
|
||||||
// persisted users in storage.
|
// persisted users in storage.
|
||||||
MostRecentUserEmail() string
|
MostRecentUserEmail() string
|
||||||
|
|
||||||
|
// Locker is necessary because synchronizing certificate maintenance
|
||||||
|
// depends on how storage is implemented.
|
||||||
|
Locker
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrNotExist is returned by Storage implementations when
|
// ErrNotExist is returned by Storage implementations when
|
||||||
|
@ -1,57 +0,0 @@
|
|||||||
// Copyright 2015 Light Code Labs, LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package caddytls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ Locker = &syncLock{}
|
|
||||||
|
|
||||||
type syncLock struct {
|
|
||||||
nameLocks map[string]*sync.WaitGroup
|
|
||||||
nameLocksMu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryLock attempts to get a lock for name, otherwise it returns
|
|
||||||
// a Waiter value to wait until the other process is finished.
|
|
||||||
func (s *syncLock) TryLock(name string) (Waiter, error) {
|
|
||||||
s.nameLocksMu.Lock()
|
|
||||||
defer s.nameLocksMu.Unlock()
|
|
||||||
wg, ok := s.nameLocks[name]
|
|
||||||
if ok {
|
|
||||||
// lock already obtained, let caller wait on it
|
|
||||||
return wg, nil
|
|
||||||
}
|
|
||||||
// caller gets lock
|
|
||||||
wg = new(sync.WaitGroup)
|
|
||||||
wg.Add(1)
|
|
||||||
s.nameLocks[name] = wg
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unlock unlocks name.
|
|
||||||
func (s *syncLock) Unlock(name string) error {
|
|
||||||
s.nameLocksMu.Lock()
|
|
||||||
defer s.nameLocksMu.Unlock()
|
|
||||||
wg, ok := s.nameLocks[name]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
delete(s.nameLocks, name)
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -88,30 +88,38 @@ func Revoke(host string) error {
|
|||||||
return client.Revoke(host)
|
return client.Revoke(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsSniSolver is a type that can solve tls-sni challenges using
|
// tlsSNISolver is a type that can solve TLS-SNI challenges using
|
||||||
// an existing listener and our custom, in-memory certificate cache.
|
// an existing listener and our custom, in-memory certificate cache.
|
||||||
type tlsSniSolver struct{}
|
type tlsSNISolver struct {
|
||||||
|
certCache *certificateCache
|
||||||
|
}
|
||||||
|
|
||||||
// Present adds the challenge certificate to the cache.
|
// Present adds the challenge certificate to the cache.
|
||||||
func (s tlsSniSolver) Present(domain, token, keyAuth string) error {
|
func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
|
||||||
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cacheCertificate(Certificate{
|
certHash := hashCertificateChain(cert.Certificate)
|
||||||
|
s.certCache.Lock()
|
||||||
|
s.certCache.cache[acmeDomain] = Certificate{
|
||||||
Certificate: cert,
|
Certificate: cert,
|
||||||
Names: []string{acmeDomain},
|
Names: []string{acmeDomain},
|
||||||
})
|
Hash: certHash, // perhaps not necesssary
|
||||||
|
}
|
||||||
|
s.certCache.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanUp removes the challenge certificate from the cache.
|
// CleanUp removes the challenge certificate from the cache.
|
||||||
func (s tlsSniSolver) CleanUp(domain, token, keyAuth string) error {
|
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
|
||||||
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
uncacheCertificate(acmeDomain)
|
s.certCache.Lock()
|
||||||
|
delete(s.certCache.cache, acmeDomain)
|
||||||
|
s.certCache.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,6 +103,20 @@ func (c *Controller) Context() Context {
|
|||||||
return c.instance.context
|
return c.instance.context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get safely gets a value from the Instance's storage.
|
||||||
|
func (c *Controller) Get(key interface{}) interface{} {
|
||||||
|
c.instance.StorageMu.RLock()
|
||||||
|
defer c.instance.StorageMu.RUnlock()
|
||||||
|
return c.instance.Storage[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set safely sets a value on the Instance's storage.
|
||||||
|
func (c *Controller) Set(key, val interface{}) {
|
||||||
|
c.instance.StorageMu.Lock()
|
||||||
|
c.instance.Storage[key] = val
|
||||||
|
c.instance.StorageMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// NewTestController creates a new Controller for
|
// NewTestController creates a new Controller for
|
||||||
// the server type and input specified. The filename
|
// the server type and input specified. The filename
|
||||||
// is "Testfile". If the server type is not empty and
|
// is "Testfile". If the server type is not empty and
|
||||||
@ -113,12 +127,12 @@ func (c *Controller) Context() Context {
|
|||||||
// Used only for testing, but exported so plugins can
|
// Used only for testing, but exported so plugins can
|
||||||
// use this for convenience.
|
// use this for convenience.
|
||||||
func NewTestController(serverType, input string) *Controller {
|
func NewTestController(serverType, input string) *Controller {
|
||||||
var ctx Context
|
testInst := &Instance{serverType: serverType, Storage: make(map[interface{}]interface{})}
|
||||||
if stype, err := getServerType(serverType); err == nil {
|
if stype, err := getServerType(serverType); err == nil {
|
||||||
ctx = stype.NewContext()
|
testInst.context = stype.NewContext(testInst)
|
||||||
}
|
}
|
||||||
return &Controller{
|
return &Controller{
|
||||||
instance: &Instance{serverType: serverType, context: ctx},
|
instance: testInst,
|
||||||
Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)),
|
Dispenser: caddyfile.NewDispenser("Testfile", strings.NewReader(input)),
|
||||||
OncePerServerBlock: func(f func() error) error { return f() },
|
OncePerServerBlock: func(f func() error) error { return f() },
|
||||||
}
|
}
|
||||||
|
1
dist/init/linux-systemd/README.md
vendored
1
dist/init/linux-systemd/README.md
vendored
@ -91,6 +91,7 @@ Install the systemd service unit configuration file, reload the systemd daemon,
|
|||||||
and start caddy:
|
and start caddy:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
wget https://raw.githubusercontent.com/mholt/caddy/master/dist/init/linux-systemd/caddy.service
|
||||||
sudo cp caddy.service /etc/systemd/system/
|
sudo cp caddy.service /etc/systemd/system/
|
||||||
sudo chown root:root /etc/systemd/system/caddy.service
|
sudo chown root:root /etc/systemd/system/caddy.service
|
||||||
sudo chmod 644 /etc/systemd/system/caddy.service
|
sudo chmod 644 /etc/systemd/system/caddy.service
|
||||||
|
4
dist/init/linux-systemd/caddy.service
vendored
4
dist/init/linux-systemd/caddy.service
vendored
@ -30,8 +30,8 @@ LimitNPROC=512
|
|||||||
|
|
||||||
; Use private /tmp and /var/tmp, which are discarded after caddy stops.
|
; Use private /tmp and /var/tmp, which are discarded after caddy stops.
|
||||||
PrivateTmp=true
|
PrivateTmp=true
|
||||||
; Use a minimal /dev
|
; Use a minimal /dev (May bring additional security if switched to 'true', but it may not work on Raspberry Pi's or other devices, so it has been disabled in this dist.)
|
||||||
PrivateDevices=true
|
PrivateDevices=false
|
||||||
; Hide /home, /root, and /run/user. Nobody will steal your SSH-keys.
|
; Hide /home, /root, and /run/user. Nobody will steal your SSH-keys.
|
||||||
ProtectHome=true
|
ProtectHome=true
|
||||||
; Make /usr, /boot, /etc and possibly some more folders read-only.
|
; Make /usr, /boot, /etc and possibly some more folders read-only.
|
||||||
|
41
plugins.go
41
plugins.go
@ -19,6 +19,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/mholt/caddy/caddyfile"
|
"github.com/mholt/caddy/caddyfile"
|
||||||
)
|
)
|
||||||
@ -38,7 +39,7 @@ var (
|
|||||||
|
|
||||||
// eventHooks is a map of hook name to Hook. All hooks plugins
|
// eventHooks is a map of hook name to Hook. All hooks plugins
|
||||||
// must have a name.
|
// must have a name.
|
||||||
eventHooks = make(map[string]EventHook)
|
eventHooks = sync.Map{}
|
||||||
|
|
||||||
// parsingCallbacks maps server type to map of directive
|
// parsingCallbacks maps server type to map of directive
|
||||||
// to list of callback functions. These aren't really
|
// to list of callback functions. These aren't really
|
||||||
@ -98,11 +99,15 @@ func ListPlugins() map[string][]string {
|
|||||||
p["caddyfile_loaders"] = append(p["caddyfile_loaders"], defaultCaddyfileLoader.name)
|
p["caddyfile_loaders"] = append(p["caddyfile_loaders"], defaultCaddyfileLoader.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// event hook plugins
|
// List the event hook plugins
|
||||||
if len(eventHooks) > 0 {
|
hooks := ""
|
||||||
for name := range eventHooks {
|
eventHooks.Range(func(k, _ interface{}) bool {
|
||||||
p["event_hooks"] = append(p["event_hooks"], name)
|
hooks += " hook." + k.(string) + "\n"
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
if hooks != "" {
|
||||||
|
str += "\nEvent hook plugins:\n"
|
||||||
|
str += hooks
|
||||||
}
|
}
|
||||||
|
|
||||||
// alphabetize the rest of the plugins
|
// alphabetize the rest of the plugins
|
||||||
@ -220,7 +225,7 @@ type ServerType struct {
|
|||||||
// startup phases before this one. It's a way to keep
|
// startup phases before this one. It's a way to keep
|
||||||
// each set of server instances separate and to reduce
|
// each set of server instances separate and to reduce
|
||||||
// the amount of global state you need.
|
// the amount of global state you need.
|
||||||
NewContext func() Context
|
NewContext func(inst *Instance) Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// Plugin is a type which holds information about a plugin.
|
// Plugin is a type which holds information about a plugin.
|
||||||
@ -277,23 +282,23 @@ func RegisterEventHook(name string, hook EventHook) {
|
|||||||
if name == "" {
|
if name == "" {
|
||||||
panic("event hook must have a name")
|
panic("event hook must have a name")
|
||||||
}
|
}
|
||||||
if _, dup := eventHooks[name]; dup {
|
_, dup := eventHooks.LoadOrStore(name, hook)
|
||||||
|
if dup {
|
||||||
panic("hook named " + name + " already registered")
|
panic("hook named " + name + " already registered")
|
||||||
}
|
}
|
||||||
eventHooks[name] = hook
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmitEvent executes the different hooks passing the EventType as an
|
// EmitEvent executes the different hooks passing the EventType as an
|
||||||
// argument. This is a blocking function. Hook developers should
|
// argument. This is a blocking function. Hook developers should
|
||||||
// use 'go' keyword if they don't want to block Caddy.
|
// use 'go' keyword if they don't want to block Caddy.
|
||||||
func EmitEvent(event EventName, info interface{}) {
|
func EmitEvent(event EventName, info interface{}) {
|
||||||
for name, hook := range eventHooks {
|
eventHooks.Range(func(k, v interface{}) bool {
|
||||||
err := hook(event, info)
|
err := v.(EventHook)(event, info)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error on '%s' hook: %v", name, err)
|
log.Printf("error on '%s' hook: %v", k.(string), err)
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParsingCallback is a function that is called after
|
// ParsingCallback is a function that is called after
|
||||||
@ -412,6 +417,14 @@ func loadCaddyfileInput(serverType string) (Input, error) {
|
|||||||
return caddyfileToUse, nil
|
return caddyfileToUse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnProcessExit is a list of functions to run when the process
|
||||||
|
// exits -- they are ONLY for cleanup and should not block,
|
||||||
|
// return errors, or do anything fancy. They will be run with
|
||||||
|
// every signal, even if "shutdown callbacks" are not executed.
|
||||||
|
// This variable must only be modified in the main goroutine
|
||||||
|
// from init() functions.
|
||||||
|
var OnProcessExit []func()
|
||||||
|
|
||||||
// caddyfileLoader pairs the name of a loader to the loader.
|
// caddyfileLoader pairs the name of a loader to the loader.
|
||||||
type caddyfileLoader struct {
|
type caddyfileLoader struct {
|
||||||
name string
|
name string
|
||||||
|
@ -44,16 +44,17 @@ func trapSignalsCrossPlatform() {
|
|||||||
|
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
log.Println("[INFO] SIGINT: Force quit")
|
log.Println("[INFO] SIGINT: Force quit")
|
||||||
if PidFile != "" {
|
for _, f := range OnProcessExit {
|
||||||
os.Remove(PidFile)
|
f() // important cleanup actions only
|
||||||
}
|
}
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("[INFO] SIGINT: Shutting down")
|
log.Println("[INFO] SIGINT: Shutting down")
|
||||||
|
|
||||||
if PidFile != "" {
|
// important cleanup actions before shutdown callbacks
|
||||||
os.Remove(PidFile)
|
for _, f := range OnProcessExit {
|
||||||
|
f()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -33,22 +33,22 @@ func trapSignalsPosix() {
|
|||||||
switch sig {
|
switch sig {
|
||||||
case syscall.SIGQUIT:
|
case syscall.SIGQUIT:
|
||||||
log.Println("[INFO] SIGQUIT: Quitting process immediately")
|
log.Println("[INFO] SIGQUIT: Quitting process immediately")
|
||||||
if PidFile != "" {
|
for _, f := range OnProcessExit {
|
||||||
os.Remove(PidFile)
|
f() // only perform important cleanup actions
|
||||||
}
|
}
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|
||||||
case syscall.SIGTERM:
|
case syscall.SIGTERM:
|
||||||
log.Println("[INFO] SIGTERM: Shutting down servers then terminating")
|
log.Println("[INFO] SIGTERM: Shutting down servers then terminating")
|
||||||
exitCode := executeShutdownCallbacks("SIGTERM")
|
exitCode := executeShutdownCallbacks("SIGTERM")
|
||||||
|
for _, f := range OnProcessExit {
|
||||||
|
f() // only perform important cleanup actions
|
||||||
|
}
|
||||||
err := Stop()
|
err := Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] SIGTERM stop: %v", err)
|
log.Printf("[ERROR] SIGTERM stop: %v", err)
|
||||||
exitCode = 3
|
exitCode = 3
|
||||||
}
|
}
|
||||||
if PidFile != "" {
|
|
||||||
os.Remove(PidFile)
|
|
||||||
}
|
|
||||||
os.Exit(exitCode)
|
os.Exit(exitCode)
|
||||||
|
|
||||||
case syscall.SIGUSR1:
|
case syscall.SIGUSR1:
|
||||||
|
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2016 Richard Barnes
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mint
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
type Alert uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// alert level
|
||||||
|
AlertLevelWarning = 1
|
||||||
|
AlertLevelError = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AlertCloseNotify Alert = 0
|
||||||
|
AlertUnexpectedMessage Alert = 10
|
||||||
|
AlertBadRecordMAC Alert = 20
|
||||||
|
AlertDecryptionFailed Alert = 21
|
||||||
|
AlertRecordOverflow Alert = 22
|
||||||
|
AlertDecompressionFailure Alert = 30
|
||||||
|
AlertHandshakeFailure Alert = 40
|
||||||
|
AlertBadCertificate Alert = 42
|
||||||
|
AlertUnsupportedCertificate Alert = 43
|
||||||
|
AlertCertificateRevoked Alert = 44
|
||||||
|
AlertCertificateExpired Alert = 45
|
||||||
|
AlertCertificateUnknown Alert = 46
|
||||||
|
AlertIllegalParameter Alert = 47
|
||||||
|
AlertUnknownCA Alert = 48
|
||||||
|
AlertAccessDenied Alert = 49
|
||||||
|
AlertDecodeError Alert = 50
|
||||||
|
AlertDecryptError Alert = 51
|
||||||
|
AlertProtocolVersion Alert = 70
|
||||||
|
AlertInsufficientSecurity Alert = 71
|
||||||
|
AlertInternalError Alert = 80
|
||||||
|
AlertInappropriateFallback Alert = 86
|
||||||
|
AlertUserCanceled Alert = 90
|
||||||
|
AlertNoRenegotiation Alert = 100
|
||||||
|
AlertMissingExtension Alert = 109
|
||||||
|
AlertUnsupportedExtension Alert = 110
|
||||||
|
AlertCertificateUnobtainable Alert = 111
|
||||||
|
AlertUnrecognizedName Alert = 112
|
||||||
|
AlertBadCertificateStatsResponse Alert = 113
|
||||||
|
AlertBadCertificateHashValue Alert = 114
|
||||||
|
AlertUnknownPSKIdentity Alert = 115
|
||||||
|
AlertNoApplicationProtocol Alert = 120
|
||||||
|
AlertWouldBlock Alert = 254
|
||||||
|
AlertNoAlert Alert = 255
|
||||||
|
)
|
||||||
|
|
||||||
|
var alertText = map[Alert]string{
|
||||||
|
AlertCloseNotify: "close notify",
|
||||||
|
AlertUnexpectedMessage: "unexpected message",
|
||||||
|
AlertBadRecordMAC: "bad record MAC",
|
||||||
|
AlertDecryptionFailed: "decryption failed",
|
||||||
|
AlertRecordOverflow: "record overflow",
|
||||||
|
AlertDecompressionFailure: "decompression failure",
|
||||||
|
AlertHandshakeFailure: "handshake failure",
|
||||||
|
AlertBadCertificate: "bad certificate",
|
||||||
|
AlertUnsupportedCertificate: "unsupported certificate",
|
||||||
|
AlertCertificateRevoked: "revoked certificate",
|
||||||
|
AlertCertificateExpired: "expired certificate",
|
||||||
|
AlertCertificateUnknown: "unknown certificate",
|
||||||
|
AlertIllegalParameter: "illegal parameter",
|
||||||
|
AlertUnknownCA: "unknown certificate authority",
|
||||||
|
AlertAccessDenied: "access denied",
|
||||||
|
AlertDecodeError: "error decoding message",
|
||||||
|
AlertDecryptError: "error decrypting message",
|
||||||
|
AlertProtocolVersion: "protocol version not supported",
|
||||||
|
AlertInsufficientSecurity: "insufficient security level",
|
||||||
|
AlertInternalError: "internal error",
|
||||||
|
AlertInappropriateFallback: "inappropriate fallback",
|
||||||
|
AlertUserCanceled: "user canceled",
|
||||||
|
AlertMissingExtension: "missing extension",
|
||||||
|
AlertUnsupportedExtension: "unsupported extension",
|
||||||
|
AlertCertificateUnobtainable: "certificate unobtainable",
|
||||||
|
AlertUnrecognizedName: "unrecognized name",
|
||||||
|
AlertBadCertificateStatsResponse: "bad certificate status response",
|
||||||
|
AlertBadCertificateHashValue: "bad certificate hash value",
|
||||||
|
AlertUnknownPSKIdentity: "unknown PSK identity",
|
||||||
|
AlertNoApplicationProtocol: "no application protocol",
|
||||||
|
AlertNoRenegotiation: "no renegotiation",
|
||||||
|
AlertWouldBlock: "would have blocked",
|
||||||
|
AlertNoAlert: "no alert",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Alert) String() string {
|
||||||
|
s, ok := alertText[e]
|
||||||
|
if ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return "alert(" + strconv.Itoa(int(e)) + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Alert) Error() string {
|
||||||
|
return e.String()
|
||||||
|
}
|
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
)
|
||||||
|
|
||||||
|
var url string
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
url := flag.String("url", "https://localhost:4430", "URL to send request")
|
||||||
|
flag.Parse()
|
||||||
|
mintdial := func(network, addr string) (net.Conn, error) {
|
||||||
|
return mint.Dial(network, addr, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
tr := &http.Transport{
|
||||||
|
DialTLS: mintdial,
|
||||||
|
DisableCompression: true,
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
|
||||||
|
response, err := client.Get(*url)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("err:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
contents, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("%s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("%s\n", string(contents))
|
||||||
|
}
|
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
)
|
||||||
|
|
||||||
|
var addr string
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&addr, "addr", "localhost:4430", "port")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
conn, err := mint.Dial("tcp", addr, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("TLS handshake failed:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
request := "GET / HTTP/1.0\r\n\r\n"
|
||||||
|
conn.Write([]byte(request))
|
||||||
|
|
||||||
|
response := ""
|
||||||
|
buffer := make([]byte, 1024)
|
||||||
|
var read int
|
||||||
|
for err == nil {
|
||||||
|
read, err = conn.Read(buffer)
|
||||||
|
fmt.Println(" ~~ read: ", read)
|
||||||
|
response += string(buffer)
|
||||||
|
}
|
||||||
|
fmt.Println("err:", err)
|
||||||
|
fmt.Println("Received from server:")
|
||||||
|
fmt.Println(response)
|
||||||
|
}
|
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
port string
|
||||||
|
serverName string
|
||||||
|
certFile string
|
||||||
|
keyFile string
|
||||||
|
responseFile string
|
||||||
|
h2 bool
|
||||||
|
sendTickets bool
|
||||||
|
)
|
||||||
|
|
||||||
|
type responder []byte
|
||||||
|
|
||||||
|
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(rsp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
|
||||||
|
// PEM-encoded private key.
|
||||||
|
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||||
|
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
|
||||||
|
keyDER, _ := pem.Decode(keyPEM)
|
||||||
|
if keyDER == nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
// We don't include the actual error into
|
||||||
|
// the final error. The reason might be
|
||||||
|
// we don't want to leak any info about
|
||||||
|
// the private key.
|
||||||
|
return nil, fmt.Errorf("No successful private key decoder")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch generalKey.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return generalKey.(*rsa.PrivateKey), nil
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return generalKey.(*ecdsa.PrivateKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// should never reach here
|
||||||
|
return nil, fmt.Errorf("Should be unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
|
||||||
|
// either a raw x509 certificate or a PKCS #7 structure possibly containing
|
||||||
|
// multiple certificates, from the top of certsPEM, which itself may
|
||||||
|
// contain multiple PEM encoded certificate objects.
|
||||||
|
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||||
|
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
|
||||||
|
block, rest := pem.Decode(certsPEM)
|
||||||
|
if block == nil {
|
||||||
|
return nil, rest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
var certs = []*x509.Certificate{cert}
|
||||||
|
return certs, rest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
|
||||||
|
// can handle PEM encoded PKCS #7 structures.
|
||||||
|
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||||
|
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
|
||||||
|
var certs []*x509.Certificate
|
||||||
|
var err error
|
||||||
|
certsPEM = bytes.TrimSpace(certsPEM)
|
||||||
|
for len(certsPEM) > 0 {
|
||||||
|
var cert []*x509.Certificate
|
||||||
|
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if cert == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
certs = append(certs, cert...)
|
||||||
|
}
|
||||||
|
if len(certsPEM) > 0 {
|
||||||
|
return nil, fmt.Errorf("Trailing PEM data")
|
||||||
|
}
|
||||||
|
return certs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&port, "port", "4430", "port")
|
||||||
|
flag.StringVar(&serverName, "host", "example.com", "hostname")
|
||||||
|
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
|
||||||
|
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
|
||||||
|
flag.StringVar(&responseFile, "response", "", "file to serve")
|
||||||
|
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
|
||||||
|
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
var certChain []*x509.Certificate
|
||||||
|
var priv crypto.Signer
|
||||||
|
var response []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Load the key and certificate chain
|
||||||
|
if certFile != "" {
|
||||||
|
certs, err := ioutil.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
} else {
|
||||||
|
certChain, err = ParseCertificatesPEM(certs)
|
||||||
|
if err != nil {
|
||||||
|
certChain, err = x509.ParseCertificates(certs)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error parsing certificates: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if keyFile != "" {
|
||||||
|
keyPEM, err := ioutil.ReadFile(keyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
} else {
|
||||||
|
priv, err = ParsePrivateKeyPEM(keyPEM)
|
||||||
|
if priv == nil || err != nil {
|
||||||
|
log.Fatalf("Error parsing private key: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load response file
|
||||||
|
if responseFile != "" {
|
||||||
|
log.Printf("Loading response file: %v", responseFile)
|
||||||
|
response, err = ioutil.ReadFile(responseFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response = []byte("Welcome to the TLS 1.3 zone!")
|
||||||
|
}
|
||||||
|
handler := responder(response)
|
||||||
|
|
||||||
|
config := mint.Config{
|
||||||
|
SendSessionTickets: true,
|
||||||
|
ServerName: serverName,
|
||||||
|
NextProtos: []string{"http/1.1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if h2 {
|
||||||
|
config.NextProtos = []string{"h2"}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.SendSessionTickets = sendTickets
|
||||||
|
|
||||||
|
if certChain != nil && priv != nil {
|
||||||
|
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
|
||||||
|
config.Certificates = []*mint.Certificate{
|
||||||
|
{
|
||||||
|
Chain: certChain,
|
||||||
|
PrivateKey: priv,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Init(false)
|
||||||
|
|
||||||
|
service := "0.0.0.0:" + port
|
||||||
|
srv := &http.Server{Handler: handler}
|
||||||
|
|
||||||
|
log.Printf("Listening on port %v", port)
|
||||||
|
// Need the inner loop here because the h1 server errors on a dropped connection
|
||||||
|
// Need the outer loop here because the h2 server is per-connection
|
||||||
|
for {
|
||||||
|
listener, err := mint.Listen("tcp", service, &config)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Listen Error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h2 {
|
||||||
|
alert := srv.Serve(listener)
|
||||||
|
if alert != mint.AlertNoAlert {
|
||||||
|
log.Printf("Serve Error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
srv2 := new(http2.Server)
|
||||||
|
opts := &http2.ServeConnOpts{
|
||||||
|
Handler: handler,
|
||||||
|
BaseConfig: srv,
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Accept error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go srv2.ServeConn(conn, opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
)
|
||||||
|
|
||||||
|
var port string
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var config mint.Config
|
||||||
|
config.SendSessionTickets = true
|
||||||
|
config.ServerName = "localhost"
|
||||||
|
config.Init(false)
|
||||||
|
|
||||||
|
flag.StringVar(&port, "port", "4430", "port")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
service := "0.0.0.0:" + port
|
||||||
|
listener, err := mint.Listen("tcp", service, &config)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("server: listen: %s", err)
|
||||||
|
}
|
||||||
|
log.Print("server: listening")
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("server: accept: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
log.Printf("server: accepted from %s", conn.RemoteAddr())
|
||||||
|
go handleClient(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleClient(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 10)
|
||||||
|
for {
|
||||||
|
log.Print("server: conn: waiting")
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("server: conn: read: %s", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = conn.Write([]byte("hello world"))
|
||||||
|
log.Printf("server: conn: wrote %d bytes", n)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("server: write: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Println("server: conn: closed")
|
||||||
|
}
|
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
@ -0,0 +1,942 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"hash"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client State Machine
|
||||||
|
//
|
||||||
|
// START <----+
|
||||||
|
// Send ClientHello | | Recv HelloRetryRequest
|
||||||
|
// / v |
|
||||||
|
// | WAIT_SH ---+
|
||||||
|
// Can | | Recv ServerHello
|
||||||
|
// send | V
|
||||||
|
// early | WAIT_EE
|
||||||
|
// data | | Recv EncryptedExtensions
|
||||||
|
// | +--------+--------+
|
||||||
|
// | Using | | Using certificate
|
||||||
|
// | PSK | v
|
||||||
|
// | | WAIT_CERT_CR
|
||||||
|
// | | Recv | | Recv CertificateRequest
|
||||||
|
// | | Certificate | v
|
||||||
|
// | | | WAIT_CERT
|
||||||
|
// | | | | Recv Certificate
|
||||||
|
// | | v v
|
||||||
|
// | | WAIT_CV
|
||||||
|
// | | | Recv CertificateVerify
|
||||||
|
// | +> WAIT_FINISHED <+
|
||||||
|
// | | Recv Finished
|
||||||
|
// \ |
|
||||||
|
// | [Send EndOfEarlyData]
|
||||||
|
// | [Send Certificate [+ CertificateVerify]]
|
||||||
|
// | Send Finished
|
||||||
|
// Can send v
|
||||||
|
// app data --> CONNECTED
|
||||||
|
// after
|
||||||
|
// here
|
||||||
|
//
|
||||||
|
// State Instructions
|
||||||
|
// START Send(CH); [RekeyOut; SendEarlyData]
|
||||||
|
// WAIT_SH Send(CH) || RekeyIn
|
||||||
|
// WAIT_EE {}
|
||||||
|
// WAIT_CERT_CR {}
|
||||||
|
// WAIT_CERT {}
|
||||||
|
// WAIT_CV {}
|
||||||
|
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
||||||
|
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||||
|
|
||||||
|
type ClientStateStart struct {
|
||||||
|
Caps Capabilities
|
||||||
|
Opts ConnectionOptions
|
||||||
|
Params ConnectionParameters
|
||||||
|
|
||||||
|
cookie []byte
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// key_shares
|
||||||
|
offeredDH := map[NamedGroup][]byte{}
|
||||||
|
ks := KeyShareExtension{
|
||||||
|
HandshakeType: HandshakeTypeClientHello,
|
||||||
|
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
|
||||||
|
}
|
||||||
|
for i, group := range state.Caps.Groups {
|
||||||
|
pub, priv, err := newKeyShare(group)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Shares[i].Group = group
|
||||||
|
ks.Shares[i].KeyExchange = pub
|
||||||
|
offeredDH[group] = priv
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "opts: %+v", state.Opts)
|
||||||
|
|
||||||
|
// supported_versions, supported_groups, signature_algorithms, server_name
|
||||||
|
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
|
||||||
|
sni := ServerNameExtension(state.Opts.ServerName)
|
||||||
|
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
|
||||||
|
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||||
|
|
||||||
|
state.Params.ServerName = state.Opts.ServerName
|
||||||
|
|
||||||
|
// Application Layer Protocol Negotiation
|
||||||
|
var alpn *ALPNExtension
|
||||||
|
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
|
||||||
|
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct base ClientHello
|
||||||
|
ch := &ClientHelloBody{
|
||||||
|
CipherSuites: state.Caps.CipherSuites,
|
||||||
|
}
|
||||||
|
_, err := prng.Read(ch.Random[:])
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
|
||||||
|
err := ch.Extensions.Add(ext)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// XXX: These optional extensions can't be folded into the above because Go
|
||||||
|
// interface-typed values are never reported as nil
|
||||||
|
if alpn != nil {
|
||||||
|
err := ch.Extensions.Add(alpn)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if state.cookie != nil {
|
||||||
|
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PSK and EarlyData just before transmitting, so that we can
|
||||||
|
// calculate the PSK binder value
|
||||||
|
var psk *PreSharedKeyExtension
|
||||||
|
var ed *EarlyDataExtension
|
||||||
|
var offeredPSK PreSharedKey
|
||||||
|
var earlyHash crypto.Hash
|
||||||
|
var earlySecret []byte
|
||||||
|
var clientEarlyTrafficKeys keySet
|
||||||
|
var clientHello *HandshakeMessage
|
||||||
|
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
|
||||||
|
offeredPSK = key
|
||||||
|
|
||||||
|
// Narrow ciphersuites to ones that match PSK hash
|
||||||
|
params, ok := cipherSuiteMap[key.CipherSuite]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
compatibleSuites := []CipherSuite{}
|
||||||
|
for _, suite := range ch.CipherSuites {
|
||||||
|
if cipherSuiteMap[suite].Hash == params.Hash {
|
||||||
|
compatibleSuites = append(compatibleSuites, suite)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ch.CipherSuites = compatibleSuites
|
||||||
|
|
||||||
|
// Signal early data if we're going to do it
|
||||||
|
if len(state.Opts.EarlyData) > 0 {
|
||||||
|
state.Params.ClientSendingEarlyData = true
|
||||||
|
ed = &EarlyDataExtension{}
|
||||||
|
err = ch.Extensions.Add(ed)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "Error adding early data extension: %v", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal supported PSK key exchange modes
|
||||||
|
if len(state.Caps.PSKModes) == 0 {
|
||||||
|
logf(logTypeHandshake, "PSK selected, but no PSKModes")
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
|
||||||
|
err = ch.Extensions.Add(kem)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the shim PSK extension to the ClientHello
|
||||||
|
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
|
||||||
|
psk = &PreSharedKeyExtension{
|
||||||
|
HandshakeType: HandshakeTypeClientHello,
|
||||||
|
Identities: []PSKIdentity{
|
||||||
|
{
|
||||||
|
Identity: key.Identity,
|
||||||
|
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Binders: []PSKBinderEntry{
|
||||||
|
// Note: Stub to get the length fields right
|
||||||
|
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ch.Extensions.Add(psk)
|
||||||
|
|
||||||
|
// Compute the binder key
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
|
||||||
|
earlyHash = params.Hash
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
|
||||||
|
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||||
|
|
||||||
|
binderLabel := labelExternalBinder
|
||||||
|
if key.IsResumption {
|
||||||
|
binderLabel = labelResumptionBinder
|
||||||
|
}
|
||||||
|
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||||
|
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
|
||||||
|
|
||||||
|
// Compute the binder value
|
||||||
|
trunc, err := ch.Truncated()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
truncHash := params.Hash.New()
|
||||||
|
truncHash.Write(trunc)
|
||||||
|
|
||||||
|
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
|
||||||
|
|
||||||
|
// Replace the PSK extension
|
||||||
|
psk.Binders[0].Binder = binder
|
||||||
|
ch.Extensions.Add(psk)
|
||||||
|
|
||||||
|
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
||||||
|
// this one should too.
|
||||||
|
clientHello, _ = HandshakeMessageFromBody(ch)
|
||||||
|
|
||||||
|
// Compute early traffic keys
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(clientHello.Marshal())
|
||||||
|
chHash := h.Sum(nil)
|
||||||
|
|
||||||
|
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||||
|
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
|
||||||
|
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
|
||||||
|
} else if len(state.Opts.EarlyData) > 0 {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
} else {
|
||||||
|
clientHello, err = HandshakeMessageFromBody(ch)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
|
||||||
|
nextState := ClientStateWaitSH{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Opts: state.Opts,
|
||||||
|
Params: state.Params,
|
||||||
|
OfferedDH: offeredDH,
|
||||||
|
OfferedPSK: offeredPSK,
|
||||||
|
|
||||||
|
earlySecret: earlySecret,
|
||||||
|
earlyHash: earlyHash,
|
||||||
|
|
||||||
|
firstClientHello: state.firstClientHello,
|
||||||
|
helloRetryRequest: state.helloRetryRequest,
|
||||||
|
clientHello: clientHello,
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
SendHandshakeMessage{clientHello},
|
||||||
|
}
|
||||||
|
if state.Params.ClientSendingEarlyData {
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||||
|
SendEarlyData{},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitSH struct {
|
||||||
|
Caps Capabilities
|
||||||
|
Opts ConnectionOptions
|
||||||
|
Params ConnectionParameters
|
||||||
|
OfferedDH map[NamedGroup][]byte
|
||||||
|
OfferedPSK PreSharedKey
|
||||||
|
PSK []byte
|
||||||
|
|
||||||
|
earlySecret []byte
|
||||||
|
earlyHash crypto.Hash
|
||||||
|
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
clientHello *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyGeneric, err := hm.ToBody()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
switch body := bodyGeneric.(type) {
|
||||||
|
case *HelloRetryRequestBody:
|
||||||
|
hrr := body
|
||||||
|
|
||||||
|
if state.helloRetryRequest != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the version sent by the server is the one we support
|
||||||
|
if hrr.Version != supportedVersion {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the server provided a supported ciphersuite
|
||||||
|
supportedCipherSuite := false
|
||||||
|
for _, suite := range state.Caps.CipherSuites {
|
||||||
|
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
|
||||||
|
}
|
||||||
|
if !supportedCipherSuite {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Narrow the supported ciphersuites to the server-provided one
|
||||||
|
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The only thing we know how to respond to in an HRR is the Cookie
|
||||||
|
// extension, so if there is either no Cookie extension or anything other
|
||||||
|
// than a Cookie extension, we have to fail.
|
||||||
|
serverCookie := new(CookieExtension)
|
||||||
|
foundCookie := hrr.Extensions.Find(serverCookie)
|
||||||
|
if !foundCookie || len(hrr.Extensions) != 1 {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash the body into a pseudo-message
|
||||||
|
// XXX: Ignoring some errors here
|
||||||
|
params := cipherSuiteMap[hrr.CipherSuite]
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(state.clientHello.Marshal())
|
||||||
|
firstClientHello := &HandshakeMessage{
|
||||||
|
msgType: HandshakeTypeMessageHash,
|
||||||
|
body: h.Sum(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
|
||||||
|
return ClientStateStart{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Opts: state.Opts,
|
||||||
|
cookie: serverCookie.Cookie,
|
||||||
|
firstClientHello: firstClientHello,
|
||||||
|
helloRetryRequest: hm,
|
||||||
|
}.Next(nil)
|
||||||
|
|
||||||
|
case *ServerHelloBody:
|
||||||
|
sh := body
|
||||||
|
|
||||||
|
// Check that the version sent by the server is the one we support
|
||||||
|
if sh.Version != supportedVersion {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the server provided a supported ciphersuite
|
||||||
|
supportedCipherSuite := false
|
||||||
|
for _, suite := range state.Caps.CipherSuites {
|
||||||
|
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
|
||||||
|
}
|
||||||
|
if !supportedCipherSuite {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do PSK or key agreement depending on extensions
|
||||||
|
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
|
||||||
|
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
|
||||||
|
|
||||||
|
foundPSK := sh.Extensions.Find(&serverPSK)
|
||||||
|
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
|
||||||
|
|
||||||
|
if foundPSK && (serverPSK.SelectedIdentity == 0) {
|
||||||
|
state.Params.UsingPSK = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var dhSecret []byte
|
||||||
|
if foundKeyShare {
|
||||||
|
sks := serverKeyShare.Shares[0]
|
||||||
|
priv, ok := state.OfferedDH[sks.Group]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Params.UsingDH = true
|
||||||
|
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
|
||||||
|
}
|
||||||
|
|
||||||
|
suite := sh.CipherSuite
|
||||||
|
state.Params.CipherSuite = suite
|
||||||
|
|
||||||
|
params, ok := cipherSuiteMap[suite]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start up the handshake hash
|
||||||
|
handshakeHash := params.Hash.New()
|
||||||
|
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||||
|
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||||
|
handshakeHash.Write(state.clientHello.Marshal())
|
||||||
|
handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
// Compute handshake secrets
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
|
||||||
|
var earlySecret []byte
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
if params.Hash != state.earlyHash {
|
||||||
|
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
|
||||||
|
state.earlyHash, suite, params.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
earlySecret = state.earlySecret
|
||||||
|
} else {
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dhSecret == nil {
|
||||||
|
dhSecret = zero
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
h2 := handshakeHash.Sum(nil)
|
||||||
|
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||||
|
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
|
||||||
|
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||||
|
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||||
|
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||||
|
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||||
|
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||||
|
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||||
|
|
||||||
|
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
|
||||||
|
nextState := ClientStateWaitEE{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: params,
|
||||||
|
handshakeHash: handshakeHash,
|
||||||
|
certificates: state.Caps.Certificates,
|
||||||
|
masterSecret: masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||||
|
}
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitEE struct {
|
||||||
|
Caps Capabilities
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
certificates []*Certificate
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
ee := EncryptedExtensionsBody{}
|
||||||
|
_, err := ee.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serverALPN := ALPNExtension{}
|
||||||
|
serverEarlyData := EarlyDataExtension{}
|
||||||
|
|
||||||
|
gotALPN := ee.Extensions.Find(&serverALPN)
|
||||||
|
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
|
||||||
|
|
||||||
|
if gotALPN && len(serverALPN.Protocols) > 0 {
|
||||||
|
state.Params.NextProto = serverALPN.Protocols[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
|
||||||
|
nextState := ClientStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
|
||||||
|
nextState := ClientStateWaitCertCR{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitCertCR struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
certificates []*Certificate
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyGeneric, err := hm.ToBody()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
switch body := bodyGeneric.(type) {
|
||||||
|
case *CertificateBody:
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
|
||||||
|
nextState := ClientStateWaitCV{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificate: body,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
|
||||||
|
case *CertificateRequestBody:
|
||||||
|
// A certificate request in the handshake should have a zero-length context
|
||||||
|
if len(body.CertificateRequestContext) > 0 {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Params.UsingClientAuth = true
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
|
||||||
|
nextState := ClientStateWaitCert{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificateRequest: body,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitCert struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
|
certificates []*Certificate
|
||||||
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := &CertificateBody{}
|
||||||
|
_, err := cert.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
|
||||||
|
nextState := ClientStateWaitCV{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificate: cert,
|
||||||
|
serverCertificateRequest: state.serverCertificateRequest,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitCV struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
|
certificates []*Certificate
|
||||||
|
serverCertificate *CertificateBody
|
||||||
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
certVerify := CertificateVerifyBody{}
|
||||||
|
_, err := certVerify.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
hcv := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
|
||||||
|
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.AuthCertificate != nil {
|
||||||
|
err := state.AuthCertificate(state.serverCertificate.CertificateList)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
|
||||||
|
return nil, nil, AlertBadCertificate
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
|
||||||
|
nextState := ClientStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificateRequest: state.serverCertificateRequest,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitFinished struct {
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
|
certificates []*Certificate
|
||||||
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify server's Finished
|
||||||
|
h3 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||||
|
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||||
|
|
||||||
|
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
|
||||||
|
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||||
|
|
||||||
|
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
|
||||||
|
_, err := fin.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
|
||||||
|
fin.VerifyData, serverFinishedData)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the handshake hash with the Finished
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
|
||||||
|
h4 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
|
||||||
|
|
||||||
|
// Compute traffic secrets and keys
|
||||||
|
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||||
|
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||||
|
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||||
|
|
||||||
|
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
|
||||||
|
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
|
||||||
|
|
||||||
|
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
|
||||||
|
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||||
|
|
||||||
|
// Assemble client's second flight
|
||||||
|
toSend := []HandshakeAction{}
|
||||||
|
|
||||||
|
if state.Params.UsingEarlyData {
|
||||||
|
// Note: We only send EOED if the server is actually going to use the early
|
||||||
|
// data. Otherwise, it will never see it, and the transcripts will
|
||||||
|
// mismatch.
|
||||||
|
// EOED marshal is infallible
|
||||||
|
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{eoedm})
|
||||||
|
state.handshakeHash.Write(eoedm.Marshal())
|
||||||
|
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||||
|
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
|
||||||
|
|
||||||
|
if state.Params.UsingClientAuth {
|
||||||
|
// Extract constraints from certicateRequest
|
||||||
|
schemes := SignatureAlgorithmsExtension{}
|
||||||
|
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
|
||||||
|
if !gotSchemes {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select a certificate
|
||||||
|
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
|
||||||
|
if err != nil {
|
||||||
|
// XXX: Signal this to the application layer?
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||||
|
|
||||||
|
certificate := &CertificateBody{}
|
||||||
|
certm, err := HandshakeMessageFromBody(certificate)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||||
|
state.handshakeHash.Write(certm.Marshal())
|
||||||
|
} else {
|
||||||
|
// Create and send Certificate, CertificateVerify
|
||||||
|
certificate := &CertificateBody{
|
||||||
|
CertificateList: make([]CertificateEntry, len(cert.Chain)),
|
||||||
|
}
|
||||||
|
for i, entry := range cert.Chain {
|
||||||
|
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||||
|
}
|
||||||
|
certm, err := HandshakeMessageFromBody(certificate)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||||
|
state.handshakeHash.Write(certm.Marshal())
|
||||||
|
|
||||||
|
hcv := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
|
||||||
|
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
|
||||||
|
|
||||||
|
err = certificateVerify.Sign(cert.PrivateKey, hcv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||||
|
state.handshakeHash.Write(certvm.Marshal())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the client's Finished message
|
||||||
|
h5 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||||
|
|
||||||
|
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||||
|
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||||
|
|
||||||
|
fin = &FinishedBody{
|
||||||
|
VerifyDataLen: len(clientFinishedData),
|
||||||
|
VerifyData: clientFinishedData,
|
||||||
|
}
|
||||||
|
finm, err := HandshakeMessageFromBody(fin)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the resumption secret
|
||||||
|
state.handshakeHash.Write(finm.Marshal())
|
||||||
|
h6 := state.handshakeHash.Sum(nil)
|
||||||
|
|
||||||
|
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||||
|
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||||
|
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
SendHandshakeMessage{finm},
|
||||||
|
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
|
||||||
|
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
|
||||||
|
nextState := StateConnected{
|
||||||
|
Params: state.Params,
|
||||||
|
isClient: true,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
resumptionSecret: resumptionSecret,
|
||||||
|
clientTrafficSecret: clientTrafficSecret,
|
||||||
|
serverTrafficSecret: serverTrafficSecret,
|
||||||
|
exporterSecret: exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
supportedVersion uint16 = 0x7f15 // draft-21
|
||||||
|
|
||||||
|
// Flags for some minor compat issues
|
||||||
|
allowWrongVersionNumber = true
|
||||||
|
allowPKCS1 = true
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} ContentType;
|
||||||
|
type RecordType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
RecordTypeAlert RecordType = 21
|
||||||
|
RecordTypeHandshake RecordType = 22
|
||||||
|
RecordTypeApplicationData RecordType = 23
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} HandshakeType;
|
||||||
|
type HandshakeType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Omitted: *_RESERVED
|
||||||
|
HandshakeTypeClientHello HandshakeType = 1
|
||||||
|
HandshakeTypeServerHello HandshakeType = 2
|
||||||
|
HandshakeTypeNewSessionTicket HandshakeType = 4
|
||||||
|
HandshakeTypeEndOfEarlyData HandshakeType = 5
|
||||||
|
HandshakeTypeHelloRetryRequest HandshakeType = 6
|
||||||
|
HandshakeTypeEncryptedExtensions HandshakeType = 8
|
||||||
|
HandshakeTypeCertificate HandshakeType = 11
|
||||||
|
HandshakeTypeCertificateRequest HandshakeType = 13
|
||||||
|
HandshakeTypeCertificateVerify HandshakeType = 15
|
||||||
|
HandshakeTypeServerConfiguration HandshakeType = 17
|
||||||
|
HandshakeTypeFinished HandshakeType = 20
|
||||||
|
HandshakeTypeKeyUpdate HandshakeType = 24
|
||||||
|
HandshakeTypeMessageHash HandshakeType = 254
|
||||||
|
)
|
||||||
|
|
||||||
|
// uint8 CipherSuite[2];
|
||||||
|
type CipherSuite uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
|
||||||
|
// value for this type so that we can detect when a field is set.
|
||||||
|
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
|
||||||
|
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
|
||||||
|
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
|
||||||
|
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
|
||||||
|
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
|
||||||
|
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c CipherSuite) String() string {
|
||||||
|
switch c {
|
||||||
|
case CIPHER_SUITE_UNKNOWN:
|
||||||
|
return "unknown"
|
||||||
|
case TLS_AES_128_GCM_SHA256:
|
||||||
|
return "TLS_AES_128_GCM_SHA256"
|
||||||
|
case TLS_AES_256_GCM_SHA384:
|
||||||
|
return "TLS_AES_256_GCM_SHA384"
|
||||||
|
case TLS_CHACHA20_POLY1305_SHA256:
|
||||||
|
return "TLS_CHACHA20_POLY1305_SHA256"
|
||||||
|
case TLS_AES_128_CCM_SHA256:
|
||||||
|
return "TLS_AES_128_CCM_SHA256"
|
||||||
|
case TLS_AES_256_CCM_8_SHA256:
|
||||||
|
return "TLS_AES_256_CCM_8_SHA256"
|
||||||
|
}
|
||||||
|
// cannot use %x here, since it calls String(), leading to infinite recursion
|
||||||
|
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
|
||||||
|
}
|
||||||
|
|
||||||
|
// enum {...} SignatureScheme
|
||||||
|
type SignatureScheme uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RSASSA-PKCS1-v1_5 algorithms
|
||||||
|
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
|
||||||
|
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
|
||||||
|
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
|
||||||
|
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
|
||||||
|
// ECDSA algorithms
|
||||||
|
ECDSA_P256_SHA256 SignatureScheme = 0x0403
|
||||||
|
ECDSA_P384_SHA384 SignatureScheme = 0x0503
|
||||||
|
ECDSA_P521_SHA512 SignatureScheme = 0x0603
|
||||||
|
// RSASSA-PSS algorithms
|
||||||
|
RSA_PSS_SHA256 SignatureScheme = 0x0804
|
||||||
|
RSA_PSS_SHA384 SignatureScheme = 0x0805
|
||||||
|
RSA_PSS_SHA512 SignatureScheme = 0x0806
|
||||||
|
// EdDSA algorithms
|
||||||
|
Ed25519 SignatureScheme = 0x0807
|
||||||
|
Ed448 SignatureScheme = 0x0808
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} ExtensionType
|
||||||
|
type ExtensionType uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
ExtensionTypeServerName ExtensionType = 0
|
||||||
|
ExtensionTypeSupportedGroups ExtensionType = 10
|
||||||
|
ExtensionTypeSignatureAlgorithms ExtensionType = 13
|
||||||
|
ExtensionTypeALPN ExtensionType = 16
|
||||||
|
ExtensionTypeKeyShare ExtensionType = 40
|
||||||
|
ExtensionTypePreSharedKey ExtensionType = 41
|
||||||
|
ExtensionTypeEarlyData ExtensionType = 42
|
||||||
|
ExtensionTypeSupportedVersions ExtensionType = 43
|
||||||
|
ExtensionTypeCookie ExtensionType = 44
|
||||||
|
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
|
||||||
|
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} NamedGroup
|
||||||
|
type NamedGroup uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Elliptic Curve Groups.
|
||||||
|
P256 NamedGroup = 23
|
||||||
|
P384 NamedGroup = 24
|
||||||
|
P521 NamedGroup = 25
|
||||||
|
// ECDH functions.
|
||||||
|
X25519 NamedGroup = 29
|
||||||
|
X448 NamedGroup = 30
|
||||||
|
// Finite field groups.
|
||||||
|
FFDHE2048 NamedGroup = 256
|
||||||
|
FFDHE3072 NamedGroup = 257
|
||||||
|
FFDHE4096 NamedGroup = 258
|
||||||
|
FFDHE6144 NamedGroup = 259
|
||||||
|
FFDHE8192 NamedGroup = 260
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} PskKeyExchangeMode;
|
||||||
|
type PSKKeyExchangeMode uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
PSKModeKE PSKKeyExchangeMode = 0
|
||||||
|
PSKModeDHEKE PSKKeyExchangeMode = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {
|
||||||
|
// update_not_requested(0), update_requested(1), (255)
|
||||||
|
// } KeyUpdateRequest;
|
||||||
|
type KeyUpdateRequest uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
KeyUpdateNotRequested KeyUpdateRequest = 0
|
||||||
|
KeyUpdateRequested KeyUpdateRequest = 1
|
||||||
|
)
|
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
@ -0,0 +1,819 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var WouldBlock = fmt.Errorf("Would have blocked")
|
||||||
|
|
||||||
|
type Certificate struct {
|
||||||
|
Chain []*x509.Certificate
|
||||||
|
PrivateKey crypto.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreSharedKey struct {
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
IsResumption bool
|
||||||
|
Identity []byte
|
||||||
|
Key []byte
|
||||||
|
NextProto string
|
||||||
|
ReceivedAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
TicketAgeAdd uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreSharedKeyCache interface {
|
||||||
|
Get(string) (PreSharedKey, bool)
|
||||||
|
Put(string, PreSharedKey)
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
type PSKMapCache map[string]PreSharedKey
|
||||||
|
|
||||||
|
// A CookieHandler does two things:
|
||||||
|
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||||
|
// - validates this byte string echoed by the client in the ClientHello
|
||||||
|
type CookieHandler interface {
|
||||||
|
Generate(*Conn) ([]byte, error)
|
||||||
|
Validate(*Conn, []byte) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||||||
|
psk, ok = cache[key]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
|
||||||
|
(*cache)[key] = psk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache PSKMapCache) Size() int {
|
||||||
|
return len(cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the struct used to pass configuration settings to a TLS client or
|
||||||
|
// server instance. The settings for client and server are pretty different,
|
||||||
|
// but we just throw them all in here.
|
||||||
|
type Config struct {
|
||||||
|
// Client fields
|
||||||
|
ServerName string
|
||||||
|
|
||||||
|
// Server fields
|
||||||
|
SendSessionTickets bool
|
||||||
|
TicketLifetime uint32
|
||||||
|
TicketLen int
|
||||||
|
EarlyDataLifetime uint32
|
||||||
|
AllowEarlyData bool
|
||||||
|
// Require the client to echo a cookie.
|
||||||
|
RequireCookie bool
|
||||||
|
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
||||||
|
// The default cookie handler uses 32 random bytes as a cookie.
|
||||||
|
CookieHandler CookieHandler
|
||||||
|
RequireClientAuth bool
|
||||||
|
|
||||||
|
// Shared fields
|
||||||
|
Certificates []*Certificate
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
CipherSuites []CipherSuite
|
||||||
|
Groups []NamedGroup
|
||||||
|
SignatureSchemes []SignatureScheme
|
||||||
|
NextProtos []string
|
||||||
|
PSKs PreSharedKeyCache
|
||||||
|
PSKModes []PSKKeyExchangeMode
|
||||||
|
NonBlocking bool
|
||||||
|
|
||||||
|
// The same config object can be shared among different connections, so it
|
||||||
|
// needs its own mutex
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a shallow clone of c. It is safe to clone a Config that is
|
||||||
|
// being used concurrently by a TLS client or server.
|
||||||
|
func (c *Config) Clone() *Config {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
return &Config{
|
||||||
|
ServerName: c.ServerName,
|
||||||
|
|
||||||
|
SendSessionTickets: c.SendSessionTickets,
|
||||||
|
TicketLifetime: c.TicketLifetime,
|
||||||
|
TicketLen: c.TicketLen,
|
||||||
|
EarlyDataLifetime: c.EarlyDataLifetime,
|
||||||
|
AllowEarlyData: c.AllowEarlyData,
|
||||||
|
RequireCookie: c.RequireCookie,
|
||||||
|
RequireClientAuth: c.RequireClientAuth,
|
||||||
|
|
||||||
|
Certificates: c.Certificates,
|
||||||
|
AuthCertificate: c.AuthCertificate,
|
||||||
|
CipherSuites: c.CipherSuites,
|
||||||
|
Groups: c.Groups,
|
||||||
|
SignatureSchemes: c.SignatureSchemes,
|
||||||
|
NextProtos: c.NextProtos,
|
||||||
|
PSKs: c.PSKs,
|
||||||
|
PSKModes: c.PSKModes,
|
||||||
|
NonBlocking: c.NonBlocking,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Init(isClient bool) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if len(c.CipherSuites) == 0 {
|
||||||
|
c.CipherSuites = defaultSupportedCipherSuites
|
||||||
|
}
|
||||||
|
if len(c.Groups) == 0 {
|
||||||
|
c.Groups = defaultSupportedGroups
|
||||||
|
}
|
||||||
|
if len(c.SignatureSchemes) == 0 {
|
||||||
|
c.SignatureSchemes = defaultSignatureSchemes
|
||||||
|
}
|
||||||
|
if c.TicketLen == 0 {
|
||||||
|
c.TicketLen = defaultTicketLen
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(c.PSKs).IsValid() {
|
||||||
|
c.PSKs = &PSKMapCache{}
|
||||||
|
}
|
||||||
|
if len(c.PSKModes) == 0 {
|
||||||
|
c.PSKModes = defaultPSKModes
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no certificate, generate one
|
||||||
|
if !isClient && len(c.Certificates) == 0 {
|
||||||
|
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
||||||
|
priv, err := newSigningKey(RSA_PSS_SHA256)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Certificates = []*Certificate{
|
||||||
|
{
|
||||||
|
Chain: []*x509.Certificate{cert},
|
||||||
|
PrivateKey: priv,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ValidForServer() bool {
|
||||||
|
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
|
||||||
|
(len(c.Certificates) > 0 &&
|
||||||
|
len(c.Certificates[0].Chain) > 0 &&
|
||||||
|
c.Certificates[0].PrivateKey != nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ValidForClient() bool {
|
||||||
|
return len(c.ServerName) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultSupportedCipherSuites = []CipherSuite{
|
||||||
|
TLS_AES_128_GCM_SHA256,
|
||||||
|
TLS_AES_256_GCM_SHA384,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSupportedGroups = []NamedGroup{
|
||||||
|
P256,
|
||||||
|
P384,
|
||||||
|
FFDHE2048,
|
||||||
|
X25519,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSignatureSchemes = []SignatureScheme{
|
||||||
|
RSA_PSS_SHA256,
|
||||||
|
RSA_PSS_SHA384,
|
||||||
|
RSA_PSS_SHA512,
|
||||||
|
ECDSA_P256_SHA256,
|
||||||
|
ECDSA_P384_SHA384,
|
||||||
|
ECDSA_P521_SHA512,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultTicketLen = 16
|
||||||
|
|
||||||
|
defaultPSKModes = []PSKKeyExchangeMode{
|
||||||
|
PSKModeKE,
|
||||||
|
PSKModeDHEKE,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionState struct {
|
||||||
|
HandshakeState string // string representation of the handshake state.
|
||||||
|
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||||
|
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
||||||
|
NextProto string // Selected ALPN proto
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||||||
|
// * Read, Write, and Close are provided locally
|
||||||
|
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
|
||||||
|
type Conn struct {
|
||||||
|
config *Config
|
||||||
|
conn net.Conn
|
||||||
|
isClient bool
|
||||||
|
|
||||||
|
EarlyData []byte
|
||||||
|
|
||||||
|
state StateConnected
|
||||||
|
hState HandshakeState
|
||||||
|
handshakeMutex sync.Mutex
|
||||||
|
handshakeAlert Alert
|
||||||
|
handshakeComplete bool
|
||||||
|
|
||||||
|
readBuffer []byte
|
||||||
|
in, out *RecordLayer
|
||||||
|
hIn, hOut *HandshakeLayer
|
||||||
|
|
||||||
|
extHandler AppExtensionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||||||
|
c := &Conn{conn: conn, config: config, isClient: isClient}
|
||||||
|
c.in = NewRecordLayer(c.conn)
|
||||||
|
c.out = NewRecordLayer(c.conn)
|
||||||
|
c.hIn = NewHandshakeLayer(c.in)
|
||||||
|
c.hIn.nonblocking = c.config.NonBlocking
|
||||||
|
c.hOut = NewHandshakeLayer(c.out)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read up
|
||||||
|
func (c *Conn) consumeRecord() error {
|
||||||
|
pt, err := c.in.ReadRecord()
|
||||||
|
if pt == nil {
|
||||||
|
logf(logTypeIO, "extendBuffer returns error %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch pt.contentType {
|
||||||
|
case RecordTypeHandshake:
|
||||||
|
logf(logTypeHandshake, "Received post-handshake message")
|
||||||
|
// We do not support fragmentation of post-handshake handshake messages.
|
||||||
|
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||||||
|
start := 0
|
||||||
|
for start < len(pt.fragment) {
|
||||||
|
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
||||||
|
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||||||
|
}
|
||||||
|
|
||||||
|
hm := &HandshakeMessage{}
|
||||||
|
hm.msgType = HandshakeType(pt.fragment[start])
|
||||||
|
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||||||
|
|
||||||
|
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
||||||
|
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||||||
|
}
|
||||||
|
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
||||||
|
|
||||||
|
// Advance state machine
|
||||||
|
state, actions, alert := c.state.Next(hm)
|
||||||
|
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||||
|
// authentication, we'll need to allow transitions other than
|
||||||
|
// Connected -> Connected
|
||||||
|
var connected bool
|
||||||
|
c.state, connected = state.(StateConnected)
|
||||||
|
if !connected {
|
||||||
|
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
start += handshakeHeaderLen + hmLen
|
||||||
|
}
|
||||||
|
case RecordTypeAlert:
|
||||||
|
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||||
|
if len(pt.fragment) != 2 {
|
||||||
|
c.sendAlert(AlertUnexpectedMessage)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
if Alert(pt.fragment[1]) == AlertCloseNotify {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
switch pt.fragment[0] {
|
||||||
|
case AlertLevelWarning:
|
||||||
|
// drop on the floor
|
||||||
|
case AlertLevelError:
|
||||||
|
return Alert(pt.fragment[1])
|
||||||
|
default:
|
||||||
|
c.sendAlert(AlertUnexpectedMessage)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
case RecordTypeApplicationData:
|
||||||
|
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||||||
|
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read application data up to the size of buffer. Handshake and alert records
|
||||||
|
// are consumed by the Conn object directly.
|
||||||
|
func (c *Conn) Read(buffer []byte) (int, error) {
|
||||||
|
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||||||
|
if alert := c.Handshake(); alert != AlertNoAlert {
|
||||||
|
return 0, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(buffer) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock the input channel
|
||||||
|
c.in.Lock()
|
||||||
|
defer c.in.Unlock()
|
||||||
|
for len(c.readBuffer) == 0 {
|
||||||
|
err := c.consumeRecord()
|
||||||
|
|
||||||
|
// err can be nil if consumeRecord processed a non app-data
|
||||||
|
// record.
|
||||||
|
if err != nil {
|
||||||
|
if c.config.NonBlocking || err != WouldBlock {
|
||||||
|
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var read int
|
||||||
|
n := len(buffer)
|
||||||
|
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
||||||
|
if len(c.readBuffer) <= n {
|
||||||
|
buffer = buffer[:len(c.readBuffer)]
|
||||||
|
copy(buffer, c.readBuffer)
|
||||||
|
read = len(c.readBuffer)
|
||||||
|
c.readBuffer = c.readBuffer[:0]
|
||||||
|
} else {
|
||||||
|
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
||||||
|
copy(buffer[:n], c.readBuffer[:n])
|
||||||
|
c.readBuffer = c.readBuffer[n:]
|
||||||
|
read = n
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write application data
|
||||||
|
func (c *Conn) Write(buffer []byte) (int, error) {
|
||||||
|
// Lock the output channel
|
||||||
|
c.out.Lock()
|
||||||
|
defer c.out.Unlock()
|
||||||
|
|
||||||
|
// Send full-size fragments
|
||||||
|
var start int
|
||||||
|
sent := 0
|
||||||
|
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||||
|
err := c.out.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeApplicationData,
|
||||||
|
fragment: buffer[start : start+maxFragmentLen],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return sent, err
|
||||||
|
}
|
||||||
|
sent += maxFragmentLen
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a final partial fragment if necessary
|
||||||
|
if start < len(buffer) {
|
||||||
|
err := c.out.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeApplicationData,
|
||||||
|
fragment: buffer[start:],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return sent, err
|
||||||
|
}
|
||||||
|
sent += len(buffer[start:])
|
||||||
|
}
|
||||||
|
return sent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendAlert sends a TLS alert message.
|
||||||
|
// c.out.Mutex <= L.
|
||||||
|
func (c *Conn) sendAlert(err Alert) error {
|
||||||
|
c.handshakeMutex.Lock()
|
||||||
|
defer c.handshakeMutex.Unlock()
|
||||||
|
|
||||||
|
var level int
|
||||||
|
switch err {
|
||||||
|
case AlertNoRenegotiation, AlertCloseNotify:
|
||||||
|
level = AlertLevelWarning
|
||||||
|
default:
|
||||||
|
level = AlertLevelError
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := []byte{byte(err), byte(level)}
|
||||||
|
c.out.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeAlert,
|
||||||
|
fragment: buf,
|
||||||
|
})
|
||||||
|
|
||||||
|
// close_notify and end_of_early_data are not actually errors
|
||||||
|
if level == AlertLevelWarning {
|
||||||
|
return &net.OpError{Op: "local error", Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
// XXX crypto/tls has an interlock with Write here. Do we need that?
|
||||||
|
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local network address.
|
||||||
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote network address.
|
||||||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the read and write deadlines associated with the connection.
|
||||||
|
// A zero value for t means Read and Write will not time out.
|
||||||
|
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline on the underlying connection.
|
||||||
|
// A zero value for t means Read will not time out.
|
||||||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||||||
|
// A zero value for t means Write will not time out.
|
||||||
|
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||||||
|
label := "[server]"
|
||||||
|
if c.isClient {
|
||||||
|
label = "[client]"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch action := actionGeneric.(type) {
|
||||||
|
case SendHandshakeMessage:
|
||||||
|
err := c.hOut.WriteMessage(action.Message)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case RekeyIn:
|
||||||
|
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
||||||
|
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case RekeyOut:
|
||||||
|
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
||||||
|
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case SendEarlyData:
|
||||||
|
logf(logTypeHandshake, "%s Sending early data...", label)
|
||||||
|
_, err := c.Write(c.EarlyData)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case ReadPastEarlyData:
|
||||||
|
logf(logTypeHandshake, "%s Reading past early data...", label)
|
||||||
|
// Scan past all records that fail to decrypt
|
||||||
|
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_, ok := err.(DecryptError)
|
||||||
|
|
||||||
|
for ok {
|
||||||
|
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_, ok = err.(DecryptError)
|
||||||
|
}
|
||||||
|
|
||||||
|
case ReadEarlyData:
|
||||||
|
logf(logTypeHandshake, "%s Reading early data...", label)
|
||||||
|
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
||||||
|
|
||||||
|
for t == RecordTypeApplicationData {
|
||||||
|
// Read a record into the buffer. Note that this is safe
|
||||||
|
// in blocking mode because we read the record in in
|
||||||
|
// PeekRecordType.
|
||||||
|
pt, err := c.in.ReadRecord()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
||||||
|
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
||||||
|
|
||||||
|
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "%s Done reading early data", label)
|
||||||
|
|
||||||
|
case StorePSK:
|
||||||
|
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||||||
|
if c.isClient {
|
||||||
|
// Clients look up PSKs based on server name
|
||||||
|
c.config.PSKs.Put(c.config.ServerName, action.PSK)
|
||||||
|
} else {
|
||||||
|
// Servers look them up based on the identity in the extension
|
||||||
|
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) HandshakeSetup() Alert {
|
||||||
|
var state HandshakeState
|
||||||
|
var actions []HandshakeAction
|
||||||
|
var alert Alert
|
||||||
|
|
||||||
|
if err := c.config.Init(c.isClient); err != nil {
|
||||||
|
logf(logTypeHandshake, "Error initializing config: %v", err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set things up
|
||||||
|
caps := Capabilities{
|
||||||
|
CipherSuites: c.config.CipherSuites,
|
||||||
|
Groups: c.config.Groups,
|
||||||
|
SignatureSchemes: c.config.SignatureSchemes,
|
||||||
|
PSKs: c.config.PSKs,
|
||||||
|
PSKModes: c.config.PSKModes,
|
||||||
|
AllowEarlyData: c.config.AllowEarlyData,
|
||||||
|
RequireCookie: c.config.RequireCookie,
|
||||||
|
CookieHandler: c.config.CookieHandler,
|
||||||
|
RequireClientAuth: c.config.RequireClientAuth,
|
||||||
|
NextProtos: c.config.NextProtos,
|
||||||
|
Certificates: c.config.Certificates,
|
||||||
|
ExtensionHandler: c.extHandler,
|
||||||
|
}
|
||||||
|
opts := ConnectionOptions{
|
||||||
|
ServerName: c.config.ServerName,
|
||||||
|
NextProtos: c.config.NextProtos,
|
||||||
|
EarlyData: c.EarlyData,
|
||||||
|
}
|
||||||
|
|
||||||
|
if caps.RequireCookie && caps.CookieHandler == nil {
|
||||||
|
caps.CookieHandler = &defaultCookieHandler{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.isClient {
|
||||||
|
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state = ServerStateStart{Caps: caps, conn: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.hState = state
|
||||||
|
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||||||
|
// determines whether a client or server handshake is performed. If a
|
||||||
|
// handshake has already been performed, then its result will be returned.
|
||||||
|
func (c *Conn) Handshake() Alert {
|
||||||
|
label := "[server]"
|
||||||
|
if c.isClient {
|
||||||
|
label = "[client]"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO Lock handshakeMutex
|
||||||
|
// TODO Remove CloseNotify hack
|
||||||
|
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
|
||||||
|
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
|
||||||
|
return c.handshakeAlert
|
||||||
|
}
|
||||||
|
if c.handshakeComplete {
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
var alert Alert
|
||||||
|
if c.hState == nil {
|
||||||
|
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
||||||
|
alert = c.HandshakeSetup()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
||||||
|
}
|
||||||
|
|
||||||
|
state := c.hState
|
||||||
|
_, connected := state.(StateConnected)
|
||||||
|
|
||||||
|
var actions []HandshakeAction
|
||||||
|
|
||||||
|
for !connected {
|
||||||
|
// Read a handshake message
|
||||||
|
hm, err := c.hIn.ReadMessage()
|
||||||
|
if err == WouldBlock {
|
||||||
|
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
||||||
|
return AlertWouldBlock
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
||||||
|
c.sendAlert(AlertCloseNotify)
|
||||||
|
return AlertCloseNotify
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||||
|
|
||||||
|
// Advance the state machine
|
||||||
|
state, actions, alert = state.Next(hm)
|
||||||
|
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, action := range actions {
|
||||||
|
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.hState = state
|
||||||
|
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||||||
|
|
||||||
|
_, connected = state.(StateConnected)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = state.(StateConnected)
|
||||||
|
|
||||||
|
// Send NewSessionTicket if acting as server
|
||||||
|
if !c.isClient && c.config.SendSessionTickets {
|
||||||
|
actions, alert := c.state.NewSessionTicket(
|
||||||
|
c.config.TicketLen,
|
||||||
|
c.config.TicketLifetime,
|
||||||
|
c.config.EarlyDataLifetime)
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handshakeComplete = true
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
||||||
|
if !c.handshakeComplete {
|
||||||
|
return fmt.Errorf("Cannot update keys until after handshake")
|
||||||
|
}
|
||||||
|
|
||||||
|
request := KeyUpdateNotRequested
|
||||||
|
if requestUpdate {
|
||||||
|
request = KeyUpdateRequested
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the key update and update state
|
||||||
|
actions, alert := c.state.KeyUpdate(request)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return fmt.Errorf("Alert while generating key update: %v", alert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take actions (send key update and rekey)
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return fmt.Errorf("Alert during key update actions: %v", alert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) GetHsState() string {
|
||||||
|
return reflect.TypeOf(c.hState).Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||||
|
_, connected := c.hState.(StateConnected)
|
||||||
|
if !connected {
|
||||||
|
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.state.exporterSecret == nil {
|
||||||
|
return nil, fmt.Errorf("Internal error: no exporter secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
|
||||||
|
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
|
||||||
|
|
||||||
|
hc := c.state.cryptoParams.Hash.New().Sum(context)
|
||||||
|
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) State() ConnectionState {
|
||||||
|
state := ConnectionState{
|
||||||
|
HandshakeState: c.GetHsState(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.handshakeComplete {
|
||||||
|
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||||||
|
state.NextProto = c.state.Params.NextProto
|
||||||
|
}
|
||||||
|
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
||||||
|
if c.hState != nil {
|
||||||
|
return fmt.Errorf("Can't set extension handler after setup")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.extHandler = h
|
||||||
|
return nil
|
||||||
|
}
|
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
@ -0,0 +1,654 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
|
||||||
|
// Blank includes to ensure hash support
|
||||||
|
_ "crypto/sha1"
|
||||||
|
_ "crypto/sha256"
|
||||||
|
_ "crypto/sha512"
|
||||||
|
)
|
||||||
|
|
||||||
|
var prng = rand.Reader
|
||||||
|
|
||||||
|
type aeadFactory func(key []byte) (cipher.AEAD, error)
|
||||||
|
|
||||||
|
type CipherSuiteParams struct {
|
||||||
|
Suite CipherSuite
|
||||||
|
Cipher aeadFactory // Cipher factory
|
||||||
|
Hash crypto.Hash // Hash function
|
||||||
|
KeyLen int // Key length in octets
|
||||||
|
IvLen int // IV length in octets
|
||||||
|
}
|
||||||
|
|
||||||
|
type signatureAlgorithm uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
signatureAlgorithmUnknown = iota
|
||||||
|
signatureAlgorithmRSA_PKCS1
|
||||||
|
signatureAlgorithmRSA_PSS
|
||||||
|
signatureAlgorithmECDSA
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
hashMap = map[SignatureScheme]crypto.Hash{
|
||||||
|
RSA_PKCS1_SHA1: crypto.SHA1,
|
||||||
|
RSA_PKCS1_SHA256: crypto.SHA256,
|
||||||
|
RSA_PKCS1_SHA384: crypto.SHA384,
|
||||||
|
RSA_PKCS1_SHA512: crypto.SHA512,
|
||||||
|
ECDSA_P256_SHA256: crypto.SHA256,
|
||||||
|
ECDSA_P384_SHA384: crypto.SHA384,
|
||||||
|
ECDSA_P521_SHA512: crypto.SHA512,
|
||||||
|
RSA_PSS_SHA256: crypto.SHA256,
|
||||||
|
RSA_PSS_SHA384: crypto.SHA384,
|
||||||
|
RSA_PSS_SHA512: crypto.SHA512,
|
||||||
|
}
|
||||||
|
|
||||||
|
sigMap = map[SignatureScheme]signatureAlgorithm{
|
||||||
|
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
|
||||||
|
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
|
||||||
|
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
|
||||||
|
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
|
||||||
|
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
|
||||||
|
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
|
||||||
|
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
|
||||||
|
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
|
||||||
|
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
|
||||||
|
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
|
||||||
|
}
|
||||||
|
|
||||||
|
curveMap = map[SignatureScheme]NamedGroup{
|
||||||
|
ECDSA_P256_SHA256: P256,
|
||||||
|
ECDSA_P384_SHA384: P384,
|
||||||
|
ECDSA_P521_SHA512: P521,
|
||||||
|
}
|
||||||
|
|
||||||
|
newAESGCM = func(key []byte) (cipher.AEAD, error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLS always uses 12-byte nonces
|
||||||
|
return cipher.NewGCMWithNonceSize(block, 12)
|
||||||
|
}
|
||||||
|
|
||||||
|
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
|
||||||
|
TLS_AES_128_GCM_SHA256: {
|
||||||
|
Suite: TLS_AES_128_GCM_SHA256,
|
||||||
|
Cipher: newAESGCM,
|
||||||
|
Hash: crypto.SHA256,
|
||||||
|
KeyLen: 16,
|
||||||
|
IvLen: 12,
|
||||||
|
},
|
||||||
|
TLS_AES_256_GCM_SHA384: {
|
||||||
|
Suite: TLS_AES_256_GCM_SHA384,
|
||||||
|
Cipher: newAESGCM,
|
||||||
|
Hash: crypto.SHA384,
|
||||||
|
KeyLen: 32,
|
||||||
|
IvLen: 12,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
|
||||||
|
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
|
||||||
|
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
|
||||||
|
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
|
||||||
|
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
|
||||||
|
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
|
||||||
|
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
|
||||||
|
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultRSAKeySize = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
|
||||||
|
switch group {
|
||||||
|
case P256:
|
||||||
|
crv = elliptic.P256()
|
||||||
|
case P384:
|
||||||
|
crv = elliptic.P384()
|
||||||
|
case P521:
|
||||||
|
crv = elliptic.P521()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
|
||||||
|
switch key.Curve.Params().Name {
|
||||||
|
case elliptic.P256().Params().Name:
|
||||||
|
g = P256
|
||||||
|
case elliptic.P384().Params().Name:
|
||||||
|
g = P384
|
||||||
|
case elliptic.P521().Params().Name:
|
||||||
|
g = P521
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
|
||||||
|
size = 0
|
||||||
|
switch group {
|
||||||
|
case X25519:
|
||||||
|
size = 32
|
||||||
|
case P256:
|
||||||
|
size = 65
|
||||||
|
case P384:
|
||||||
|
size = 97
|
||||||
|
case P521:
|
||||||
|
size = 133
|
||||||
|
case FFDHE2048:
|
||||||
|
size = 256
|
||||||
|
case FFDHE3072:
|
||||||
|
size = 384
|
||||||
|
case FFDHE4096:
|
||||||
|
size = 512
|
||||||
|
case FFDHE6144:
|
||||||
|
size = 768
|
||||||
|
case FFDHE8192:
|
||||||
|
size = 1024
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
|
||||||
|
switch group {
|
||||||
|
case FFDHE2048:
|
||||||
|
p = finiteFieldPrime2048
|
||||||
|
case FFDHE3072:
|
||||||
|
p = finiteFieldPrime3072
|
||||||
|
case FFDHE4096:
|
||||||
|
p = finiteFieldPrime4096
|
||||||
|
case FFDHE6144:
|
||||||
|
p = finiteFieldPrime6144
|
||||||
|
case FFDHE8192:
|
||||||
|
p = finiteFieldPrime8192
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
|
||||||
|
sigType := sigMap[alg]
|
||||||
|
switch key.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return sigType == signatureAlgorithmECDSA
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
|
||||||
|
primeLen := len(p.Bytes())
|
||||||
|
for {
|
||||||
|
// g = 2 for all ffdhe groups
|
||||||
|
priv, err = rand.Int(prng, p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pub = big.NewInt(0)
|
||||||
|
pub.Exp(big.NewInt(2), priv, p)
|
||||||
|
|
||||||
|
if len(pub.Bytes()) == primeLen {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
|
||||||
|
switch group {
|
||||||
|
case P256, P384, P521:
|
||||||
|
var x, y *big.Int
|
||||||
|
crv := curveFromNamedGroup(group)
|
||||||
|
priv, x, y, err = elliptic.GenerateKey(crv, prng)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pub = elliptic.Marshal(crv, x, y)
|
||||||
|
return
|
||||||
|
|
||||||
|
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||||
|
p := primeFromNamedGroup(group)
|
||||||
|
x, X, err2 := ffdheKeyShareFromPrime(p)
|
||||||
|
if err2 != nil {
|
||||||
|
err = err2
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
priv = x.Bytes()
|
||||||
|
pubBytes := X.Bytes()
|
||||||
|
|
||||||
|
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||||
|
|
||||||
|
pub = make([]byte, numBytes)
|
||||||
|
copy(pub[numBytes-len(pubBytes):], pubBytes)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
case X25519:
|
||||||
|
var private, public [32]byte
|
||||||
|
_, err = prng.Read(private[:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
curve25519.ScalarBaseMult(&public, &private)
|
||||||
|
priv = private[:]
|
||||||
|
pub = public[:]
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
|
||||||
|
switch group {
|
||||||
|
case P256, P384, P521:
|
||||||
|
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
crv := curveFromNamedGroup(group)
|
||||||
|
pubX, pubY := elliptic.Unmarshal(crv, pub)
|
||||||
|
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
|
||||||
|
xBytes := x.Bytes()
|
||||||
|
|
||||||
|
numBytes := len(crv.Params().P.Bytes())
|
||||||
|
|
||||||
|
ret := make([]byte, numBytes)
|
||||||
|
copy(ret[numBytes-len(xBytes):], xBytes)
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
|
||||||
|
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||||
|
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||||
|
if len(pub) != numBytes {
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||||
|
}
|
||||||
|
p := primeFromNamedGroup(group)
|
||||||
|
x := big.NewInt(0).SetBytes(priv)
|
||||||
|
Y := big.NewInt(0).SetBytes(pub)
|
||||||
|
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
|
||||||
|
|
||||||
|
ret := make([]byte, numBytes)
|
||||||
|
copy(ret[numBytes-len(ZBytes):], ZBytes)
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
|
||||||
|
case X25519:
|
||||||
|
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
var private, public, ret [32]byte
|
||||||
|
copy(private[:], priv)
|
||||||
|
copy(public[:], pub)
|
||||||
|
curve25519.ScalarMult(&ret, &private, &public)
|
||||||
|
|
||||||
|
return ret[:], nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
|
||||||
|
switch sig {
|
||||||
|
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
|
||||||
|
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
|
||||||
|
RSA_PSS_SHA256, RSA_PSS_SHA384,
|
||||||
|
RSA_PSS_SHA512:
|
||||||
|
return rsa.GenerateKey(prng, defaultRSAKeySize)
|
||||||
|
case ECDSA_P256_SHA256:
|
||||||
|
return ecdsa.GenerateKey(elliptic.P256(), prng)
|
||||||
|
case ECDSA_P384_SHA384:
|
||||||
|
return ecdsa.GenerateKey(elliptic.P384(), prng)
|
||||||
|
case ECDSA_P521_SHA512:
|
||||||
|
return ecdsa.GenerateKey(elliptic.P521(), prng)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||||
|
sigAlg, ok := x509AlgMap[alg]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||||
|
}
|
||||||
|
if len(name) == 0 {
|
||||||
|
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||||
|
SignatureAlgorithm: sigAlg,
|
||||||
|
Subject: pkix.Name{CommonName: name},
|
||||||
|
DNSNames: []string{name},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is safe to ignore the error here because we're parsing known-good data
|
||||||
|
cert, _ := x509.ParseCertificate(der)
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX(rlb): Copied from crypto/x509
|
||||||
|
type ecdsaSignature struct {
|
||||||
|
R, S *big.Int
|
||||||
|
}
|
||||||
|
|
||||||
|
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
|
||||||
|
var opts crypto.SignerOpts
|
||||||
|
|
||||||
|
hash := hashMap[alg]
|
||||||
|
if hash == crypto.SHA1 {
|
||||||
|
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||||
|
}
|
||||||
|
|
||||||
|
sigType := sigMap[alg]
|
||||||
|
var realInput []byte
|
||||||
|
switch key := privateKey.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
switch {
|
||||||
|
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
|
||||||
|
opts = hash
|
||||||
|
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
fallthrough
|
||||||
|
case sigType == signatureAlgorithmRSA_PSS:
|
||||||
|
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
|
||||||
|
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput = h.Sum(nil)
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
if sigType != signatureAlgorithmECDSA {
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
algGroup := curveMap[alg]
|
||||||
|
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
|
||||||
|
if algGroup != keyGroup {
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput = h.Sum(nil)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
|
||||||
|
}
|
||||||
|
|
||||||
|
sig, err := privateKey.Sign(prng, realInput, opts)
|
||||||
|
logf(logTypeCrypto, "signature: %x", sig)
|
||||||
|
return sig, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
|
||||||
|
hash := hashMap[alg]
|
||||||
|
|
||||||
|
if hash == crypto.SHA1 {
|
||||||
|
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||||
|
}
|
||||||
|
|
||||||
|
sigType := sigMap[alg]
|
||||||
|
switch pub := publicKey.(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
switch {
|
||||||
|
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput := h.Sum(nil)
|
||||||
|
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
|
||||||
|
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
fallthrough
|
||||||
|
case sigType == signatureAlgorithmRSA_PSS:
|
||||||
|
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
|
||||||
|
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput := h.Sum(nil)
|
||||||
|
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
if sigType != signatureAlgorithmECDSA {
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
ecdsaSig := new(ecdsaSignature)
|
||||||
|
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
|
||||||
|
return err
|
||||||
|
} else if len(rest) != 0 {
|
||||||
|
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
|
||||||
|
}
|
||||||
|
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||||
|
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput := h.Sum(nil)
|
||||||
|
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
|
||||||
|
return fmt.Errorf("tls.verify: ECDSA verification failure")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported key type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// PSK -> HKDF-Extract = Early Secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(.,
|
||||||
|
// | "ext binder" |
|
||||||
|
// | "res binder",
|
||||||
|
// | "")
|
||||||
|
// | = binder_key
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "c e traffic",
|
||||||
|
// | ClientHello)
|
||||||
|
// | = client_early_traffic_secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "e exp master",
|
||||||
|
// | ClientHello)
|
||||||
|
// | = early_exporter_master_secret
|
||||||
|
// v
|
||||||
|
// Derive-Secret(., "derived", "")
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// (EC)DHE -> HKDF-Extract = Handshake Secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "c hs traffic",
|
||||||
|
// | ClientHello...ServerHello)
|
||||||
|
// | = client_handshake_traffic_secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "s hs traffic",
|
||||||
|
// | ClientHello...ServerHello)
|
||||||
|
// | = server_handshake_traffic_secret
|
||||||
|
// v
|
||||||
|
// Derive-Secret(., "derived", "")
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// 0 -> HKDF-Extract = Master Secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "c ap traffic",
|
||||||
|
// | ClientHello...server Finished)
|
||||||
|
// | = client_application_traffic_secret_0
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "s ap traffic",
|
||||||
|
// | ClientHello...server Finished)
|
||||||
|
// | = server_application_traffic_secret_0
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "exp master",
|
||||||
|
// | ClientHello...server Finished)
|
||||||
|
// | = exporter_master_secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "res master",
|
||||||
|
// ClientHello...client Finished)
|
||||||
|
// = resumption_master_secret
|
||||||
|
|
||||||
|
// From RFC 5869
|
||||||
|
// PRK = HMAC-Hash(salt, IKM)
|
||||||
|
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
|
||||||
|
salt := saltIn
|
||||||
|
|
||||||
|
// if [salt is] not provided, it is set to a string of HashLen zeros
|
||||||
|
if salt == nil {
|
||||||
|
salt = bytes.Repeat([]byte{0}, hash.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hmac.New(hash.New, salt)
|
||||||
|
h.Write(input)
|
||||||
|
out := h.Sum(nil)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "HKDF Extract:\n")
|
||||||
|
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
|
||||||
|
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
|
||||||
|
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
labelExternalBinder = "ext binder"
|
||||||
|
labelResumptionBinder = "res binder"
|
||||||
|
labelEarlyTrafficSecret = "c e traffic"
|
||||||
|
labelEarlyExporterSecret = "e exp master"
|
||||||
|
labelClientHandshakeTrafficSecret = "c hs traffic"
|
||||||
|
labelServerHandshakeTrafficSecret = "s hs traffic"
|
||||||
|
labelClientApplicationTrafficSecret = "c ap traffic"
|
||||||
|
labelServerApplicationTrafficSecret = "s ap traffic"
|
||||||
|
labelExporterSecret = "exp master"
|
||||||
|
labelResumptionSecret = "res master"
|
||||||
|
labelDerived = "derived"
|
||||||
|
labelFinished = "finished"
|
||||||
|
labelResumption = "resumption"
|
||||||
|
)
|
||||||
|
|
||||||
|
// struct HkdfLabel {
|
||||||
|
// uint16 length;
|
||||||
|
// opaque label<9..255>;
|
||||||
|
// opaque hash_value<0..255>;
|
||||||
|
// };
|
||||||
|
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
|
||||||
|
label := "tls13 " + labelIn
|
||||||
|
|
||||||
|
labelLen := len(label)
|
||||||
|
hashLen := len(hashValue)
|
||||||
|
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
|
||||||
|
hkdfLabel[0] = byte(outLen >> 8)
|
||||||
|
hkdfLabel[1] = byte(outLen)
|
||||||
|
hkdfLabel[2] = byte(labelLen)
|
||||||
|
copy(hkdfLabel[3:3+labelLen], []byte(label))
|
||||||
|
hkdfLabel[3+labelLen] = byte(hashLen)
|
||||||
|
copy(hkdfLabel[3+labelLen+1:], hashValue)
|
||||||
|
|
||||||
|
return hkdfLabel
|
||||||
|
}
|
||||||
|
|
||||||
|
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
|
||||||
|
out := []byte{}
|
||||||
|
T := []byte{}
|
||||||
|
i := byte(1)
|
||||||
|
for len(out) < outLen {
|
||||||
|
block := append(T, info...)
|
||||||
|
block = append(block, i)
|
||||||
|
|
||||||
|
h := hmac.New(hash.New, prk)
|
||||||
|
h.Write(block)
|
||||||
|
|
||||||
|
T = h.Sum(nil)
|
||||||
|
out = append(out, T...)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return out[:outLen]
|
||||||
|
}
|
||||||
|
|
||||||
|
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
|
||||||
|
info := hkdfEncodeLabel(label, hashValue, outLen)
|
||||||
|
derived := HkdfExpand(hash, secret, info, outLen)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
|
||||||
|
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
|
||||||
|
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
|
||||||
|
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
|
||||||
|
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
|
||||||
|
|
||||||
|
return derived
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
|
||||||
|
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
|
||||||
|
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
|
||||||
|
mac := hmac.New(params.Hash.New, macKey)
|
||||||
|
mac.Write(input)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type keySet struct {
|
||||||
|
cipher aeadFactory
|
||||||
|
key []byte
|
||||||
|
iv []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
|
||||||
|
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
|
||||||
|
return keySet{
|
||||||
|
cipher: params.Cipher,
|
||||||
|
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
|
||||||
|
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
||||||
|
}
|
||||||
|
}
|
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
@ -0,0 +1,586 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint/syntax"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ExtensionBody interface {
|
||||||
|
Type() ExtensionType
|
||||||
|
Marshal() ([]byte, error)
|
||||||
|
Unmarshal(data []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ExtensionType extension_type;
|
||||||
|
// opaque extension_data<0..2^16-1>;
|
||||||
|
// } Extension;
|
||||||
|
type Extension struct {
|
||||||
|
ExtensionType ExtensionType
|
||||||
|
ExtensionData []byte `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext Extension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *Extension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExtensionList []Extension
|
||||||
|
|
||||||
|
type extensionListInner struct {
|
||||||
|
List []Extension `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el ExtensionList) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(extensionListInner{el})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
|
||||||
|
var list extensionListInner
|
||||||
|
read, err := syntax.Unmarshal(data, &list)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
*el = list.List
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *ExtensionList) Add(src ExtensionBody) error {
|
||||||
|
data, err := src.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if el == nil {
|
||||||
|
el = new(ExtensionList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If one already exists with this type, replace it
|
||||||
|
for i := range *el {
|
||||||
|
if (*el)[i].ExtensionType == src.Type() {
|
||||||
|
(*el)[i].ExtensionData = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise append
|
||||||
|
*el = append(*el, Extension{
|
||||||
|
ExtensionType: src.Type(),
|
||||||
|
ExtensionData: data,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el ExtensionList) Find(dst ExtensionBody) bool {
|
||||||
|
for _, ext := range el {
|
||||||
|
if ext.ExtensionType == dst.Type() {
|
||||||
|
_, err := dst.Unmarshal(ext.ExtensionData)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// NameType name_type;
|
||||||
|
// select (name_type) {
|
||||||
|
// case host_name: HostName;
|
||||||
|
// } name;
|
||||||
|
// } ServerName;
|
||||||
|
//
|
||||||
|
// enum {
|
||||||
|
// host_name(0), (255)
|
||||||
|
// } NameType;
|
||||||
|
//
|
||||||
|
// opaque HostName<1..2^16-1>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ServerName server_name_list<1..2^16-1>
|
||||||
|
// } ServerNameList;
|
||||||
|
//
|
||||||
|
// But we only care about the case where there's a single DNS hostname. We
|
||||||
|
// will never create anything else, and throw if we receive something else
|
||||||
|
//
|
||||||
|
// 2 1 2
|
||||||
|
// | listLen | NameType | nameLen | name |
|
||||||
|
type ServerNameExtension string
|
||||||
|
|
||||||
|
type serverNameInner struct {
|
||||||
|
NameType uint8
|
||||||
|
HostName []byte `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverNameListInner struct {
|
||||||
|
ServerNameList []serverNameInner `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sni ServerNameExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sni ServerNameExtension) Marshal() ([]byte, error) {
|
||||||
|
list := serverNameListInner{
|
||||||
|
ServerNameList: []serverNameInner{{
|
||||||
|
NameType: 0x00, // host_name
|
||||||
|
HostName: []byte(sni),
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
var list serverNameListInner
|
||||||
|
read, err := syntax.Unmarshal(data, &list)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Syntax requires at least one entry
|
||||||
|
// Entries beyond the first are ignored
|
||||||
|
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
|
||||||
|
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
|
||||||
|
}
|
||||||
|
|
||||||
|
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// NamedGroup group;
|
||||||
|
// opaque key_exchange<1..2^16-1>;
|
||||||
|
// } KeyShareEntry;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// select (Handshake.msg_type) {
|
||||||
|
// case client_hello:
|
||||||
|
// KeyShareEntry client_shares<0..2^16-1>;
|
||||||
|
//
|
||||||
|
// case hello_retry_request:
|
||||||
|
// NamedGroup selected_group;
|
||||||
|
//
|
||||||
|
// case server_hello:
|
||||||
|
// KeyShareEntry server_share;
|
||||||
|
// };
|
||||||
|
// } KeyShare;
|
||||||
|
type KeyShareEntry struct {
|
||||||
|
Group NamedGroup
|
||||||
|
KeyExchange []byte `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kse KeyShareEntry) SizeValid() bool {
|
||||||
|
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyShareExtension struct {
|
||||||
|
HandshakeType HandshakeType
|
||||||
|
SelectedGroup NamedGroup
|
||||||
|
Shares []KeyShareEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyShareClientHelloInner struct {
|
||||||
|
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
|
||||||
|
}
|
||||||
|
type KeyShareHelloRetryInner struct {
|
||||||
|
SelectedGroup NamedGroup
|
||||||
|
}
|
||||||
|
type KeyShareServerHelloInner struct {
|
||||||
|
ServerShare KeyShareEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks KeyShareExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeKeyShare
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks KeyShareExtension) Marshal() ([]byte, error) {
|
||||||
|
switch ks.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
for _, share := range ks.Shares {
|
||||||
|
if !share.SizeValid() {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
|
||||||
|
|
||||||
|
case HandshakeTypeHelloRetryRequest:
|
||||||
|
if len(ks.Shares) > 0 {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
if len(ks.Shares) != 1 {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ks.Shares[0].SizeValid() {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
switch ks.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
var inner KeyShareClientHelloInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, share := range inner.ClientShares {
|
||||||
|
if !share.SizeValid() {
|
||||||
|
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Shares = inner.ClientShares
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
case HandshakeTypeHelloRetryRequest:
|
||||||
|
var inner KeyShareHelloRetryInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.SelectedGroup = inner.SelectedGroup
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
var inner KeyShareServerHelloInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inner.ServerShare.SizeValid() {
|
||||||
|
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Shares = []KeyShareEntry{inner.ServerShare}
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// NamedGroup named_group_list<2..2^16-1>;
|
||||||
|
// } NamedGroupList;
|
||||||
|
type SupportedGroupsExtension struct {
|
||||||
|
Groups []NamedGroup `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sg SupportedGroupsExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeSupportedGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
|
||||||
|
// } SignatureSchemeList
|
||||||
|
type SignatureAlgorithmsExtension struct {
|
||||||
|
Algorithms []SignatureScheme `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeSignatureAlgorithms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque identity<1..2^16-1>;
|
||||||
|
// uint32 obfuscated_ticket_age;
|
||||||
|
// } PskIdentity;
|
||||||
|
//
|
||||||
|
// opaque PskBinderEntry<32..255>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// select (Handshake.msg_type) {
|
||||||
|
// case client_hello:
|
||||||
|
// PskIdentity identities<7..2^16-1>;
|
||||||
|
// PskBinderEntry binders<33..2^16-1>;
|
||||||
|
//
|
||||||
|
// case server_hello:
|
||||||
|
// uint16 selected_identity;
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// } PreSharedKeyExtension;
|
||||||
|
type PSKIdentity struct {
|
||||||
|
Identity []byte `tls:"head=2,min=1"`
|
||||||
|
ObfuscatedTicketAge uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type PSKBinderEntry struct {
|
||||||
|
Binder []byte `tls:"head=1,min=32"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreSharedKeyExtension struct {
|
||||||
|
HandshakeType HandshakeType
|
||||||
|
Identities []PSKIdentity
|
||||||
|
Binders []PSKBinderEntry
|
||||||
|
SelectedIdentity uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type preSharedKeyClientInner struct {
|
||||||
|
Identities []PSKIdentity `tls:"head=2,min=7"`
|
||||||
|
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type preSharedKeyServerInner struct {
|
||||||
|
SelectedIdentity uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk PreSharedKeyExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypePreSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
|
||||||
|
switch psk.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
return syntax.Marshal(preSharedKeyClientInner{
|
||||||
|
Identities: psk.Identities,
|
||||||
|
Binders: psk.Binders,
|
||||||
|
})
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
|
||||||
|
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
|
||||||
|
}
|
||||||
|
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
switch psk.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
var inner preSharedKeyClientInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inner.Identities) != len(inner.Binders) {
|
||||||
|
return 0, fmt.Errorf("Lengths of identities and binders not equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
psk.Identities = inner.Identities
|
||||||
|
psk.Binders = inner.Binders
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
var inner preSharedKeyServerInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
psk.SelectedIdentity = inner.SelectedIdentity
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
|
||||||
|
for i, localID := range psk.Identities {
|
||||||
|
if bytes.Equal(localID.Identity, id) {
|
||||||
|
return psk.Binders[i].Binder, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// PskKeyExchangeMode ke_modes<1..255>;
|
||||||
|
// } PskKeyExchangeModes;
|
||||||
|
type PSKKeyExchangeModesExtension struct {
|
||||||
|
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypePSKKeyExchangeModes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(pkem)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, pkem)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// } EarlyDataIndication;
|
||||||
|
|
||||||
|
type EarlyDataExtension struct{}
|
||||||
|
|
||||||
|
func (ed EarlyDataExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeEarlyData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// uint32 max_early_data_size;
|
||||||
|
// } TicketEarlyDataInfo;
|
||||||
|
|
||||||
|
type TicketEarlyDataInfoExtension struct {
|
||||||
|
MaxEarlyDataSize uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeTicketEarlyDataInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(tedi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, tedi)
|
||||||
|
}
|
||||||
|
|
||||||
|
// opaque ProtocolName<1..2^8-1>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ProtocolName protocol_name_list<2..2^16-1>
|
||||||
|
// } ProtocolNameList;
|
||||||
|
type ALPNExtension struct {
|
||||||
|
Protocols []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type protocolNameInner struct {
|
||||||
|
Name []byte `tls:"head=1,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type alpnExtensionInner struct {
|
||||||
|
Protocols []protocolNameInner `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alpn ALPNExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeALPN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alpn ALPNExtension) Marshal() ([]byte, error) {
|
||||||
|
protocols := make([]protocolNameInner, len(alpn.Protocols))
|
||||||
|
for i, protocol := range alpn.Protocols {
|
||||||
|
protocols[i] = protocolNameInner{[]byte(protocol)}
|
||||||
|
}
|
||||||
|
return syntax.Marshal(alpnExtensionInner{protocols})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
var inner alpnExtensionInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
alpn.Protocols = make([]string, len(inner.Protocols))
|
||||||
|
for i, protocol := range inner.Protocols {
|
||||||
|
alpn.Protocols[i] = string(protocol.Name)
|
||||||
|
}
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion versions<2..254>;
|
||||||
|
// } SupportedVersions;
|
||||||
|
type SupportedVersionsExtension struct {
|
||||||
|
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeSupportedVersions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque cookie<1..2^16-1>;
|
||||||
|
// } Cookie;
|
||||||
|
type CookieExtension struct {
|
||||||
|
Cookie []byte `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CookieExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeCookie
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CookieExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultCookieLength is the default length of a cookie
|
||||||
|
const defaultCookieLength = 32
|
||||||
|
|
||||||
|
type defaultCookieHandler struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ CookieHandler = &defaultCookieHandler{}
|
||||||
|
|
||||||
|
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
||||||
|
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
|
||||||
|
h.data = make([]byte, defaultCookieLength)
|
||||||
|
if _, err := prng.Read(h.data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
|
||||||
|
return bytes.Equal(h.data, data)
|
||||||
|
}
|
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"math/big"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B423861285C97FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
|
||||||
|
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
|
||||||
|
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||||
|
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||||
|
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||||
|
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||||
|
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||||
|
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
|
||||||
|
"FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
|
||||||
|
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||||
|
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||||
|
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||||
|
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||||
|
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||||
|
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||||
|
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||||
|
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||||
|
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||||
|
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||||
|
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||||
|
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||||
|
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||||
|
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||||
|
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||||
|
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||||
|
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
|
||||||
|
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||||
|
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||||
|
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||||
|
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||||
|
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||||
|
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||||
|
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||||
|
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||||
|
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||||
|
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||||
|
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||||
|
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||||
|
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||||
|
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||||
|
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||||
|
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||||
|
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
|
||||||
|
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
|
||||||
|
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
|
||||||
|
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
|
||||||
|
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
|
||||||
|
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
|
||||||
|
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
|
||||||
|
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
|
||||||
|
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
|
||||||
|
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
|
||||||
|
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
|
||||||
|
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
|
||||||
|
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
|
||||||
|
)
|
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// Read a generic "framed" packet consisting of a header and a
|
||||||
|
// This is used for both TLS Records and TLS Handshake Messages
|
||||||
|
package mint
|
||||||
|
|
||||||
|
type framing interface {
|
||||||
|
headerLen() int
|
||||||
|
defaultReadLen() int
|
||||||
|
frameLen(hdr []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
kFrameReaderHdr = 0
|
||||||
|
kFrameReaderBody = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type frameNextAction func(f *frameReader) error
|
||||||
|
|
||||||
|
type frameReader struct {
|
||||||
|
details framing
|
||||||
|
state uint8
|
||||||
|
header []byte
|
||||||
|
body []byte
|
||||||
|
working []byte
|
||||||
|
writeOffset int
|
||||||
|
remainder []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFrameReader(d framing) *frameReader {
|
||||||
|
hdr := make([]byte, d.headerLen())
|
||||||
|
return &frameReader{
|
||||||
|
d,
|
||||||
|
kFrameReaderHdr,
|
||||||
|
hdr,
|
||||||
|
nil,
|
||||||
|
hdr,
|
||||||
|
0,
|
||||||
|
nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dup(a []byte) []byte {
|
||||||
|
r := make([]byte, len(a))
|
||||||
|
copy(r, a)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *frameReader) needed() int {
|
||||||
|
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
|
||||||
|
if tmp < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *frameReader) addChunk(in []byte) {
|
||||||
|
// Append to the buffer.
|
||||||
|
logf(logTypeFrameReader, "Appending %v", len(in))
|
||||||
|
f.remainder = append(f.remainder, in...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
||||||
|
for f.needed() == 0 {
|
||||||
|
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
|
||||||
|
// Fill out our working block
|
||||||
|
copied := copy(f.working[f.writeOffset:], f.remainder)
|
||||||
|
f.remainder = f.remainder[copied:]
|
||||||
|
f.writeOffset += copied
|
||||||
|
if f.writeOffset < len(f.working) {
|
||||||
|
logf(logTypeFrameReader, "Read would have blocked 1")
|
||||||
|
return nil, nil, WouldBlock
|
||||||
|
}
|
||||||
|
// Reset the write offset, because we are now full.
|
||||||
|
f.writeOffset = 0
|
||||||
|
|
||||||
|
// We have read a full frame
|
||||||
|
if f.state == kFrameReaderBody {
|
||||||
|
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
|
||||||
|
f.state = kFrameReaderHdr
|
||||||
|
f.working = f.header
|
||||||
|
return dup(f.header), dup(f.body), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have read the header
|
||||||
|
bodyLen, err := f.details.frameLen(f.header)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
|
||||||
|
|
||||||
|
f.body = make([]byte, bodyLen)
|
||||||
|
f.working = f.body
|
||||||
|
f.writeOffset = 0
|
||||||
|
f.state = kFrameReaderBody
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeFrameReader, "Read would have blocked 2")
|
||||||
|
return nil, nil, WouldBlock
|
||||||
|
}
|
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
handshakeHeaderLen = 4 // handshake message header length
|
||||||
|
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||||
|
)
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// HandshakeType msg_type; /* handshake type */
|
||||||
|
// uint24 length; /* bytes in message */
|
||||||
|
// select (HandshakeType) {
|
||||||
|
// ...
|
||||||
|
// } body;
|
||||||
|
// } Handshake;
|
||||||
|
//
|
||||||
|
// We do the select{...} part in a different layer, so we treat the
|
||||||
|
// actual message body as opaque:
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// HandshakeType msg_type;
|
||||||
|
// opaque msg<0..2^24-1>
|
||||||
|
// } Handshake;
|
||||||
|
//
|
||||||
|
// TODO: File a spec bug
|
||||||
|
type HandshakeMessage struct {
|
||||||
|
// Omitted: length
|
||||||
|
msgType HandshakeType
|
||||||
|
body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: This could be done with the `syntax` module, using the simplified
|
||||||
|
// syntax as discussed above. However, since this is so simple, there's not
|
||||||
|
// much benefit to doing so.
|
||||||
|
func (hm *HandshakeMessage) Marshal() []byte {
|
||||||
|
if hm == nil {
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLen := len(hm.body)
|
||||||
|
data := make([]byte, 4+len(hm.body))
|
||||||
|
data[0] = byte(hm.msgType)
|
||||||
|
data[1] = byte(msgLen >> 16)
|
||||||
|
data[2] = byte(msgLen >> 8)
|
||||||
|
data[3] = byte(msgLen)
|
||||||
|
copy(data[4:], hm.body)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||||
|
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
|
||||||
|
|
||||||
|
var body HandshakeMessageBody
|
||||||
|
switch hm.msgType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
body = new(ClientHelloBody)
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
body = new(ServerHelloBody)
|
||||||
|
case HandshakeTypeHelloRetryRequest:
|
||||||
|
body = new(HelloRetryRequestBody)
|
||||||
|
case HandshakeTypeEncryptedExtensions:
|
||||||
|
body = new(EncryptedExtensionsBody)
|
||||||
|
case HandshakeTypeCertificate:
|
||||||
|
body = new(CertificateBody)
|
||||||
|
case HandshakeTypeCertificateRequest:
|
||||||
|
body = new(CertificateRequestBody)
|
||||||
|
case HandshakeTypeCertificateVerify:
|
||||||
|
body = new(CertificateVerifyBody)
|
||||||
|
case HandshakeTypeFinished:
|
||||||
|
body = &FinishedBody{VerifyDataLen: len(hm.body)}
|
||||||
|
case HandshakeTypeNewSessionTicket:
|
||||||
|
body = new(NewSessionTicketBody)
|
||||||
|
case HandshakeTypeKeyUpdate:
|
||||||
|
body = new(KeyUpdateBody)
|
||||||
|
case HandshakeTypeEndOfEarlyData:
|
||||||
|
body = new(EndOfEarlyDataBody)
|
||||||
|
default:
|
||||||
|
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := body.Unmarshal(hm.body)
|
||||||
|
return body, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||||
|
data, err := body.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HandshakeMessage{
|
||||||
|
msgType: body.Type(),
|
||||||
|
body: data,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandshakeLayer struct {
|
||||||
|
nonblocking bool // Should we operate in nonblocking mode
|
||||||
|
conn *RecordLayer // Used for reading/writing records
|
||||||
|
frame *frameReader // The buffered frame reader
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshakeLayerFrameDetails struct{}
|
||||||
|
|
||||||
|
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||||
|
return handshakeHeaderLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||||
|
return handshakeHeaderLen + maxFragmentLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||||
|
logf(logTypeIO, "Header=%x", hdr)
|
||||||
|
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
||||||
|
h := HandshakeLayer{}
|
||||||
|
h.conn = r
|
||||||
|
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
||||||
|
return &h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) readRecord() error {
|
||||||
|
logf(logTypeIO, "Trying to read record")
|
||||||
|
pt, err := h.conn.ReadRecord()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pt.contentType != RecordTypeHandshake &&
|
||||||
|
pt.contentType != RecordTypeAlert {
|
||||||
|
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pt.contentType == RecordTypeAlert {
|
||||||
|
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||||
|
if len(pt.fragment) < 2 {
|
||||||
|
h.sendAlert(AlertUnexpectedMessage)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return Alert(pt.fragment[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
||||||
|
h.frame.addChunk(pt.fragment)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendAlert sends a TLS alert message.
|
||||||
|
func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||||||
|
tmp := make([]byte, 2)
|
||||||
|
tmp[0] = AlertLevelError
|
||||||
|
tmp[1] = byte(err)
|
||||||
|
h.conn.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeAlert,
|
||||||
|
fragment: tmp},
|
||||||
|
)
|
||||||
|
|
||||||
|
// closeNotify is a special case in that it isn't an error:
|
||||||
|
if err != AlertCloseNotify {
|
||||||
|
return &net.OpError{Op: "local error", Err: err}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||||
|
var hdr, body []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for {
|
||||||
|
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||||
|
if h.frame.needed() > 0 {
|
||||||
|
logf(logTypeHandshake, "Trying to read a new record")
|
||||||
|
err = h.readRecord()
|
||||||
|
}
|
||||||
|
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr, body, err = h.frame.process()
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "read handshake message")
|
||||||
|
|
||||||
|
hm := &HandshakeMessage{}
|
||||||
|
hm.msgType = HandshakeType(hdr[0])
|
||||||
|
|
||||||
|
hm.body = make([]byte, len(body))
|
||||||
|
copy(hm.body, body)
|
||||||
|
|
||||||
|
return hm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
||||||
|
return h.WriteMessages([]*HandshakeMessage{hm})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
||||||
|
for _, hm := range hms {
|
||||||
|
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write out headers and bodies
|
||||||
|
buffer := []byte{}
|
||||||
|
for _, msg := range hms {
|
||||||
|
msgLen := len(msg.body)
|
||||||
|
if msgLen > maxHandshakeMessageLen {
|
||||||
|
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer = append(buffer, msg.Marshal()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send full-size fragments
|
||||||
|
var start int
|
||||||
|
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||||
|
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeHandshake,
|
||||||
|
fragment: buffer[start : start+maxFragmentLen],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a final partial fragment if necessary
|
||||||
|
if start < len(buffer) {
|
||||||
|
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeHandshake,
|
||||||
|
fragment: buffer[start:],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
@ -0,0 +1,450 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint/syntax"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandshakeMessageBody interface {
|
||||||
|
Type() HandshakeType
|
||||||
|
Marshal() ([]byte, error)
|
||||||
|
Unmarshal(data []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
||||||
|
// Random random;
|
||||||
|
// opaque legacy_session_id<0..32>;
|
||||||
|
// CipherSuite cipher_suites<2..2^16-2>;
|
||||||
|
// opaque legacy_compression_methods<1..2^8-1>;
|
||||||
|
// Extension extensions<0..2^16-1>;
|
||||||
|
// } ClientHello;
|
||||||
|
type ClientHelloBody struct {
|
||||||
|
// Omitted: clientVersion
|
||||||
|
// Omitted: legacySessionID
|
||||||
|
// Omitted: legacyCompressionMethods
|
||||||
|
Random [32]byte
|
||||||
|
CipherSuites []CipherSuite
|
||||||
|
Extensions ExtensionList
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientHelloBodyInner struct {
|
||||||
|
LegacyVersion uint16
|
||||||
|
Random [32]byte
|
||||||
|
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||||
|
CipherSuites []CipherSuite `tls:"head=2,min=2"`
|
||||||
|
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
|
||||||
|
Extensions []Extension `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch ClientHelloBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeClientHello
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(clientHelloBodyInner{
|
||||||
|
LegacyVersion: 0x0303,
|
||||||
|
Random: ch.Random,
|
||||||
|
LegacySessionID: []byte{},
|
||||||
|
CipherSuites: ch.CipherSuites,
|
||||||
|
LegacyCompressionMethods: []byte{0},
|
||||||
|
Extensions: ch.Extensions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
var inner clientHelloBodyInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are strict about these things because we only support 1.3
|
||||||
|
if inner.LegacyVersion != 0x0303 {
|
||||||
|
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||||
|
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.Random = inner.Random
|
||||||
|
ch.CipherSuites = inner.CipherSuites
|
||||||
|
ch.Extensions = inner.Extensions
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: File a spec bug to clarify this
|
||||||
|
func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
||||||
|
if len(ch.Extensions) == 0 {
|
||||||
|
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
pskExt := ch.Extensions[len(ch.Extensions)-1]
|
||||||
|
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
|
||||||
|
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
||||||
|
}
|
||||||
|
|
||||||
|
chm, err := HandshakeMessageFromBody(&ch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
chData := chm.Marshal()
|
||||||
|
|
||||||
|
psk := PreSharedKeyExtension{
|
||||||
|
HandshakeType: HandshakeTypeClientHello,
|
||||||
|
}
|
||||||
|
_, err = psk.Unmarshal(pskExt.ExtensionData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal just the binders so that we know how much to truncate
|
||||||
|
binders := struct {
|
||||||
|
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||||
|
}{Binders: psk.Binders}
|
||||||
|
binderData, _ := syntax.Marshal(binders)
|
||||||
|
binderLen := len(binderData)
|
||||||
|
|
||||||
|
chLen := len(chData)
|
||||||
|
return chData[:chLen-binderLen], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion server_version;
|
||||||
|
// CipherSuite cipher_suite;
|
||||||
|
// Extension extensions<2..2^16-1>;
|
||||||
|
// } HelloRetryRequest;
|
||||||
|
type HelloRetryRequestBody struct {
|
||||||
|
Version uint16
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
Extensions ExtensionList `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hrr HelloRetryRequestBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeHelloRetryRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(hrr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, hrr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion version;
|
||||||
|
// Random random;
|
||||||
|
// CipherSuite cipher_suite;
|
||||||
|
// Extension extensions<0..2^16-1>;
|
||||||
|
// } ServerHello;
|
||||||
|
type ServerHelloBody struct {
|
||||||
|
Version uint16
|
||||||
|
Random [32]byte
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh ServerHelloBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeServerHello
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh ServerHelloBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque verify_data[verify_data_length];
|
||||||
|
// } Finished;
|
||||||
|
//
|
||||||
|
// verifyDataLen is not a field in the TLS struct, but we add it here so
|
||||||
|
// that calling code can tell us how much data to expect when we marshal /
|
||||||
|
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
|
||||||
|
// try to keep the signature consistent for now.)
|
||||||
|
//
|
||||||
|
// For similar reasons, we don't use the `syntax` module here, because this
|
||||||
|
// struct doesn't map well to standard TLS presentation language concepts.
|
||||||
|
//
|
||||||
|
// TODO: File a spec bug
|
||||||
|
type FinishedBody struct {
|
||||||
|
VerifyDataLen int
|
||||||
|
VerifyData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fin FinishedBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeFinished
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fin FinishedBody) Marshal() ([]byte, error) {
|
||||||
|
if len(fin.VerifyData) != fin.VerifyDataLen {
|
||||||
|
return nil, fmt.Errorf("tls.finished: data length mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
body := make([]byte, len(fin.VerifyData))
|
||||||
|
copy(body, fin.VerifyData)
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
if len(data) < fin.VerifyDataLen {
|
||||||
|
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
fin.VerifyData = make([]byte, fin.VerifyDataLen)
|
||||||
|
copy(fin.VerifyData, data[:fin.VerifyDataLen])
|
||||||
|
return fin.VerifyDataLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// Extension extensions<0..2^16-1>;
|
||||||
|
// } EncryptedExtensions;
|
||||||
|
//
|
||||||
|
// Marshal() and Unmarshal() are handled by ExtensionList
|
||||||
|
type EncryptedExtensionsBody struct {
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ee EncryptedExtensionsBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeEncryptedExtensions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(ee)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, ee)
|
||||||
|
}
|
||||||
|
|
||||||
|
// opaque ASN1Cert<1..2^24-1>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ASN1Cert cert_data;
|
||||||
|
// Extension extensions<0..2^16-1>
|
||||||
|
// } CertificateEntry;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// opaque certificate_request_context<0..2^8-1>;
|
||||||
|
// CertificateEntry certificate_list<0..2^24-1>;
|
||||||
|
// } Certificate;
|
||||||
|
type CertificateEntry struct {
|
||||||
|
CertData *x509.Certificate
|
||||||
|
Extensions ExtensionList
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertificateBody struct {
|
||||||
|
CertificateRequestContext []byte
|
||||||
|
CertificateList []CertificateEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type certificateEntryInner struct {
|
||||||
|
CertData []byte `tls:"head=3,min=1"`
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type certificateBodyInner struct {
|
||||||
|
CertificateRequestContext []byte `tls:"head=1"`
|
||||||
|
CertificateList []certificateEntryInner `tls:"head=3"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CertificateBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeCertificate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CertificateBody) Marshal() ([]byte, error) {
|
||||||
|
inner := certificateBodyInner{
|
||||||
|
CertificateRequestContext: c.CertificateRequestContext,
|
||||||
|
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, entry := range c.CertificateList {
|
||||||
|
inner.CertificateList[i] = certificateEntryInner{
|
||||||
|
CertData: entry.CertData.Raw,
|
||||||
|
Extensions: entry.Extensions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
inner := certificateBodyInner{}
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return read, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.CertificateRequestContext = inner.CertificateRequestContext
|
||||||
|
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
|
||||||
|
|
||||||
|
for i, entry := range inner.CertificateList {
|
||||||
|
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.CertificateList[i].Extensions = entry.Extensions
|
||||||
|
}
|
||||||
|
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// SignatureScheme algorithm;
|
||||||
|
// opaque signature<0..2^16-1>;
|
||||||
|
// } CertificateVerify;
|
||||||
|
type CertificateVerifyBody struct {
|
||||||
|
Algorithm SignatureScheme
|
||||||
|
Signature []byte `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv CertificateVerifyBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeCertificateVerify
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(cv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, cv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
|
||||||
|
// TODO: Change context for client auth
|
||||||
|
// TODO: Put this in a const
|
||||||
|
const context = "TLS 1.3, server CertificateVerify"
|
||||||
|
sigInput := bytes.Repeat([]byte{0x20}, 64)
|
||||||
|
sigInput = append(sigInput, []byte(context)...)
|
||||||
|
sigInput = append(sigInput, []byte{0}...)
|
||||||
|
sigInput = append(sigInput, data...)
|
||||||
|
return sigInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
|
||||||
|
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||||
|
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
|
||||||
|
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
|
||||||
|
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||||
|
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||||
|
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque certificate_request_context<0..2^8-1>;
|
||||||
|
// Extension extensions<2..2^16-1>;
|
||||||
|
// } CertificateRequest;
|
||||||
|
type CertificateRequestBody struct {
|
||||||
|
CertificateRequestContext []byte `tls:"head=1"`
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr CertificateRequestBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeCertificateRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(cr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, cr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// uint32 ticket_lifetime;
|
||||||
|
// uint32 ticket_age_add;
|
||||||
|
// opaque ticket_nonce<1..255>;
|
||||||
|
// opaque ticket<1..2^16-1>;
|
||||||
|
// Extension extensions<0..2^16-2>;
|
||||||
|
// } NewSessionTicket;
|
||||||
|
type NewSessionTicketBody struct {
|
||||||
|
TicketLifetime uint32
|
||||||
|
TicketAgeAdd uint32
|
||||||
|
TicketNonce []byte `tls:"head=1,min=1"`
|
||||||
|
Ticket []byte `tls:"head=2,min=1"`
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const ticketNonceLen = 16
|
||||||
|
|
||||||
|
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
|
||||||
|
buf := make([]byte, 4+ticketNonceLen+ticketLen)
|
||||||
|
_, err := prng.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tkt := &NewSessionTicketBody{
|
||||||
|
TicketLifetime: ticketLifetime,
|
||||||
|
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
|
||||||
|
TicketNonce: buf[4 : 4+ticketNonceLen],
|
||||||
|
Ticket: buf[4+ticketNonceLen:],
|
||||||
|
}
|
||||||
|
|
||||||
|
return tkt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tkt NewSessionTicketBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeNewSessionTicket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(tkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, tkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enum {
|
||||||
|
// update_not_requested(0), update_requested(1), (255)
|
||||||
|
// } KeyUpdateRequest;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// KeyUpdateRequest request_update;
|
||||||
|
// } KeyUpdate;
|
||||||
|
type KeyUpdateBody struct {
|
||||||
|
KeyUpdateRequest KeyUpdateRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ku KeyUpdateBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeKeyUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(ku)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, ku)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {} EndOfEarlyData;
|
||||||
|
type EndOfEarlyDataBody struct{}
|
||||||
|
|
||||||
|
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeEndOfEarlyData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// We use this environment variable to control logging. It should be a
|
||||||
|
// comma-separated list of log tags (see below) or "*" to enable all logging.
|
||||||
|
const logConfigVar = "MINT_LOG"
|
||||||
|
|
||||||
|
// Pre-defined log types
|
||||||
|
const (
|
||||||
|
logTypeCrypto = "crypto"
|
||||||
|
logTypeHandshake = "handshake"
|
||||||
|
logTypeNegotiation = "negotiation"
|
||||||
|
logTypeIO = "io"
|
||||||
|
logTypeFrameReader = "frame"
|
||||||
|
logTypeVerbose = "verbose"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
logFunction = log.Printf
|
||||||
|
logAll = false
|
||||||
|
logSettings = map[string]bool{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
parseLogEnv(os.Environ())
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLogEnv(env []string) {
|
||||||
|
for _, stmt := range env {
|
||||||
|
if strings.HasPrefix(stmt, logConfigVar+"=") {
|
||||||
|
val := stmt[len(logConfigVar)+1:]
|
||||||
|
|
||||||
|
if val == "*" {
|
||||||
|
logAll = true
|
||||||
|
} else {
|
||||||
|
for _, t := range strings.Split(val, ",") {
|
||||||
|
logSettings[t] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logf(tag string, format string, args ...interface{}) {
|
||||||
|
if logAll || logSettings[tag] {
|
||||||
|
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
|
||||||
|
logFunction(fullFormat, args...)
|
||||||
|
}
|
||||||
|
}
|
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
|
||||||
|
for _, offeredVersion := range offered {
|
||||||
|
for _, supportedVersion := range supported {
|
||||||
|
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
|
||||||
|
if offeredVersion == supportedVersion {
|
||||||
|
// XXX: Should probably be highest supported version, but for now, we
|
||||||
|
// only support one version, so it doesn't really matter.
|
||||||
|
return true, offeredVersion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
|
||||||
|
for _, share := range keyShares {
|
||||||
|
for _, group := range groups {
|
||||||
|
if group != share.Group {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pub, priv, err := newKeyShare(share.Group)
|
||||||
|
if err != nil {
|
||||||
|
// If we encounter an error, just keep looking
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
|
||||||
|
if err != nil {
|
||||||
|
// If we encounter an error, just keep looking
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, group, pub, dhSecret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
|
||||||
|
)
|
||||||
|
|
||||||
|
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
|
||||||
|
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
|
||||||
|
for i, id := range identities {
|
||||||
|
identityHex := hex.EncodeToString(id.Identity)
|
||||||
|
|
||||||
|
psk, ok := psks.Get(identityHex)
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// For resumption, make sure the ticket age is correct
|
||||||
|
if psk.IsResumption {
|
||||||
|
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
|
||||||
|
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
|
||||||
|
ticketAgeDelta := knownTicketAge - extTicketAge
|
||||||
|
if knownTicketAge < extTicketAge {
|
||||||
|
ticketAgeDelta = extTicketAge - knownTicketAge
|
||||||
|
}
|
||||||
|
if ticketAgeDelta > ticketAgeTolerance {
|
||||||
|
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
|
||||||
|
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
|
||||||
|
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params, ok := cipherSuiteMap[psk.CipherSuite]
|
||||||
|
if !ok {
|
||||||
|
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute binder
|
||||||
|
binderLabel := labelExternalBinder
|
||||||
|
if psk.IsResumption {
|
||||||
|
binderLabel = labelResumptionBinder
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
|
||||||
|
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||||
|
|
||||||
|
// context = ClientHello[truncated]
|
||||||
|
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
|
||||||
|
ctxHash := params.Hash.New()
|
||||||
|
ctxHash.Write(context)
|
||||||
|
|
||||||
|
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
|
||||||
|
if !bytes.Equal(binder, binders[i].Binder) {
|
||||||
|
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
|
||||||
|
return true, i, &psk, params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeNegotiation, "Failed to find a usable PSK")
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
|
||||||
|
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
|
||||||
|
dhAllowed := false
|
||||||
|
dhRequired := true
|
||||||
|
for _, mode := range modes {
|
||||||
|
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
|
||||||
|
dhRequired = dhRequired && (mode == PSKModeDHEKE)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use PSK if we can meet DH requirement and modes were provided
|
||||||
|
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
|
||||||
|
|
||||||
|
// Use DH if allowed
|
||||||
|
usingDH := canDoDH && (dhAllowed || !usingPSK)
|
||||||
|
|
||||||
|
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
|
||||||
|
return usingDH, usingPSK
|
||||||
|
}
|
||||||
|
|
||||||
|
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
|
||||||
|
// Select for server name if provided
|
||||||
|
candidates := certs
|
||||||
|
if serverName != nil {
|
||||||
|
candidatesByName := []*Certificate{}
|
||||||
|
for _, cert := range certs {
|
||||||
|
for _, name := range cert.Chain[0].DNSNames {
|
||||||
|
if len(*serverName) > 0 && name == *serverName {
|
||||||
|
candidatesByName = append(candidatesByName, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(candidatesByName) == 0 {
|
||||||
|
return nil, 0, fmt.Errorf("No certificates available for server name")
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates = candidatesByName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select for signature scheme
|
||||||
|
for _, cert := range candidates {
|
||||||
|
for _, scheme := range signatureSchemes {
|
||||||
|
if !schemeValidForKey(scheme, cert.PrivateKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, scheme, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
|
||||||
|
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
|
||||||
|
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
|
||||||
|
return usingEarlyData
|
||||||
|
}
|
||||||
|
|
||||||
|
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
||||||
|
for _, s1 := range offered {
|
||||||
|
if psk != nil {
|
||||||
|
if s1 == psk.CipherSuite {
|
||||||
|
return s1, nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s2 := range supported {
|
||||||
|
if s1 == s2 {
|
||||||
|
return s1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
|
||||||
|
for _, p1 := range offered {
|
||||||
|
if psk != nil {
|
||||||
|
if p1 != psk.NextProto {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p2 := range supported {
|
||||||
|
if p1 == p2 {
|
||||||
|
return p1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client offers ALPN on resumption, it must match the earlier one
|
||||||
|
var err error
|
||||||
|
if psk != nil && psk.IsResumption && (len(offered) > 0) {
|
||||||
|
err = fmt.Errorf("ALPN for PSK not provided")
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sequenceNumberLen = 8 // sequence number length
|
||||||
|
recordHeaderLen = 5 // record header length
|
||||||
|
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||||
|
)
|
||||||
|
|
||||||
|
type DecryptError string
|
||||||
|
|
||||||
|
func (err DecryptError) Error() string {
|
||||||
|
return string(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ContentType type;
|
||||||
|
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
||||||
|
// uint16 length;
|
||||||
|
// opaque fragment[TLSPlaintext.length];
|
||||||
|
// } TLSPlaintext;
|
||||||
|
type TLSPlaintext struct {
|
||||||
|
// Omitted: record_version (static)
|
||||||
|
// Omitted: length (computed from fragment)
|
||||||
|
contentType RecordType
|
||||||
|
fragment []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type RecordLayer struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
conn io.ReadWriter // The underlying connection
|
||||||
|
frame *frameReader // The buffered frame reader
|
||||||
|
nextData []byte // The next record to send
|
||||||
|
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
||||||
|
cachedError error // Error on the last record read
|
||||||
|
|
||||||
|
ivLength int // Length of the seq and nonce fields
|
||||||
|
seq []byte // Zero-padded sequence number
|
||||||
|
nonce []byte // Buffer for per-record nonces
|
||||||
|
cipher cipher.AEAD // AEAD cipher
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordLayerFrameDetails struct{}
|
||||||
|
|
||||||
|
func (d recordLayerFrameDetails) headerLen() int {
|
||||||
|
return recordHeaderLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d recordLayerFrameDetails) defaultReadLen() int {
|
||||||
|
return recordHeaderLen + maxFragmentLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||||
|
return (int(hdr[3]) << 8) | int(hdr[4]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
|
||||||
|
r := RecordLayer{}
|
||||||
|
r.conn = conn
|
||||||
|
r.frame = newFrameReader(recordLayerFrameDetails{})
|
||||||
|
r.ivLength = 0
|
||||||
|
return &r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
|
||||||
|
var err error
|
||||||
|
r.cipher, err = cipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ivLength = len(iv)
|
||||||
|
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
|
||||||
|
r.nonce = make([]byte, r.ivLength)
|
||||||
|
copy(r.nonce, iv)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) incrementSequenceNumber() {
|
||||||
|
if r.ivLength == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
|
||||||
|
r.seq[i]++
|
||||||
|
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
|
||||||
|
if r.seq[i] != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not allowed to let sequence number wrap.
|
||||||
|
// Instead, must renegotiate before it does.
|
||||||
|
// Not likely enough to bother.
|
||||||
|
panic("TLS: sequence number wraparound")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||||
|
// Expand the fragment to hold contentType, padding, and overhead
|
||||||
|
originalLen := len(pt.fragment)
|
||||||
|
plaintextLen := originalLen + 1 + padLen
|
||||||
|
ciphertextLen := plaintextLen + r.cipher.Overhead()
|
||||||
|
|
||||||
|
// Assemble the revised plaintext
|
||||||
|
out := &TLSPlaintext{
|
||||||
|
contentType: RecordTypeApplicationData,
|
||||||
|
fragment: make([]byte, ciphertextLen),
|
||||||
|
}
|
||||||
|
copy(out.fragment, pt.fragment)
|
||||||
|
out.fragment[originalLen] = byte(pt.contentType)
|
||||||
|
for i := 1; i <= padLen; i++ {
|
||||||
|
out.fragment[originalLen+i] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the fragment
|
||||||
|
payload := out.fragment[:plaintextLen]
|
||||||
|
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
||||||
|
if len(pt.fragment) < r.cipher.Overhead() {
|
||||||
|
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
|
||||||
|
return nil, 0, DecryptError(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptLen := len(pt.fragment) - r.cipher.Overhead()
|
||||||
|
out := &TLSPlaintext{
|
||||||
|
contentType: pt.contentType,
|
||||||
|
fragment: make([]byte, decryptLen),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt
|
||||||
|
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the padding boundary
|
||||||
|
padLen := 0
|
||||||
|
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transfer the content type
|
||||||
|
newLen := decryptLen - padLen - 1
|
||||||
|
out.contentType = RecordType(out.fragment[newLen])
|
||||||
|
|
||||||
|
// Truncate the message to remove contentType, padding, overhead
|
||||||
|
out.fragment = out.fragment[:newLen]
|
||||||
|
return out, padLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
||||||
|
var pt *TLSPlaintext
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for {
|
||||||
|
pt, err = r.nextRecord()
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !block || err != WouldBlock {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pt.contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||||
|
pt, err := r.nextRecord()
|
||||||
|
|
||||||
|
// Consume the cached record if there was one
|
||||||
|
r.cachedRecord = nil
|
||||||
|
r.cachedError = nil
|
||||||
|
|
||||||
|
return pt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||||
|
if r.cachedRecord != nil {
|
||||||
|
logf(logTypeIO, "Returning cached record")
|
||||||
|
return r.cachedRecord, r.cachedError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop until one of three things happens:
|
||||||
|
//
|
||||||
|
// 1. We get a frame
|
||||||
|
// 2. We try to read off the socket and get nothing, in which case
|
||||||
|
// return WouldBlock
|
||||||
|
// 3. We get an error.
|
||||||
|
err := WouldBlock
|
||||||
|
var header, body []byte
|
||||||
|
|
||||||
|
for err != nil {
|
||||||
|
if r.frame.needed() > 0 {
|
||||||
|
buf := make([]byte, recordHeaderLen+maxFragmentLen)
|
||||||
|
n, err := r.conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeIO, "Error reading, %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return nil, WouldBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeIO, "Read %v bytes", n)
|
||||||
|
|
||||||
|
buf = buf[:n]
|
||||||
|
r.frame.addChunk(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
header, body, err = r.frame.process()
|
||||||
|
// Loop around on WouldBlock to see if some
|
||||||
|
// data is now available.
|
||||||
|
if err != nil && err != WouldBlock {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pt := &TLSPlaintext{}
|
||||||
|
// Validate content type
|
||||||
|
switch RecordType(header[0]) {
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
||||||
|
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
|
||||||
|
pt.contentType = RecordType(header[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate version
|
||||||
|
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
|
||||||
|
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate size < max
|
||||||
|
size := (int(header[3]) << 8) + int(header[4])
|
||||||
|
if size > maxFragmentLen+256 {
|
||||||
|
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
pt.fragment = make([]byte, size)
|
||||||
|
copy(pt.fragment, body)
|
||||||
|
|
||||||
|
// Attempt to decrypt fragment
|
||||||
|
if r.cipher != nil {
|
||||||
|
pt, _, err = r.decrypt(pt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that plaintext length is not too long
|
||||||
|
if len(pt.fragment) > maxFragmentLen {
|
||||||
|
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||||
|
|
||||||
|
r.cachedRecord = pt
|
||||||
|
r.incrementSequenceNumber()
|
||||||
|
return pt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
||||||
|
return r.WriteRecordWithPadding(pt, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
||||||
|
if r.cipher != nil {
|
||||||
|
pt = r.encrypt(pt, padLen)
|
||||||
|
} else if padLen > 0 {
|
||||||
|
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pt.fragment) > maxFragmentLen {
|
||||||
|
return fmt.Errorf("tls.record: Record size too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
length := len(pt.fragment)
|
||||||
|
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
|
||||||
|
record := append(header, pt.fragment...)
|
||||||
|
|
||||||
|
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||||
|
|
||||||
|
r.incrementSequenceNumber()
|
||||||
|
_, err := r.conn.Write(record)
|
||||||
|
return err
|
||||||
|
}
|
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
@ -0,0 +1,898 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"hash"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server State Machine
|
||||||
|
//
|
||||||
|
// START <-----+
|
||||||
|
// Recv ClientHello | | Send HelloRetryRequest
|
||||||
|
// v |
|
||||||
|
// RECVD_CH ----+
|
||||||
|
// | Select parameters
|
||||||
|
// | Send ServerHello
|
||||||
|
// v
|
||||||
|
// NEGOTIATED
|
||||||
|
// | Send EncryptedExtensions
|
||||||
|
// | [Send CertificateRequest]
|
||||||
|
// Can send | [Send Certificate + CertificateVerify]
|
||||||
|
// app data --> | Send Finished
|
||||||
|
// after +--------+--------+
|
||||||
|
// here No 0-RTT | | 0-RTT
|
||||||
|
// | v
|
||||||
|
// | WAIT_EOED <---+
|
||||||
|
// | Recv | | | Recv
|
||||||
|
// | EndOfEarlyData | | | early data
|
||||||
|
// | | +-----+
|
||||||
|
// +> WAIT_FLIGHT2 <-+
|
||||||
|
// |
|
||||||
|
// +--------+--------+
|
||||||
|
// No auth | | Client auth
|
||||||
|
// | |
|
||||||
|
// | v
|
||||||
|
// | WAIT_CERT
|
||||||
|
// | Recv | | Recv Certificate
|
||||||
|
// | empty | v
|
||||||
|
// | Certificate | WAIT_CV
|
||||||
|
// | | | Recv
|
||||||
|
// | v | CertificateVerify
|
||||||
|
// +-> WAIT_FINISHED <---+
|
||||||
|
// | Recv Finished
|
||||||
|
// v
|
||||||
|
// CONNECTED
|
||||||
|
//
|
||||||
|
// NB: Not using state RECVD_CH
|
||||||
|
//
|
||||||
|
// State Instructions
|
||||||
|
// START {}
|
||||||
|
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
|
||||||
|
// WAIT_EOED RekeyIn;
|
||||||
|
// WAIT_FLIGHT2 {}
|
||||||
|
// WAIT_CERT_CR {}
|
||||||
|
// WAIT_CERT {}
|
||||||
|
// WAIT_CV {}
|
||||||
|
// WAIT_FINISHED RekeyIn; RekeyOut;
|
||||||
|
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||||
|
|
||||||
|
type ServerStateStart struct {
|
||||||
|
Caps Capabilities
|
||||||
|
conn *Conn
|
||||||
|
|
||||||
|
cookieSent bool
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeClientHello {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := &ClientHelloBody{}
|
||||||
|
_, err := ch.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
clientHello := hm
|
||||||
|
connParams := ConnectionParameters{}
|
||||||
|
|
||||||
|
supportedVersions := new(SupportedVersionsExtension)
|
||||||
|
serverName := new(ServerNameExtension)
|
||||||
|
supportedGroups := new(SupportedGroupsExtension)
|
||||||
|
signatureAlgorithms := new(SignatureAlgorithmsExtension)
|
||||||
|
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
|
||||||
|
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
|
||||||
|
clientEarlyData := &EarlyDataExtension{}
|
||||||
|
clientALPN := new(ALPNExtension)
|
||||||
|
clientPSKModes := new(PSKKeyExchangeModesExtension)
|
||||||
|
clientCookie := new(CookieExtension)
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
|
||||||
|
gotServerName := ch.Extensions.Find(serverName)
|
||||||
|
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
|
||||||
|
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
|
||||||
|
gotEarlyData := ch.Extensions.Find(clientEarlyData)
|
||||||
|
ch.Extensions.Find(clientKeyShares)
|
||||||
|
ch.Extensions.Find(clientPSK)
|
||||||
|
ch.Extensions.Find(clientALPN)
|
||||||
|
ch.Extensions.Find(clientPSKModes)
|
||||||
|
ch.Extensions.Find(clientCookie)
|
||||||
|
|
||||||
|
if gotServerName {
|
||||||
|
connParams.ServerName = string(*serverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client didn't send supportedVersions or doesn't support 1.3,
|
||||||
|
// then we're done here.
|
||||||
|
if !gotSupportedVersions {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
|
||||||
|
if !versionOK {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
|
||||||
|
return nil, nil, AlertAccessDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out if we can do DH
|
||||||
|
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
|
||||||
|
|
||||||
|
// Figure out if we can do PSK
|
||||||
|
canDoPSK := false
|
||||||
|
var selectedPSK int
|
||||||
|
var psk *PreSharedKey
|
||||||
|
var params CipherSuiteParams
|
||||||
|
if len(clientPSK.Identities) > 0 {
|
||||||
|
contextBase := []byte{}
|
||||||
|
if state.helloRetryRequest != nil {
|
||||||
|
chBytes := state.firstClientHello.Marshal()
|
||||||
|
hrrBytes := state.helloRetryRequest.Marshal()
|
||||||
|
contextBase = append(chBytes, hrrBytes...)
|
||||||
|
}
|
||||||
|
|
||||||
|
chTrunc, err := ch.Truncated()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
context := append(contextBase, chTrunc...)
|
||||||
|
|
||||||
|
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out if we actually should do DH / PSK
|
||||||
|
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
|
||||||
|
|
||||||
|
// Select a ciphersuite
|
||||||
|
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a cookie if required
|
||||||
|
// NB: Need to do this here because it's after ciphersuite selection, which
|
||||||
|
// has to be after PSK selection.
|
||||||
|
// XXX: Doing this statefully for now, could be stateless
|
||||||
|
var cookieData []byte
|
||||||
|
if state.Caps.RequireCookie && !state.cookieSent {
|
||||||
|
var err error
|
||||||
|
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cookieData != nil {
|
||||||
|
// Ignoring errors because everything here is newly constructed, so there
|
||||||
|
// shouldn't be marshal errors
|
||||||
|
hrr := &HelloRetryRequestBody{
|
||||||
|
Version: supportedVersion,
|
||||||
|
CipherSuite: connParams.CipherSuite,
|
||||||
|
}
|
||||||
|
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
params := cipherSuiteMap[connParams.CipherSuite]
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(clientHello.Marshal())
|
||||||
|
firstClientHello := &HandshakeMessage{
|
||||||
|
msgType: HandshakeTypeMessageHash,
|
||||||
|
body: h.Sum(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
nextState := ServerStateStart{
|
||||||
|
Caps: state.Caps,
|
||||||
|
conn: state.conn,
|
||||||
|
cookieSent: true,
|
||||||
|
firstClientHello: firstClientHello,
|
||||||
|
helloRetryRequest: helloRetryRequest,
|
||||||
|
}
|
||||||
|
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've got no entropy to make keys from, fail
|
||||||
|
if !connParams.UsingDH && !connParams.UsingPSK {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
var pskSecret []byte
|
||||||
|
var cert *Certificate
|
||||||
|
var certScheme SignatureScheme
|
||||||
|
if connParams.UsingPSK {
|
||||||
|
pskSecret = psk.Key
|
||||||
|
} else {
|
||||||
|
psk = nil
|
||||||
|
|
||||||
|
// If we're not using a PSK mode, then we need to have certain extensions
|
||||||
|
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
|
||||||
|
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
|
||||||
|
return nil, nil, AlertMissingExtension
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select a certificate
|
||||||
|
name := string(*serverName)
|
||||||
|
var err error
|
||||||
|
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
|
||||||
|
return nil, nil, AlertAccessDenied
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !connParams.UsingDH {
|
||||||
|
dhSecret = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out if we're going to do early data
|
||||||
|
var clientEarlyTrafficSecret []byte
|
||||||
|
connParams.ClientSendingEarlyData = gotEarlyData
|
||||||
|
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
|
||||||
|
if connParams.UsingEarlyData {
|
||||||
|
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(clientHello.Marshal())
|
||||||
|
chHash := h.Sum(nil)
|
||||||
|
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
|
||||||
|
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select a next protocol
|
||||||
|
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
|
||||||
|
return nil, nil, AlertNoApplicationProtocol
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
|
||||||
|
return ServerStateNegotiated{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Params: connParams,
|
||||||
|
|
||||||
|
dhGroup: dhGroup,
|
||||||
|
dhPublic: dhPublic,
|
||||||
|
dhSecret: dhSecret,
|
||||||
|
pskSecret: pskSecret,
|
||||||
|
selectedPSK: selectedPSK,
|
||||||
|
cert: cert,
|
||||||
|
certScheme: certScheme,
|
||||||
|
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
|
||||||
|
|
||||||
|
firstClientHello: state.firstClientHello,
|
||||||
|
helloRetryRequest: state.helloRetryRequest,
|
||||||
|
clientHello: clientHello,
|
||||||
|
}.Next(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateNegotiated struct {
|
||||||
|
Caps Capabilities
|
||||||
|
Params ConnectionParameters
|
||||||
|
|
||||||
|
dhGroup NamedGroup
|
||||||
|
dhPublic []byte
|
||||||
|
dhSecret []byte
|
||||||
|
pskSecret []byte
|
||||||
|
clientEarlyTrafficSecret []byte
|
||||||
|
selectedPSK int
|
||||||
|
cert *Certificate
|
||||||
|
certScheme SignatureScheme
|
||||||
|
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
clientHello *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the ServerHello
|
||||||
|
sh := &ServerHelloBody{
|
||||||
|
Version: supportedVersion,
|
||||||
|
CipherSuite: state.Params.CipherSuite,
|
||||||
|
}
|
||||||
|
_, err := prng.Read(sh.Random[:])
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
if state.Params.UsingDH {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
|
||||||
|
err = sh.Extensions.Add(&KeyShareExtension{
|
||||||
|
HandshakeType: HandshakeTypeServerHello,
|
||||||
|
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
|
||||||
|
err = sh.Extensions.Add(&PreSharedKeyExtension{
|
||||||
|
HandshakeType: HandshakeTypeServerHello,
|
||||||
|
SelectedIdentity: uint16(state.selectedPSK),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serverHello, err := HandshakeMessageFromBody(sh)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up crypto params
|
||||||
|
params, ok := cipherSuiteMap[sh.CipherSuite]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start up the handshake hash
|
||||||
|
handshakeHash := params.Hash.New()
|
||||||
|
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||||
|
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||||
|
handshakeHash.Write(state.clientHello.Marshal())
|
||||||
|
handshakeHash.Write(serverHello.Marshal())
|
||||||
|
|
||||||
|
// Compute handshake secrets
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
|
||||||
|
var earlySecret []byte
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
|
||||||
|
} else {
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.dhSecret == nil {
|
||||||
|
state.dhSecret = zero
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
h2 := handshakeHash.Sum(nil)
|
||||||
|
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||||
|
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
|
||||||
|
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||||
|
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||||
|
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||||
|
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
|
||||||
|
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||||
|
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||||
|
|
||||||
|
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
|
||||||
|
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||||
|
|
||||||
|
// Send an EncryptedExtensions message (even if it's empty)
|
||||||
|
eeList := ExtensionList{}
|
||||||
|
if state.Params.NextProto != "" {
|
||||||
|
logf(logTypeHandshake, "[server] sending ALPN extension")
|
||||||
|
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if state.Params.UsingEarlyData {
|
||||||
|
logf(logTypeHandshake, "[server] sending EDI extension")
|
||||||
|
err = eeList.Add(&EarlyDataExtension{})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ee := &EncryptedExtensionsBody{eeList}
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eem, err := HandshakeMessageFromBody(ee)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
handshakeHash.Write(eem.Marshal())
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
SendHandshakeMessage{serverHello},
|
||||||
|
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||||
|
SendHandshakeMessage{eem},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate with a certificate if required
|
||||||
|
if !state.Params.UsingPSK {
|
||||||
|
// Send a CertificateRequest message if we want client auth
|
||||||
|
if state.Caps.RequireClientAuth {
|
||||||
|
state.Params.UsingClientAuth = true
|
||||||
|
|
||||||
|
// XXX: We don't support sending any constraints besides a list of
|
||||||
|
// supported signature algorithms
|
||||||
|
cr := &CertificateRequestBody{}
|
||||||
|
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||||
|
err := cr.Extensions.Add(schemes)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
crm, err := HandshakeMessageFromBody(cr)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
//TODO state.state.serverCertificateRequest = cr
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{crm})
|
||||||
|
handshakeHash.Write(crm.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and send Certificate, CertificateVerify
|
||||||
|
certificate := &CertificateBody{
|
||||||
|
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
|
||||||
|
}
|
||||||
|
for i, entry := range state.cert.Chain {
|
||||||
|
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||||
|
}
|
||||||
|
certm, err := HandshakeMessageFromBody(certificate)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||||
|
handshakeHash.Write(certm.Marshal())
|
||||||
|
|
||||||
|
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
|
||||||
|
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
|
||||||
|
|
||||||
|
hcv := handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||||
|
handshakeHash.Write(certvm.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute secrets resulting from the server's first flight
|
||||||
|
h3 := handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||||
|
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||||
|
|
||||||
|
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
|
||||||
|
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||||
|
|
||||||
|
// Assemble the Finished message
|
||||||
|
fin := &FinishedBody{
|
||||||
|
VerifyDataLen: len(serverFinishedData),
|
||||||
|
VerifyData: serverFinishedData,
|
||||||
|
}
|
||||||
|
finm, _ := HandshakeMessageFromBody(fin)
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{finm})
|
||||||
|
handshakeHash.Write(finm.Marshal())
|
||||||
|
|
||||||
|
// Compute traffic secrets
|
||||||
|
h4 := handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
|
||||||
|
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
|
||||||
|
|
||||||
|
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||||
|
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||||
|
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||||
|
|
||||||
|
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
|
||||||
|
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
|
||||||
|
|
||||||
|
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
|
||||||
|
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||||
|
|
||||||
|
if state.Params.UsingEarlyData {
|
||||||
|
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
|
||||||
|
nextState := ServerStateWaitEOED{
|
||||||
|
AuthCertificate: state.Caps.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: params,
|
||||||
|
handshakeHash: handshakeHash,
|
||||||
|
masterSecret: masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: clientTrafficSecret,
|
||||||
|
serverTrafficSecret: serverTrafficSecret,
|
||||||
|
exporterSecret: exporterSecret,
|
||||||
|
}
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||||
|
ReadEarlyData{},
|
||||||
|
}...)
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||||
|
ReadPastEarlyData{},
|
||||||
|
}...)
|
||||||
|
waitFlight2 := ServerStateWaitFlight2{
|
||||||
|
AuthCertificate: state.Caps.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: params,
|
||||||
|
handshakeHash: handshakeHash,
|
||||||
|
masterSecret: masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: clientTrafficSecret,
|
||||||
|
serverTrafficSecret: serverTrafficSecret,
|
||||||
|
exporterSecret: exporterSecret,
|
||||||
|
}
|
||||||
|
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||||
|
toSend = append(toSend, moreToSend...)
|
||||||
|
return nextState, toSend, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitEOED struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hm.body) > 0 {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||||
|
}
|
||||||
|
waitFlight2 := ServerStateWaitFlight2{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||||
|
toSend = append(toSend, moreToSend...)
|
||||||
|
return nextState, toSend, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitFlight2 struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Params.UsingClientAuth {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
|
||||||
|
nextState := ServerStateWaitCert{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
|
||||||
|
nextState := ServerStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitCert struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := &CertificateBody{}
|
||||||
|
_, err := cert.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
if len(cert.CertificateList) == 0 {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
|
||||||
|
nextState := ServerStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
|
||||||
|
nextState := ServerStateWaitCV{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
clientCertificate: cert,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitCV struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
|
||||||
|
clientCertificate *CertificateBody
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
certVerify := &CertificateVerifyBody{}
|
||||||
|
_, err := certVerify.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify client signature over handshake hash
|
||||||
|
hcv := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
|
||||||
|
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.AuthCertificate != nil {
|
||||||
|
err := state.AuthCertificate(state.clientCertificate.CertificateList)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
|
||||||
|
return nil, nil, AlertBadCertificate
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it passes, record the certificateVerify in the transcript hash
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
|
||||||
|
nextState := ServerStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitFinished struct {
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
|
||||||
|
_, err := fin.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify client Finished data
|
||||||
|
h5 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||||
|
|
||||||
|
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||||
|
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||||
|
|
||||||
|
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the resumption secret
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
h6 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
|
||||||
|
|
||||||
|
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||||
|
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||||
|
|
||||||
|
// Compute client traffic keys
|
||||||
|
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
|
||||||
|
nextState := StateConnected{
|
||||||
|
Params: state.Params,
|
||||||
|
isClient: false,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
resumptionSecret: resumptionSecret,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
|
||||||
|
}
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Marker interface for actions that an implementation should take based on
|
||||||
|
// state transitions.
|
||||||
|
type HandshakeAction interface{}
|
||||||
|
|
||||||
|
type SendHandshakeMessage struct {
|
||||||
|
Message *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type SendEarlyData struct{}
|
||||||
|
|
||||||
|
type ReadEarlyData struct{}
|
||||||
|
|
||||||
|
type ReadPastEarlyData struct{}
|
||||||
|
|
||||||
|
type RekeyIn struct {
|
||||||
|
Label string
|
||||||
|
KeySet keySet
|
||||||
|
}
|
||||||
|
|
||||||
|
type RekeyOut struct {
|
||||||
|
Label string
|
||||||
|
KeySet keySet
|
||||||
|
}
|
||||||
|
|
||||||
|
type StorePSK struct {
|
||||||
|
PSK PreSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandshakeState interface {
|
||||||
|
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppExtensionHandler interface {
|
||||||
|
Send(hs HandshakeType, el *ExtensionList) error
|
||||||
|
Receive(hs HandshakeType, el *ExtensionList) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capabilities objects represent the capabilities of a TLS client or server,
|
||||||
|
// as an input to TLS negotiation
|
||||||
|
type Capabilities struct {
|
||||||
|
// For both client and server
|
||||||
|
CipherSuites []CipherSuite
|
||||||
|
Groups []NamedGroup
|
||||||
|
SignatureSchemes []SignatureScheme
|
||||||
|
PSKs PreSharedKeyCache
|
||||||
|
Certificates []*Certificate
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
ExtensionHandler AppExtensionHandler
|
||||||
|
|
||||||
|
// For client
|
||||||
|
PSKModes []PSKKeyExchangeMode
|
||||||
|
|
||||||
|
// For server
|
||||||
|
NextProtos []string
|
||||||
|
AllowEarlyData bool
|
||||||
|
RequireCookie bool
|
||||||
|
CookieHandler CookieHandler
|
||||||
|
RequireClientAuth bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionOptions objects represent per-connection settings for a client
|
||||||
|
// initiating a connection
|
||||||
|
type ConnectionOptions struct {
|
||||||
|
ServerName string
|
||||||
|
NextProtos []string
|
||||||
|
EarlyData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionParameters objects represent the parameters negotiated for a
|
||||||
|
// connection.
|
||||||
|
type ConnectionParameters struct {
|
||||||
|
UsingPSK bool
|
||||||
|
UsingDH bool
|
||||||
|
ClientSendingEarlyData bool
|
||||||
|
UsingEarlyData bool
|
||||||
|
UsingClientAuth bool
|
||||||
|
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
ServerName string
|
||||||
|
NextProto string
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateConnected is symmetric between client and server
|
||||||
|
type StateConnected struct {
|
||||||
|
Params ConnectionParameters
|
||||||
|
isClient bool
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
resumptionSecret []byte
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||||
|
var trafficKeys keySet
|
||||||
|
if state.isClient {
|
||||||
|
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||||
|
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||||
|
} else {
|
||||||
|
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||||
|
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
SendHandshakeMessage{kum},
|
||||||
|
RekeyOut{Label: "update", KeySet: trafficKeys},
|
||||||
|
}
|
||||||
|
return toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||||
|
tkt, err := NewSessionTicket(length, lifetime)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||||
|
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
|
||||||
|
|
||||||
|
newPSK := PreSharedKey{
|
||||||
|
CipherSuite: state.cryptoParams.Suite,
|
||||||
|
IsResumption: true,
|
||||||
|
Identity: tkt.Ticket,
|
||||||
|
Key: resumptionKey,
|
||||||
|
NextProto: state.Params.NextProto,
|
||||||
|
ReceivedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
|
||||||
|
TicketAgeAdd: tkt.TicketAgeAdd,
|
||||||
|
}
|
||||||
|
|
||||||
|
tktm, err := HandshakeMessageFromBody(tkt)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
StorePSK{newPSK},
|
||||||
|
SendHandshakeMessage{tktm},
|
||||||
|
}
|
||||||
|
return toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyGeneric, err := hm.ToBody()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
switch body := bodyGeneric.(type) {
|
||||||
|
case *KeyUpdateBody:
|
||||||
|
var trafficKeys keySet
|
||||||
|
if !state.isClient {
|
||||||
|
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||||
|
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||||
|
} else {
|
||||||
|
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||||
|
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
|
||||||
|
|
||||||
|
// If requested, roll outbound keys and send a KeyUpdate
|
||||||
|
if body.KeyUpdateRequest == KeyUpdateRequested {
|
||||||
|
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, moreToSend...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, toSend, AlertNoAlert
|
||||||
|
|
||||||
|
case *NewSessionTicketBody:
|
||||||
|
// XXX: Allow NewSessionTicket in both directions?
|
||||||
|
if !state.isClient {
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||||
|
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
||||||
|
|
||||||
|
psk := PreSharedKey{
|
||||||
|
CipherSuite: state.cryptoParams.Suite,
|
||||||
|
IsResumption: true,
|
||||||
|
Identity: body.Ticket,
|
||||||
|
Key: resumptionKey,
|
||||||
|
NextProto: state.Params.NextProto,
|
||||||
|
ReceivedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
|
||||||
|
TicketAgeAdd: body.TicketAgeAdd,
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{StorePSK{psk}}
|
||||||
|
return state, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
package syntax
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Unmarshal(data []byte, v interface{}) (int, error) {
|
||||||
|
// Check for well-formedness.
|
||||||
|
// Avoids filling out half a data structure
|
||||||
|
// before discovering a JSON syntax error.
|
||||||
|
d := decodeState{}
|
||||||
|
d.Write(data)
|
||||||
|
return d.unmarshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// These are the options that can be specified in the struct tag. Right now,
|
||||||
|
// all of them apply to variable-length vectors and nothing else
|
||||||
|
type decOpts struct {
|
||||||
|
head uint // length of length in bytes
|
||||||
|
min uint // minimum size in bytes
|
||||||
|
max uint // maximum size in bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
type decodeState struct {
|
||||||
|
bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if _, ok := r.(runtime.Error); ok {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
panic(s)
|
||||||
|
}
|
||||||
|
err = r.(error)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||||
|
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
|
||||||
|
}
|
||||||
|
|
||||||
|
read = d.value(rv)
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *decodeState) value(v reflect.Value) int {
|
||||||
|
return valueDecoder(v)(e, v, decOpts{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
|
||||||
|
|
||||||
|
func valueDecoder(v reflect.Value) decoderFunc {
|
||||||
|
return typeDecoder(v.Type().Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeDecoder(t reflect.Type) decoderFunc {
|
||||||
|
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||||
|
return newTypeDecoder(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||||
|
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return uintDecoder
|
||||||
|
case reflect.Array:
|
||||||
|
return newArrayDecoder(t)
|
||||||
|
case reflect.Slice:
|
||||||
|
return newSliceDecoder(t)
|
||||||
|
case reflect.Struct:
|
||||||
|
return newStructDecoder(t)
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///// Specific decoders below
|
||||||
|
|
||||||
|
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
var uintLen int
|
||||||
|
switch v.Elem().Kind() {
|
||||||
|
case reflect.Uint8:
|
||||||
|
uintLen = 1
|
||||||
|
case reflect.Uint16:
|
||||||
|
uintLen = 2
|
||||||
|
case reflect.Uint32:
|
||||||
|
uintLen = 4
|
||||||
|
case reflect.Uint64:
|
||||||
|
uintLen = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, uintLen)
|
||||||
|
n, err := d.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if n != uintLen {
|
||||||
|
panic(fmt.Errorf("Insufficient data to read uint"))
|
||||||
|
}
|
||||||
|
|
||||||
|
val := uint64(0)
|
||||||
|
for _, b := range buf {
|
||||||
|
val = (val << 8) + uint64(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
v.Elem().SetUint(val)
|
||||||
|
return uintLen
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type arrayDecoder struct {
|
||||||
|
elemDec decoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
n := v.Elem().Type().Len()
|
||||||
|
read := 0
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
|
||||||
|
}
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
|
func newArrayDecoder(t reflect.Type) decoderFunc {
|
||||||
|
dec := &arrayDecoder{typeDecoder(t.Elem())}
|
||||||
|
return dec.decode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type sliceDecoder struct {
|
||||||
|
elementType reflect.Type
|
||||||
|
elementDec decoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
if opts.head == 0 {
|
||||||
|
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
||||||
|
}
|
||||||
|
|
||||||
|
lengthBytes := make([]byte, opts.head)
|
||||||
|
n, err := d.Read(lengthBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if uint(n) != opts.head {
|
||||||
|
panic(fmt.Errorf("Not enough data to read header"))
|
||||||
|
}
|
||||||
|
|
||||||
|
length := uint(0)
|
||||||
|
for _, b := range lengthBytes {
|
||||||
|
length = (length << 8) + uint(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.max > 0 && length > opts.max {
|
||||||
|
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||||
|
}
|
||||||
|
if length < opts.min {
|
||||||
|
panic(fmt.Errorf("Length of vector below declared min"))
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, length)
|
||||||
|
n, err = d.Read(data)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if uint(n) != length {
|
||||||
|
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
|
||||||
|
}
|
||||||
|
|
||||||
|
elemBuf := &decodeState{}
|
||||||
|
elemBuf.Write(data)
|
||||||
|
elems := []reflect.Value{}
|
||||||
|
read := int(opts.head)
|
||||||
|
for elemBuf.Len() > 0 {
|
||||||
|
elem := reflect.New(sd.elementType)
|
||||||
|
read += sd.elementDec(elemBuf, elem, opts)
|
||||||
|
elems = append(elems, elem)
|
||||||
|
}
|
||||||
|
|
||||||
|
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
|
||||||
|
for i := 0; i < len(elems); i += 1 {
|
||||||
|
v.Elem().Index(i).Set(elems[i].Elem())
|
||||||
|
}
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSliceDecoder(t reflect.Type) decoderFunc {
|
||||||
|
dec := &sliceDecoder{
|
||||||
|
elementType: t.Elem(),
|
||||||
|
elementDec: typeDecoder(t.Elem()),
|
||||||
|
}
|
||||||
|
return dec.decode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type structDecoder struct {
|
||||||
|
fieldOpts []decOpts
|
||||||
|
fieldDecs []decoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
read := 0
|
||||||
|
for i := range sd.fieldDecs {
|
||||||
|
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
|
||||||
|
}
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStructDecoder(t reflect.Type) decoderFunc {
|
||||||
|
n := t.NumField()
|
||||||
|
sd := structDecoder{
|
||||||
|
fieldOpts: make([]decOpts, n),
|
||||||
|
fieldDecs: make([]decoderFunc, n),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
f := t.Field(i)
|
||||||
|
|
||||||
|
tag := f.Tag.Get("tls")
|
||||||
|
tagOpts := parseTag(tag)
|
||||||
|
|
||||||
|
sd.fieldOpts[i] = decOpts{
|
||||||
|
head: tagOpts["head"],
|
||||||
|
max: tagOpts["max"],
|
||||||
|
min: tagOpts["min"],
|
||||||
|
}
|
||||||
|
|
||||||
|
sd.fieldDecs[i] = typeDecoder(f.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sd.decode
|
||||||
|
}
|
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
package syntax
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Marshal(v interface{}) ([]byte, error) {
|
||||||
|
e := &encodeState{}
|
||||||
|
err := e.marshal(v, encOpts{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return e.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// These are the options that can be specified in the struct tag. Right now,
|
||||||
|
// all of them apply to variable-length vectors and nothing else
|
||||||
|
type encOpts struct {
|
||||||
|
head uint // length of length in bytes
|
||||||
|
min uint // minimum size in bytes
|
||||||
|
max uint // maximum size in bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodeState struct {
|
||||||
|
bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if _, ok := r.(runtime.Error); ok {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
panic(s)
|
||||||
|
}
|
||||||
|
err = r.(error)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
e.reflectValue(reflect.ValueOf(v), opts)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
|
||||||
|
valueEncoder(v)(e, v, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
|
||||||
|
|
||||||
|
func valueEncoder(v reflect.Value) encoderFunc {
|
||||||
|
if !v.IsValid() {
|
||||||
|
panic(fmt.Errorf("Cannot encode an invalid value"))
|
||||||
|
}
|
||||||
|
return typeEncoder(v.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeEncoder(t reflect.Type) encoderFunc {
|
||||||
|
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||||
|
return newTypeEncoder(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||||
|
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return uintEncoder
|
||||||
|
case reflect.Array:
|
||||||
|
return newArrayEncoder(t)
|
||||||
|
case reflect.Slice:
|
||||||
|
return newSliceEncoder(t)
|
||||||
|
case reflect.Struct:
|
||||||
|
return newStructEncoder(t)
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///// Specific encoders below
|
||||||
|
|
||||||
|
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
u := v.Uint()
|
||||||
|
switch v.Type().Kind() {
|
||||||
|
case reflect.Uint8:
|
||||||
|
e.WriteByte(byte(u))
|
||||||
|
case reflect.Uint16:
|
||||||
|
e.Write([]byte{byte(u >> 8), byte(u)})
|
||||||
|
case reflect.Uint32:
|
||||||
|
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||||
|
case reflect.Uint64:
|
||||||
|
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
|
||||||
|
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type arrayEncoder struct {
|
||||||
|
elemEnc encoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
n := v.Len()
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
ae.elemEnc(e, v.Index(i), opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newArrayEncoder(t reflect.Type) encoderFunc {
|
||||||
|
enc := &arrayEncoder{typeEncoder(t.Elem())}
|
||||||
|
return enc.encode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type sliceEncoder struct {
|
||||||
|
ae *arrayEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
if opts.head == 0 {
|
||||||
|
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayState := &encodeState{}
|
||||||
|
se.ae.encode(arrayState, v, opts)
|
||||||
|
|
||||||
|
n := uint(arrayState.Len())
|
||||||
|
if opts.max > 0 && n > opts.max {
|
||||||
|
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
||||||
|
}
|
||||||
|
if n>>(8*opts.head) > 0 {
|
||||||
|
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||||
|
}
|
||||||
|
if n < opts.min {
|
||||||
|
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int(opts.head - 1); i >= 0; i -= 1 {
|
||||||
|
e.WriteByte(byte(n >> (8 * uint(i))))
|
||||||
|
}
|
||||||
|
e.Write(arrayState.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSliceEncoder(t reflect.Type) encoderFunc {
|
||||||
|
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
|
||||||
|
return enc.encode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type structEncoder struct {
|
||||||
|
fieldOpts []encOpts
|
||||||
|
fieldEncs []encoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
for i := range se.fieldEncs {
|
||||||
|
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStructEncoder(t reflect.Type) encoderFunc {
|
||||||
|
n := t.NumField()
|
||||||
|
se := structEncoder{
|
||||||
|
fieldOpts: make([]encOpts, n),
|
||||||
|
fieldEncs: make([]encoderFunc, n),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
f := t.Field(i)
|
||||||
|
tag := f.Tag.Get("tls")
|
||||||
|
tagOpts := parseTag(tag)
|
||||||
|
|
||||||
|
se.fieldOpts[i] = encOpts{
|
||||||
|
head: tagOpts["head"],
|
||||||
|
max: tagOpts["max"],
|
||||||
|
min: tagOpts["min"],
|
||||||
|
}
|
||||||
|
se.fieldEncs[i] = typeEncoder(f.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return se.encode
|
||||||
|
}
|
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package syntax
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// `tls:"head=2,min=2,max=255"`
|
||||||
|
|
||||||
|
type tagOptions map[string]uint
|
||||||
|
|
||||||
|
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
||||||
|
// name=value pairs, where the values MUST be unsigned integers
|
||||||
|
func parseTag(tag string) tagOptions {
|
||||||
|
opts := tagOptions{}
|
||||||
|
for _, token := range strings.Split(tag, ",") {
|
||||||
|
if strings.Index(token, "=") == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(token, "=")
|
||||||
|
if len(parts[0]) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||||
|
opts[parts[0]] = uint(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server returns a new TLS server side connection
|
||||||
|
// using conn as the underlying transport.
|
||||||
|
// The configuration config must be non-nil and must include
|
||||||
|
// at least one certificate or else set GetCertificate.
|
||||||
|
func Server(conn net.Conn, config *Config) *Conn {
|
||||||
|
return NewConn(conn, config, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client returns a new TLS client side connection
|
||||||
|
// using conn as the underlying transport.
|
||||||
|
// The config cannot be nil: users must set either ServerName or
|
||||||
|
// InsecureSkipVerify in the config.
|
||||||
|
func Client(conn net.Conn, config *Config) *Conn {
|
||||||
|
return NewConn(conn, config, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||||
|
type Listener struct {
|
||||||
|
net.Listener
|
||||||
|
config *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept waits for and returns the next incoming TLS connection.
|
||||||
|
// The returned connection c is a *tls.Conn.
|
||||||
|
func (l *Listener) Accept() (c net.Conn, err error) {
|
||||||
|
c, err = l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server := Server(c, l.config)
|
||||||
|
err = server.Handshake()
|
||||||
|
if err == AlertNoAlert {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
c = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewListener creates a Listener which accepts connections from an inner
|
||||||
|
// Listener and wraps each connection with Server.
|
||||||
|
// The configuration config must be non-nil and must include
|
||||||
|
// at least one certificate or else set GetCertificate.
|
||||||
|
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||||
|
l := new(Listener)
|
||||||
|
l.Listener = inner
|
||||||
|
l.config = config
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen creates a TLS listener accepting connections on the
|
||||||
|
// given network address using net.Listen.
|
||||||
|
// The configuration config must be non-nil and must include
|
||||||
|
// at least one certificate or else set GetCertificate.
|
||||||
|
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||||
|
if config == nil || !config.ValidForServer() {
|
||||||
|
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
|
||||||
|
}
|
||||||
|
l, err := net.Listen(network, laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewListener(l, config), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TimeoutError struct{}
|
||||||
|
|
||||||
|
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
||||||
|
func (TimeoutError) Timeout() bool { return true }
|
||||||
|
func (TimeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
// DialWithDialer connects to the given network address using dialer.Dial and
|
||||||
|
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
||||||
|
// timeout or deadline given in the dialer apply to connection and TLS
|
||||||
|
// handshake as a whole.
|
||||||
|
//
|
||||||
|
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||||
|
// configuration; see the documentation of Config for the defaults.
|
||||||
|
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
|
// We want the Timeout and Deadline values from dialer to cover the
|
||||||
|
// whole process: TCP connection and TLS handshake. This means that we
|
||||||
|
// also need to start our own timers now.
|
||||||
|
timeout := dialer.Timeout
|
||||||
|
|
||||||
|
if !dialer.Deadline.IsZero() {
|
||||||
|
deadlineTimeout := dialer.Deadline.Sub(time.Now())
|
||||||
|
if timeout == 0 || deadlineTimeout < timeout {
|
||||||
|
timeout = deadlineTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errChannel chan error
|
||||||
|
|
||||||
|
if timeout != 0 {
|
||||||
|
errChannel = make(chan error, 2)
|
||||||
|
time.AfterFunc(timeout, func() {
|
||||||
|
errChannel <- TimeoutError{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
rawConn, err := dialer.Dial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
|
if colonPos == -1 {
|
||||||
|
colonPos = len(addr)
|
||||||
|
}
|
||||||
|
hostname := addr[:colonPos]
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
config = &Config{}
|
||||||
|
}
|
||||||
|
// If no ServerName is set, infer the ServerName
|
||||||
|
// from the hostname we're connecting to.
|
||||||
|
if config.ServerName == "" {
|
||||||
|
// Make a copy to avoid polluting argument or default.
|
||||||
|
c := config.Clone()
|
||||||
|
c.ServerName = hostname
|
||||||
|
config = c
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := Client(rawConn, config)
|
||||||
|
|
||||||
|
if timeout == 0 {
|
||||||
|
err = conn.Handshake()
|
||||||
|
if err == AlertNoAlert {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
go func() {
|
||||||
|
errChannel <- conn.Handshake()
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = <-errChannel
|
||||||
|
if err == AlertNoAlert {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
rawConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address using net.Dial
|
||||||
|
// and then initiates a TLS handshake, returning the resulting
|
||||||
|
// TLS connection.
|
||||||
|
// Dial interprets a nil configuration as equivalent to
|
||||||
|
// the zero configuration; see the documentation of Config
|
||||||
|
// for the defaults.
|
||||||
|
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||||
|
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||||
|
}
|
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
@ -1,32 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SentPacketHandler handles ACKs received for outgoing packets
|
|
||||||
type SentPacketHandler interface {
|
|
||||||
// SentPacket may modify the packet
|
|
||||||
SentPacket(packet *Packet) error
|
|
||||||
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
|
|
||||||
|
|
||||||
SendingAllowed() bool
|
|
||||||
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
|
|
||||||
DequeuePacketForRetransmission() (packet *Packet)
|
|
||||||
GetLeastUnacked() protocol.PacketNumber
|
|
||||||
|
|
||||||
GetAlarmTimeout() time.Time
|
|
||||||
OnAlarm()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
|
||||||
type ReceivedPacketHandler interface {
|
|
||||||
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
|
|
||||||
SetLowerLimit(protocol.PacketNumber)
|
|
||||||
|
|
||||||
GetAlarmTimeout() time.Time
|
|
||||||
GetAckFrame() *frames.AckFrame
|
|
||||||
}
|
|
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
@ -3,7 +3,7 @@ package quic
|
|||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bufferPool sync.Pool
|
var bufferPool sync.Pool
|
||||||
|
360
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
360
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
@ -10,32 +10,39 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
listenErr error
|
|
||||||
|
|
||||||
conn connection
|
conn connection
|
||||||
hostname string
|
hostname string
|
||||||
|
|
||||||
errorChan chan struct{}
|
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||||
handshakeChan <-chan handshakeEvent
|
versionNegotiated bool // has the server accepted our version
|
||||||
|
receivedVersionNegotiationPacket bool
|
||||||
|
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||||
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
config *Config
|
config *Config
|
||||||
versionNegotiated bool // has version negotiation completed yet
|
tls handshake.MintTLS // only used when using TLS
|
||||||
|
|
||||||
connectionID protocol.ConnectionID
|
connectionID protocol.ConnectionID
|
||||||
version protocol.VersionNumber
|
|
||||||
|
initialVersion protocol.VersionNumber
|
||||||
|
version protocol.VersionNumber
|
||||||
|
|
||||||
session packetHandler
|
session packetHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
// make it possible to mock connection ID generation in the tests
|
||||||
|
generateConnectionID = utils.GenerateConnectionID
|
||||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,71 +60,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||||||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
|
||||||
// The hostname for SNI is taken from the given address.
|
|
||||||
func DialAddrNonFWSecure(
|
|
||||||
addr string,
|
|
||||||
tlsConf *tls.Config,
|
|
||||||
config *Config,
|
|
||||||
) (NonFWSession, error) {
|
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
|
||||||
// The host parameter is used for SNI.
|
|
||||||
func DialNonFWSecure(
|
|
||||||
pconn net.PacketConn,
|
|
||||||
remoteAddr net.Addr,
|
|
||||||
host string,
|
|
||||||
tlsConf *tls.Config,
|
|
||||||
config *Config,
|
|
||||||
) (NonFWSession, error) {
|
|
||||||
connID, err := utils.GenerateConnectionID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var hostname string
|
|
||||||
if tlsConf != nil {
|
|
||||||
hostname = tlsConf.ServerName
|
|
||||||
}
|
|
||||||
|
|
||||||
if hostname == "" {
|
|
||||||
hostname, _, err = net.SplitHostPort(host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clientConfig := populateClientConfig(config)
|
|
||||||
c := &client{
|
|
||||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
|
||||||
connectionID: connID,
|
|
||||||
hostname: hostname,
|
|
||||||
tlsConf: tlsConf,
|
|
||||||
config: clientConfig,
|
|
||||||
version: clientConfig.Versions[0],
|
|
||||||
errorChan: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.createNewSession(nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
|
||||||
|
|
||||||
return c.session.(NonFWSession), c.establishSecureConnection()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||||
// The host parameter is used for SNI.
|
// The host parameter is used for SNI.
|
||||||
func Dial(
|
func Dial(
|
||||||
@ -127,15 +69,39 @@ func Dial(
|
|||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
connID, err := generateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = sess.WaitUntilHandshakeComplete()
|
|
||||||
if err != nil {
|
var hostname string
|
||||||
|
if tlsConf != nil {
|
||||||
|
hostname = tlsConf.ServerName
|
||||||
|
}
|
||||||
|
if hostname == "" {
|
||||||
|
hostname, _, err = net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConfig := populateClientConfig(config)
|
||||||
|
c := &client{
|
||||||
|
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||||
|
connectionID: connID,
|
||||||
|
hostname: hostname,
|
||||||
|
tlsConf: tlsConf,
|
||||||
|
config: clientConfig,
|
||||||
|
version: clientConfig.Versions[0],
|
||||||
|
versionNegotiationChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||||
|
|
||||||
|
if err := c.dial(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return sess, nil
|
return c.session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||||
@ -153,6 +119,10 @@ func populateClientConfig(config *Config) *Config {
|
|||||||
if config.HandshakeTimeout != 0 {
|
if config.HandshakeTimeout != 0 {
|
||||||
handshakeTimeout = config.HandshakeTimeout
|
handshakeTimeout = config.HandshakeTimeout
|
||||||
}
|
}
|
||||||
|
idleTimeout := protocol.DefaultIdleTimeout
|
||||||
|
if config.IdleTimeout != 0 {
|
||||||
|
idleTimeout = config.IdleTimeout
|
||||||
|
}
|
||||||
|
|
||||||
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
||||||
if maxReceiveStreamFlowControlWindow == 0 {
|
if maxReceiveStreamFlowControlWindow == 0 {
|
||||||
@ -166,32 +136,109 @@ func populateClientConfig(config *Config) *Config {
|
|||||||
return &Config{
|
return &Config{
|
||||||
Versions: versions,
|
Versions: versions,
|
||||||
HandshakeTimeout: handshakeTimeout,
|
HandshakeTimeout: handshakeTimeout,
|
||||||
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
|
IdleTimeout: idleTimeout,
|
||||||
|
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||||
KeepAlive: config.KeepAlive,
|
KeepAlive: config.KeepAlive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
func (c *client) dial() error {
|
||||||
func (c *client) establishSecureConnection() error {
|
var err error
|
||||||
|
if c.version.UsesTLS() {
|
||||||
|
err = c.dialTLS()
|
||||||
|
} else {
|
||||||
|
err = c.dialGQUIC()
|
||||||
|
}
|
||||||
|
if err == errCloseSessionForNewVersion {
|
||||||
|
return c.dial()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialGQUIC() error {
|
||||||
|
if err := c.createNewGQUICSession(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
go c.listen()
|
go c.listen()
|
||||||
|
return c.establishSecureConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialTLS() error {
|
||||||
|
params := &handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
IdleTimeout: c.config.IdleTimeout,
|
||||||
|
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||||
|
// TODO(#523): make these values configurable
|
||||||
|
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||||
|
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||||
|
}
|
||||||
|
csc := handshake.NewCryptoStreamConn(nil)
|
||||||
|
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
|
||||||
|
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mintConf.ExtensionHandler = extHandler
|
||||||
|
mintConf.ServerName = c.hostname
|
||||||
|
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||||
|
|
||||||
|
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go c.listen()
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
if err != handshake.ErrCloseSessionForRetry {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
utils.Infof("Received a Retry packet. Recreating session.")
|
||||||
|
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||||
|
// It returns:
|
||||||
|
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||||
|
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||||
|
// - any other error that might occur
|
||||||
|
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||||
|
func (c *client) establishSecureConnection() error {
|
||||||
|
var runErr error
|
||||||
|
errorChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
runErr = c.session.run() // returns as soon as the session is closed
|
||||||
|
close(errorChan)
|
||||||
|
utils.Infof("Connection %x closed.", c.connectionID)
|
||||||
|
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait until the server accepts the QUIC version (or an error occurs)
|
||||||
|
select {
|
||||||
|
case <-errorChan:
|
||||||
|
return runErr
|
||||||
|
case <-c.versionNegotiationChan:
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-c.errorChan:
|
case <-errorChan:
|
||||||
return c.listenErr
|
return runErr
|
||||||
case ev := <-c.handshakeChan:
|
case err := <-c.session.handshakeStatus():
|
||||||
if ev.err != nil {
|
return err
|
||||||
return ev.err
|
|
||||||
}
|
|
||||||
if ev.encLevel != protocol.EncryptionSecure {
|
|
||||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen listens
|
// Listen listens on the underlying connection and passes packets on for handling.
|
||||||
|
// It returns when the connection is closed.
|
||||||
func (c *client) listen() {
|
func (c *client) listen() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@ -205,13 +252,15 @@ func (c *client) listen() {
|
|||||||
n, addr, err = c.conn.Read(data)
|
n, addr, err = c.conn.Read(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||||
c.session.Close(err)
|
c.mutex.Lock()
|
||||||
|
if c.session != nil {
|
||||||
|
c.session.Close(err)
|
||||||
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
data = data[:n]
|
c.handlePacket(addr, data[:n])
|
||||||
|
|
||||||
c.handlePacket(addr, data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,10 +268,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||||||
rcvTime := time.Now()
|
rcvTime := time.Now()
|
||||||
|
|
||||||
r := bytes.NewReader(packet)
|
r := bytes.NewReader(packet)
|
||||||
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
|
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||||
// drop this packet if we can't parse the Public Header
|
// drop this packet if we can't parse the header
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// reject packets with truncated connection id if we didn't request truncation
|
||||||
|
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
@ -230,6 +283,11 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// reject packets with the wrong connection ID
|
||||||
|
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if hdr.ResetFlag {
|
if hdr.ResetFlag {
|
||||||
cr := c.conn.RemoteAddr()
|
cr := c.conn.RemoteAddr()
|
||||||
// check if the remote address and the connection ID match
|
// check if the remote address and the connection ID match
|
||||||
@ -238,44 +296,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||||||
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pr, err := parsePublicReset(r)
|
pr, err := wire.ParsePublicReset(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
|
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
|
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
|
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ignore delayed / duplicated version negotiation packets
|
// handle Version Negotiation Packets
|
||||||
if c.versionNegotiated && hdr.VersionFlag {
|
if hdr.IsVersionNegotiation {
|
||||||
return
|
// ignore delayed / duplicated version negotiation packets
|
||||||
}
|
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// this is the first packet after the client sent a packet with the VersionFlag set
|
|
||||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
|
||||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
|
||||||
c.versionNegotiated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdr.VersionFlag {
|
|
||||||
// version negotiation packets have no payload
|
// version negotiation packets have no payload
|
||||||
if err := c.handlePacketWithVersionFlag(hdr); err != nil {
|
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||||
c.session.Close(err)
|
c.session.Close(err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this is the first packet we are receiving
|
||||||
|
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||||
|
if !c.versionNegotiated {
|
||||||
|
c.versionNegotiated = true
|
||||||
|
close(c.versionNegotiationChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
|
||||||
|
|
||||||
c.session.handlePacket(&receivedPacket{
|
c.session.handlePacket(&receivedPacket{
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
publicHeader: hdr,
|
header: hdr,
|
||||||
data: packet[len(packet)-r.Len():],
|
data: packet[len(packet)-r.Len():],
|
||||||
rcvTime: rcvTime,
|
rcvTime: rcvTime,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
for _, v := range hdr.SupportedVersions {
|
for _, v := range hdr.SupportedVersions {
|
||||||
if v == c.version {
|
if v == c.version {
|
||||||
// the version negotiation packet contains the version that we offered
|
// the version negotiation packet contains the version that we offered
|
||||||
@ -285,51 +347,57 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||||
if newVersion == protocol.VersionUnsupported {
|
if !ok {
|
||||||
return qerr.InvalidVersion
|
return qerr.InvalidVersion
|
||||||
}
|
}
|
||||||
|
c.receivedVersionNegotiationPacket = true
|
||||||
|
c.negotiatedVersions = hdr.SupportedVersions
|
||||||
|
|
||||||
// switch to negotiated version
|
// switch to negotiated version
|
||||||
|
c.initialVersion = c.version
|
||||||
c.version = newVersion
|
c.version = newVersion
|
||||||
c.versionNegotiated = true
|
|
||||||
var err error
|
var err error
|
||||||
c.connectionID, err = utils.GenerateConnectionID()
|
c.connectionID, err = utils.GenerateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
|
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||||
|
|
||||||
c.session.Close(errCloseSessionForNewVersion)
|
c.session.Close(errCloseSessionForNewVersion)
|
||||||
return c.createNewSession(hdr.SupportedVersions)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
func (c *client) createNewGQUICSession() (err error) {
|
||||||
var err error
|
c.mutex.Lock()
|
||||||
c.session, c.handshakeChan, err = newClientSession(
|
defer c.mutex.Unlock()
|
||||||
|
c.session, err = newClientSession(
|
||||||
c.conn,
|
c.conn,
|
||||||
c.hostname,
|
c.hostname,
|
||||||
c.version,
|
c.version,
|
||||||
c.connectionID,
|
c.connectionID,
|
||||||
c.tlsConf,
|
c.tlsConf,
|
||||||
c.config,
|
c.config,
|
||||||
negotiatedVersions,
|
c.initialVersion,
|
||||||
|
c.negotiatedVersions,
|
||||||
)
|
)
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
func (c *client) createNewTLSSession(
|
||||||
go func() {
|
paramsChan <-chan handshake.TransportParameters,
|
||||||
// session.run() returns as soon as the session is closed
|
version protocol.VersionNumber,
|
||||||
err := c.session.run()
|
) (err error) {
|
||||||
if err == errCloseSessionForNewVersion {
|
c.mutex.Lock()
|
||||||
return
|
defer c.mutex.Unlock()
|
||||||
}
|
c.session, err = newTLSClientSession(
|
||||||
c.listenErr = err
|
c.conn,
|
||||||
close(c.errorChan)
|
c.hostname,
|
||||||
|
c.version,
|
||||||
utils.Infof("Connection %x closed.", c.connectionID)
|
c.connectionID,
|
||||||
c.conn.Close()
|
c.config,
|
||||||
}()
|
c.tls,
|
||||||
return nil
|
paramsChan,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
@ -1,58 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/cipher"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/aes12"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type aeadAESGCM struct {
|
|
||||||
otherIV []byte
|
|
||||||
myIV []byte
|
|
||||||
encrypter cipher.AEAD
|
|
||||||
decrypter cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size
|
|
||||||
//
|
|
||||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
|
||||||
// tag size, and couples the cipher and aes packages closely.
|
|
||||||
// See https://github.com/lucas-clemente/aes12.
|
|
||||||
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
|
||||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
|
||||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
|
||||||
}
|
|
||||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &aeadAESGCM{
|
|
||||||
otherIV: otherIV,
|
|
||||||
myIV: myIV,
|
|
||||||
encrypter: encrypter,
|
|
||||||
decrypter: decrypter,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
@ -1,14 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
|
||||||
res := make([]byte, 12)
|
|
||||||
copy(res[0:4], iv)
|
|
||||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
|
||||||
return res
|
|
||||||
}
|
|
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
@ -1,76 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StkSource is used to create and verify source address tokens
|
|
||||||
type StkSource interface {
|
|
||||||
// NewToken creates a new token
|
|
||||||
NewToken([]byte) ([]byte, error)
|
|
||||||
// DecodeToken decodes a token
|
|
||||||
DecodeToken([]byte) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type stkSource struct {
|
|
||||||
aead cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
const stkKeySize = 16
|
|
||||||
|
|
||||||
// Chrome currently sets this to 12, but discusses changing it to 16. We start
|
|
||||||
// at 16 :)
|
|
||||||
const stkNonceSize = 16
|
|
||||||
|
|
||||||
// NewStkSource creates a source for source address tokens
|
|
||||||
func NewStkSource() (StkSource, error) {
|
|
||||||
secret := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(secret); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
key, err := deriveKey(secret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &stkSource{aead: aead}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
|
||||||
nonce := make([]byte, stkNonceSize)
|
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.aead.Seal(nonce, nonce, data, nil), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
|
|
||||||
if len(p) < stkNonceSize {
|
|
||||||
return nil, fmt.Errorf("STK too short: %d", len(p))
|
|
||||||
}
|
|
||||||
nonce := p[:stkNonceSize]
|
|
||||||
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deriveKey(secret []byte) ([]byte, error) {
|
|
||||||
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
|
|
||||||
key := make([]byte, stkKeySize)
|
|
||||||
if _, err := io.ReadFull(r, key); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
|
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cryptoStreamI interface {
|
||||||
|
StreamID() protocol.StreamID
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
handleStreamFrame(*wire.StreamFrame) error
|
||||||
|
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||||
|
closeForShutdown(error)
|
||||||
|
setReadOffset(protocol.ByteCount)
|
||||||
|
// methods needed for flow control
|
||||||
|
getWindowUpdate() protocol.ByteCount
|
||||||
|
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cryptoStream struct {
|
||||||
|
*stream
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ cryptoStreamI = &cryptoStream{}
|
||||||
|
|
||||||
|
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||||
|
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||||
|
return &cryptoStream{str}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadOffset sets the read offset.
|
||||||
|
// It is only needed for the crypto stream.
|
||||||
|
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||||
|
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||||
|
s.receiveStream.readOffset = offset
|
||||||
|
s.receiveStream.frameQueue.readPosition = offset
|
||||||
|
}
|
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
@ -7,12 +7,15 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/h2quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
verbose := flag.Bool("v", false, "verbose")
|
verbose := flag.Bool("v", false, "verbose")
|
||||||
|
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
urls := flag.Args()
|
urls := flag.Args()
|
||||||
|
|
||||||
@ -23,8 +26,17 @@ func main() {
|
|||||||
}
|
}
|
||||||
utils.SetLogTimeFormat("")
|
utils.SetLogTimeFormat("")
|
||||||
|
|
||||||
|
versions := protocol.SupportedVersions
|
||||||
|
if *tls {
|
||||||
|
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
roundTripper := &h2quic.RoundTripper{
|
||||||
|
QuicConfig: &quic.Config{Versions: versions},
|
||||||
|
}
|
||||||
|
defer roundTripper.Close()
|
||||||
hclient := &http.Client{
|
hclient := &http.Client{
|
||||||
Transport: &h2quic.RoundTripper{},
|
Transport: roundTripper,
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
@ -17,7 +17,9 @@ import (
|
|||||||
|
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/h2quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,6 +123,7 @@ func main() {
|
|||||||
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
||||||
www := flag.String("www", "/var/www", "www data")
|
www := flag.String("www", "/var/www", "www data")
|
||||||
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
||||||
|
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *verbose {
|
if *verbose {
|
||||||
@ -130,6 +133,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
utils.SetLogTimeFormat("")
|
utils.SetLogTimeFormat("")
|
||||||
|
|
||||||
|
versions := protocol.SupportedVersions
|
||||||
|
if *tls {
|
||||||
|
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||||
|
}
|
||||||
|
|
||||||
certFile := *certPath + "/fullchain.pem"
|
certFile := *certPath + "/fullchain.pem"
|
||||||
keyFile := *certPath + "/privkey.pem"
|
keyFile := *certPath + "/privkey.pem"
|
||||||
|
|
||||||
@ -148,7 +156,11 @@ func main() {
|
|||||||
if *tcp {
|
if *tcp {
|
||||||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
||||||
} else {
|
} else {
|
||||||
err = h2quic.ListenAndServeQUIC(bCap, certFile, keyFile, nil)
|
server := h2quic.Server{
|
||||||
|
Server: &http.Server{Addr: bCap},
|
||||||
|
QuicConfig: &quic.Config{Versions: versions},
|
||||||
|
}
|
||||||
|
err = server.ListenAndServeTLS(certFile, keyFile)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
|
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
@ -1,240 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/handshake"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
type flowControlManager struct {
|
|
||||||
connectionParameters handshake.ConnectionParametersManager
|
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
|
|
||||||
streamFlowController map[protocol.StreamID]*flowController
|
|
||||||
connFlowController *flowController
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ FlowControlManager = &flowControlManager{}
|
|
||||||
|
|
||||||
var errMapAccess = errors.New("Error accessing the flowController map.")
|
|
||||||
|
|
||||||
// NewFlowControlManager creates a new flow control manager
|
|
||||||
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
|
|
||||||
return &flowControlManager{
|
|
||||||
connectionParameters: connectionParameters,
|
|
||||||
rttStats: rttStats,
|
|
||||||
streamFlowController: make(map[protocol.StreamID]*flowController),
|
|
||||||
connFlowController: newFlowController(0, false, connectionParameters, rttStats),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStream creates new flow controllers for a stream
|
|
||||||
// it does nothing if the stream already exists
|
|
||||||
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
if _, ok := f.streamFlowController[streamID]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveStream removes a closed stream from flow control
|
|
||||||
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
delete(f.streamFlowController, streamID)
|
|
||||||
f.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetStream should be called when receiving a RstStreamFrame
|
|
||||||
// it updates the byte offset to the value in the RstStreamFrame
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
streamFlowController, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
|
|
||||||
if err != nil {
|
|
||||||
return qerr.StreamDataAfterTermination
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamFlowController.ContributesToConnection() {
|
|
||||||
f.connFlowController.IncrementHighestReceived(increment)
|
|
||||||
if f.connFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateHighestReceived updates the highest received byte offset for a stream
|
|
||||||
// it adds the number of additional bytes to connection level flow control
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
streamFlowController, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
|
|
||||||
// this error can be ignored here
|
|
||||||
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
|
|
||||||
|
|
||||||
if streamFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamFlowController.ContributesToConnection() {
|
|
||||||
f.connFlowController.IncrementHighestReceived(increment)
|
|
||||||
if f.connFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
fc, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fc.AddBytesRead(n)
|
|
||||||
if fc.ContributesToConnection() {
|
|
||||||
f.connFlowController.AddBytesRead(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
// get WindowUpdates for streams
|
|
||||||
for id, fc := range f.streamFlowController {
|
|
||||||
if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary {
|
|
||||||
res = append(res, WindowUpdate{StreamID: id, Offset: offset})
|
|
||||||
if fc.ContributesToConnection() && newIncrement != 0 {
|
|
||||||
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// get a WindowUpdate for the connection
|
|
||||||
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary {
|
|
||||||
res = append(res, WindowUpdate{StreamID: 0, Offset: offset})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
|
||||||
f.mutex.RLock()
|
|
||||||
defer f.mutex.RUnlock()
|
|
||||||
|
|
||||||
// StreamID can be 0 when retransmitting
|
|
||||||
if streamID == 0 {
|
|
||||||
return f.connFlowController.receiveWindow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
flowController, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return flowController.receiveWindow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
fc, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fc.AddBytesSent(n)
|
|
||||||
if fc.ContributesToConnection() {
|
|
||||||
f.connFlowController.AddBytesSent(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// must not be called with StreamID 0
|
|
||||||
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
|
||||||
f.mutex.RLock()
|
|
||||||
defer f.mutex.RUnlock()
|
|
||||||
|
|
||||||
fc, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
res := fc.SendWindowSize()
|
|
||||||
|
|
||||||
if fc.ContributesToConnection() {
|
|
||||||
res = utils.MinByteCount(res, f.connFlowController.SendWindowSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
|
|
||||||
f.mutex.RLock()
|
|
||||||
defer f.mutex.RUnlock()
|
|
||||||
|
|
||||||
return f.connFlowController.SendWindowSize()
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamID may be 0 here
|
|
||||||
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
var fc *flowController
|
|
||||||
if streamID == 0 {
|
|
||||||
fc = f.connFlowController
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
fc, err = f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fc.UpdateSendWindow(offset), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) {
|
|
||||||
streamFlowController, ok := f.streamFlowController[streamID]
|
|
||||||
if !ok {
|
|
||||||
return nil, errMapAccess
|
|
||||||
}
|
|
||||||
return streamFlowController, nil
|
|
||||||
}
|
|
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
@ -1,198 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/handshake"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type flowController struct {
|
|
||||||
streamID protocol.StreamID
|
|
||||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
|
||||||
|
|
||||||
connectionParameters handshake.ConnectionParametersManager
|
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
|
|
||||||
bytesSent protocol.ByteCount
|
|
||||||
sendWindow protocol.ByteCount
|
|
||||||
|
|
||||||
lastWindowUpdateTime time.Time
|
|
||||||
|
|
||||||
bytesRead protocol.ByteCount
|
|
||||||
highestReceived protocol.ByteCount
|
|
||||||
receiveWindow protocol.ByteCount
|
|
||||||
receiveWindowIncrement protocol.ByteCount
|
|
||||||
maxReceiveWindowIncrement protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
|
|
||||||
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
|
|
||||||
|
|
||||||
// newFlowController gets a new flow controller
|
|
||||||
func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
|
|
||||||
fc := flowController{
|
|
||||||
streamID: streamID,
|
|
||||||
contributesToConnection: contributesToConnection,
|
|
||||||
connectionParameters: connectionParameters,
|
|
||||||
rttStats: rttStats,
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamID == 0 {
|
|
||||||
fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
|
|
||||||
fc.receiveWindowIncrement = fc.receiveWindow
|
|
||||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
|
|
||||||
} else {
|
|
||||||
fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
|
|
||||||
fc.receiveWindowIncrement = fc.receiveWindow
|
|
||||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &fc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) ContributesToConnection() bool {
|
|
||||||
return c.contributesToConnection
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) getSendWindow() protocol.ByteCount {
|
|
||||||
if c.sendWindow == 0 {
|
|
||||||
if c.streamID == 0 {
|
|
||||||
return c.connectionParameters.GetSendConnectionFlowControlWindow()
|
|
||||||
}
|
|
||||||
return c.connectionParameters.GetSendStreamFlowControlWindow()
|
|
||||||
}
|
|
||||||
return c.sendWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
|
|
||||||
c.bytesSent += n
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
|
||||||
// it returns true if the window was actually updated
|
|
||||||
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
|
|
||||||
if newOffset > c.sendWindow {
|
|
||||||
c.sendWindow = newOffset
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) SendWindowSize() protocol.ByteCount {
|
|
||||||
sendWindow := c.getSendWindow()
|
|
||||||
|
|
||||||
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return sendWindow - c.bytesSent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) SendWindowOffset() protocol.ByteCount {
|
|
||||||
return c.getSendWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
|
||||||
// Should **only** be used for the stream-level FlowController
|
|
||||||
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
|
|
||||||
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
|
|
||||||
// It should only be treated as an error when resetting a stream
|
|
||||||
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
|
|
||||||
if byteOffset == c.highestReceived {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if byteOffset > c.highestReceived {
|
|
||||||
increment := byteOffset - c.highestReceived
|
|
||||||
c.highestReceived = byteOffset
|
|
||||||
return increment, nil
|
|
||||||
}
|
|
||||||
return 0, ErrReceivedSmallerByteOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
|
||||||
// Should **only** be used for the connection-level FlowController
|
|
||||||
func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) {
|
|
||||||
c.highestReceived += increment
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) AddBytesRead(n protocol.ByteCount) {
|
|
||||||
// pretend we sent a WindowUpdate when reading the first byte
|
|
||||||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
|
||||||
if c.bytesRead == 0 {
|
|
||||||
c.lastWindowUpdateTime = time.Now()
|
|
||||||
}
|
|
||||||
c.bytesRead += n
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaybeUpdateWindow updates the receive window, if necessary
|
|
||||||
// if the receive window increment is changed, the new value is returned, otherwise a 0
|
|
||||||
// the last return value is the new offset of the receive window
|
|
||||||
func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) {
|
|
||||||
diff := c.receiveWindow - c.bytesRead
|
|
||||||
|
|
||||||
// Chromium implements the same threshold
|
|
||||||
if diff < (c.receiveWindowIncrement / 2) {
|
|
||||||
var newWindowIncrement protocol.ByteCount
|
|
||||||
oldWindowIncrement := c.receiveWindowIncrement
|
|
||||||
|
|
||||||
c.maybeAdjustWindowIncrement()
|
|
||||||
if c.receiveWindowIncrement != oldWindowIncrement {
|
|
||||||
newWindowIncrement = c.receiveWindowIncrement
|
|
||||||
}
|
|
||||||
|
|
||||||
c.lastWindowUpdateTime = time.Now()
|
|
||||||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
|
|
||||||
return true, newWindowIncrement, c.receiveWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
|
||||||
func (c *flowController) maybeAdjustWindowIncrement() {
|
|
||||||
if c.lastWindowUpdateTime.IsZero() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rtt := c.rttStats.SmoothedRTT()
|
|
||||||
if rtt == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
|
|
||||||
|
|
||||||
// interval between the window updates is sufficiently large, no need to increase the increment
|
|
||||||
if timeSinceLastWindowUpdate >= 2*rtt {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
oldWindowSize := c.receiveWindowIncrement
|
|
||||||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
|
|
||||||
|
|
||||||
// debug log, if the window size was actually increased
|
|
||||||
if oldWindowSize < c.receiveWindowIncrement {
|
|
||||||
newWindowSize := c.receiveWindowIncrement / (1 << 10)
|
|
||||||
if c.streamID == 0 {
|
|
||||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
|
|
||||||
} else {
|
|
||||||
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
|
||||||
// it is intended be used for the connection-level flow controller
|
|
||||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
|
||||||
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
|
|
||||||
if inc > c.receiveWindowIncrement {
|
|
||||||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
|
|
||||||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) CheckFlowControlViolation() bool {
|
|
||||||
return c.highestReceived > c.receiveWindow
|
|
||||||
}
|
|
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
@ -1,26 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
|
|
||||||
// WindowUpdate provides the data for WindowUpdateFrames.
|
|
||||||
type WindowUpdate struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
Offset protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// A FlowControlManager manages the flow control
|
|
||||||
type FlowControlManager interface {
|
|
||||||
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
|
|
||||||
RemoveStream(streamID protocol.StreamID)
|
|
||||||
// methods needed for receiving data
|
|
||||||
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
|
||||||
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
|
||||||
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
|
|
||||||
GetWindowUpdates() []WindowUpdate
|
|
||||||
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
|
|
||||||
// methods needed for sending data
|
|
||||||
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
|
|
||||||
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
|
|
||||||
RemainingConnectionWindowSize() protocol.ByteCount
|
|
||||||
UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error)
|
|
||||||
}
|
|
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
@ -1,9 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
|
|
||||||
// AckRange is an ACK range
|
|
||||||
type AckRange struct {
|
|
||||||
FirstPacketNumber protocol.PacketNumber
|
|
||||||
LastPacketNumber protocol.PacketNumber
|
|
||||||
}
|
|
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
@ -1,44 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A BlockedFrame in QUIC
|
|
||||||
type BlockedFrame struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
}
|
|
||||||
|
|
||||||
//Write writes a BlockedFrame frame
|
|
||||||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x05)
|
|
||||||
utils.WriteUint32(b, uint32(f.StreamID))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseBlockedFrame parses a BLOCKED frame
|
|
||||||
func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) {
|
|
||||||
frame := &BlockedFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sid, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.StreamID = protocol.StreamID(sid)
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
@ -1,73 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A ConnectionCloseFrame in QUIC
|
|
||||||
type ConnectionCloseFrame struct {
|
|
||||||
ErrorCode qerr.ErrorCode
|
|
||||||
ReasonPhrase string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame
|
|
||||||
func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) {
|
|
||||||
frame := &ConnectionCloseFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
errorCode, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ErrorCode = qerr.ErrorCode(errorCode)
|
|
||||||
|
|
||||||
reasonPhraseLen, err := utils.ReadUint16(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if reasonPhraseLen > uint16(protocol.MaxPacketSize) {
|
|
||||||
return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
reasonPhrase := make([]byte, reasonPhraseLen)
|
|
||||||
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ReasonPhrase = string(reasonPhrase)
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes an CONNECTION_CLOSE frame.
|
|
||||||
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x02)
|
|
||||||
utils.WriteUint32(b, uint32(f.ErrorCode))
|
|
||||||
|
|
||||||
if len(f.ReasonPhrase) > math.MaxUint16 {
|
|
||||||
return errors.New("ConnectionFrame: ReasonPhrase too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
reasonPhraseLen := uint16(len(f.ReasonPhrase))
|
|
||||||
utils.WriteUint16(b, reasonPhraseLen)
|
|
||||||
b.WriteString(f.ReasonPhrase)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
@ -1,13 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Frame in QUIC
|
|
||||||
type Frame interface {
|
|
||||||
Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
|
||||||
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
|
|
||||||
}
|
|
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
@ -1,28 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
// LogFrame logs a frame, either sent or received
|
|
||||||
func LogFrame(frame Frame, sent bool) {
|
|
||||||
if !utils.Debug() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dir := "<-"
|
|
||||||
if sent {
|
|
||||||
dir = "->"
|
|
||||||
}
|
|
||||||
switch f := frame.(type) {
|
|
||||||
case *StreamFrame:
|
|
||||||
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
|
|
||||||
case *StopWaitingFrame:
|
|
||||||
if sent {
|
|
||||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
|
|
||||||
} else {
|
|
||||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
|
|
||||||
}
|
|
||||||
case *AckFrame:
|
|
||||||
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
|
|
||||||
default:
|
|
||||||
utils.Debugf("\t%s %#v", dir, frame)
|
|
||||||
}
|
|
||||||
}
|
|
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
@ -1,59 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A RstStreamFrame in QUIC
|
|
||||||
type RstStreamFrame struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
ErrorCode uint32
|
|
||||||
ByteOffset protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
//Write writes a RST_STREAM frame
|
|
||||||
func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x01)
|
|
||||||
utils.WriteUint32(b, uint32(f.StreamID))
|
|
||||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
|
||||||
utils.WriteUint32(b, f.ErrorCode)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4 + 8 + 4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseRstStreamFrame parses a RST_STREAM frame
|
|
||||||
func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) {
|
|
||||||
frame := &RstStreamFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sid, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.StreamID = protocol.StreamID(sid)
|
|
||||||
|
|
||||||
byteOffset, err := utils.ReadUint64(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
|
||||||
|
|
||||||
frame.ErrorCode, err = utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
@ -1,54 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A WindowUpdateFrame in QUIC
|
|
||||||
type WindowUpdateFrame struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
ByteOffset protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
//Write writes a RST_STREAM frame
|
|
||||||
func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
typeByte := uint8(0x04)
|
|
||||||
b.WriteByte(typeByte)
|
|
||||||
|
|
||||||
utils.WriteUint32(b, uint32(f.StreamID))
|
|
||||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4 + 8, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseWindowUpdateFrame parses a RST_STREAM frame
|
|
||||||
func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) {
|
|
||||||
frame := &WindowUpdateFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sid, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.StreamID = protocol.StreamID(sid)
|
|
||||||
|
|
||||||
byteOffset, err := utils.ReadUint64(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
109
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
109
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
@ -15,8 +15,8 @@ import (
|
|||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -34,10 +34,10 @@ type client struct {
|
|||||||
config *quic.Config
|
config *quic.Config
|
||||||
opts *roundTripperOpts
|
opts *roundTripperOpts
|
||||||
|
|
||||||
hostname string
|
hostname string
|
||||||
encryptionLevel protocol.EncryptionLevel
|
handshakeErr error
|
||||||
handshakeErr error
|
dialOnce sync.Once
|
||||||
dialOnce sync.Once
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||||
|
|
||||||
session quic.Session
|
session quic.Session
|
||||||
headerStream quic.Stream
|
headerStream quic.Stream
|
||||||
@ -51,8 +51,8 @@ type client struct {
|
|||||||
var _ http.RoundTripper = &client{}
|
var _ http.RoundTripper = &client{}
|
||||||
|
|
||||||
var defaultQuicConfig = &quic.Config{
|
var defaultQuicConfig = &quic.Config{
|
||||||
RequestConnectionIDTruncation: true,
|
RequestConnectionIDOmission: true,
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// newClient creates a new client
|
// newClient creates a new client
|
||||||
@ -61,26 +61,31 @@ func newClient(
|
|||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
opts *roundTripperOpts,
|
opts *roundTripperOpts,
|
||||||
quicConfig *quic.Config,
|
quicConfig *quic.Config,
|
||||||
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||||
) *client {
|
) *client {
|
||||||
config := defaultQuicConfig
|
config := defaultQuicConfig
|
||||||
if quicConfig != nil {
|
if quicConfig != nil {
|
||||||
config = quicConfig
|
config = quicConfig
|
||||||
}
|
}
|
||||||
return &client{
|
return &client{
|
||||||
hostname: authorityAddr("https", hostname),
|
hostname: authorityAddr("https", hostname),
|
||||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||||
encryptionLevel: protocol.EncryptionUnencrypted,
|
tlsConf: tlsConfig,
|
||||||
tlsConf: tlsConfig,
|
config: config,
|
||||||
config: config,
|
opts: opts,
|
||||||
opts: opts,
|
headerErrored: make(chan struct{}),
|
||||||
headerErrored: make(chan struct{}),
|
dialer: dialer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial dials the connection
|
// dial dials the connection
|
||||||
func (c *client) dial() error {
|
func (c *client) dial() error {
|
||||||
var err error
|
var err error
|
||||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
if c.dialer != nil {
|
||||||
|
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||||
|
} else {
|
||||||
|
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -90,9 +95,6 @@ func (c *client) dial() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.headerStream.StreamID() != 3 {
|
|
||||||
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
|
|
||||||
}
|
|
||||||
c.requestWriter = newRequestWriter(c.headerStream)
|
c.requestWriter = newRequestWriter(c.headerStream)
|
||||||
go c.handleHeaderStream()
|
go c.handleHeaderStream()
|
||||||
return nil
|
return nil
|
||||||
@ -102,45 +104,44 @@ func (c *client) handleHeaderStream() {
|
|||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||||
|
|
||||||
var lastStream protocol.StreamID
|
var err error
|
||||||
|
for err == nil {
|
||||||
|
err = c.readResponse(h2framer, decoder)
|
||||||
|
}
|
||||||
|
utils.Debugf("Error handling header stream: %s", err)
|
||||||
|
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||||
|
// stop all running request
|
||||||
|
close(c.headerErrored)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
||||||
frame, err := h2framer.ReadFrame()
|
frame, err := h2framer.ReadFrame()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
return err
|
||||||
break
|
}
|
||||||
}
|
hframe, ok := frame.(*http2.HeadersFrame)
|
||||||
lastStream = protocol.StreamID(frame.Header().StreamID)
|
if !ok {
|
||||||
hframe, ok := frame.(*http2.HeadersFrame)
|
return errors.New("not a headers frame")
|
||||||
if !ok {
|
}
|
||||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
|
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||||
break
|
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||||
}
|
if err != nil {
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
||||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
|
||||||
if err != nil {
|
|
||||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mutex.RLock()
|
|
||||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
|
||||||
c.mutex.RUnlock()
|
|
||||||
if !ok {
|
|
||||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
rsp, err := responseFromHeaders(mhframe)
|
|
||||||
if err != nil {
|
|
||||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
|
||||||
}
|
|
||||||
responseChan <- rsp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop all running request
|
c.mutex.RLock()
|
||||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||||
close(c.headerErrored)
|
c.mutex.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err := responseFromHeaders(mhframe)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
responseChan <- rsp
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Roundtrip executes a request and returns a response
|
// Roundtrip executes a request and returns a response
|
||||||
|
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
@ -13,8 +13,8 @@ import (
|
|||||||
"golang.org/x/net/lex/httplex"
|
"golang.org/x/net/lex/httplex"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type requestWriter struct {
|
type requestWriter struct {
|
||||||
|
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
@ -8,8 +8,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
)
|
)
|
||||||
@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
|||||||
|
|
||||||
func (w *responseWriter) Flush() {}
|
func (w *responseWriter) Flush() {}
|
||||||
|
|
||||||
// TODO: Implement a functional CloseNotify method.
|
// This is a NOP. Use http.Request.Context
|
||||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
||||||
|
|
||||||
// test that we implement http.Flusher
|
// test that we implement http.Flusher
|
||||||
|
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
@ -41,6 +41,11 @@ type RoundTripper struct {
|
|||||||
// If nil, reasonable default values will be used.
|
// If nil, reasonable default values will be used.
|
||||||
QuicConfig *quic.Config
|
QuicConfig *quic.Config
|
||||||
|
|
||||||
|
// Dial specifies an optional dial function for creating QUIC
|
||||||
|
// connections for requests.
|
||||||
|
// If Dial is nil, quic.DialAddr will be used.
|
||||||
|
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||||
|
|
||||||
clients map[string]roundTripCloser
|
clients map[string]roundTripCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
|||||||
if onlyCached {
|
if onlyCached {
|
||||||
return nil, ErrNoCachedConn
|
return nil, ErrNoCachedConn
|
||||||
}
|
}
|
||||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
|
client = newClient(
|
||||||
|
hostname,
|
||||||
|
r.TLSClientConfig,
|
||||||
|
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||||
|
r.QuicConfig,
|
||||||
|
r.Dial,
|
||||||
|
)
|
||||||
r.clients[hostname] = client
|
r.clients[hostname] = client
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
|
87
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
87
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
@ -7,14 +7,14 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
@ -50,6 +50,7 @@ type Server struct {
|
|||||||
|
|
||||||
listenerMutex sync.Mutex
|
listenerMutex sync.Mutex
|
||||||
listener quic.Listener
|
listener quic.Listener
|
||||||
|
closed bool
|
||||||
|
|
||||||
supportedVersionsAsString string
|
supportedVersionsAsString string
|
||||||
}
|
}
|
||||||
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||||||
return errors.New("use of h2quic.Server without http.Server")
|
return errors.New("use of h2quic.Server without http.Server")
|
||||||
}
|
}
|
||||||
s.listenerMutex.Lock()
|
s.listenerMutex.Lock()
|
||||||
|
if s.closed {
|
||||||
|
s.listenerMutex.Unlock()
|
||||||
|
return errors.New("Server is already closed")
|
||||||
|
}
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
s.listenerMutex.Unlock()
|
s.listenerMutex.Unlock()
|
||||||
return errors.New("ListenAndServe may only be called once")
|
return errors.New("ListenAndServe may only be called once")
|
||||||
@ -122,29 +127,23 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||||||
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if stream.StreamID() != 3 {
|
|
||||||
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||||
h2framer := http2.NewFramer(nil, stream)
|
h2framer := http2.NewFramer(nil, stream)
|
||||||
|
|
||||||
go func() {
|
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
for {
|
||||||
for {
|
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
// QuicErrors must originate from stream.Read() returning an error.
|
||||||
// QuicErrors must originate from stream.Read() returning an error.
|
// In this case, the session has already logged the error, so we don't
|
||||||
// In this case, the session has already logged the error, so we don't
|
// need to log it again.
|
||||||
// need to log it again.
|
if _, ok := err.(*qerr.QuicError); !ok {
|
||||||
if _, ok := err.(*qerr.QuicError); !ok {
|
utils.Errorf("error handling h2 request: %s", err.Error())
|
||||||
utils.Errorf("error handling h2 request: %s", err.Error())
|
|
||||||
}
|
|
||||||
session.Close(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
session.Close(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||||
@ -170,8 +169,6 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.RemoteAddr = session.RemoteAddr().String()
|
|
||||||
|
|
||||||
if utils.Debug() {
|
if utils.Debug() {
|
||||||
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||||
} else {
|
} else {
|
||||||
@ -187,19 +184,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var streamEnded bool
|
// handleRequest should be as non-blocking as possible to minimize
|
||||||
if h2headersFrame.StreamEnded() {
|
// head-of-line blocking. Potentially blocking code is run in a separate
|
||||||
dataStream.(remoteCloser).CloseRemote(0)
|
// goroutine, enabling handleRequest to return before the code is executed.
|
||||||
streamEnded = true
|
|
||||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBody := newRequestBody(dataStream)
|
|
||||||
req.Body = reqBody
|
|
||||||
|
|
||||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
streamEnded := h2headersFrame.StreamEnded()
|
||||||
|
if streamEnded {
|
||||||
|
dataStream.(remoteCloser).CloseRemote(0)
|
||||||
|
streamEnded = true
|
||||||
|
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||||
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(dataStream.Context())
|
||||||
|
reqBody := newRequestBody(dataStream)
|
||||||
|
req.Body = reqBody
|
||||||
|
|
||||||
|
req.RemoteAddr = session.RemoteAddr().String()
|
||||||
|
|
||||||
|
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||||
|
|
||||||
handler := s.Handler
|
handler := s.Handler
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = http.DefaultServeMux
|
handler = http.DefaultServeMux
|
||||||
@ -225,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
}
|
}
|
||||||
if responseWriter.dataStream != nil {
|
if responseWriter.dataStream != nil {
|
||||||
if !streamEnded && !reqBody.requestRead {
|
if !streamEnded && !reqBody.requestRead {
|
||||||
responseWriter.dataStream.Reset(nil)
|
// in gQUIC, the error code doesn't matter, so just use 0 here
|
||||||
|
responseWriter.dataStream.CancelRead(0)
|
||||||
}
|
}
|
||||||
responseWriter.dataStream.Close()
|
responseWriter.dataStream.Close()
|
||||||
}
|
}
|
||||||
@ -243,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
func (s *Server) Close() error {
|
func (s *Server) Close() error {
|
||||||
s.listenerMutex.Lock()
|
s.listenerMutex.Lock()
|
||||||
defer s.listenerMutex.Unlock()
|
defer s.listenerMutex.Unlock()
|
||||||
|
s.closed = true
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
err := s.listener.Close()
|
err := s.listener.Close()
|
||||||
s.listener = nil
|
s.listener = nil
|
||||||
@ -279,12 +284,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.supportedVersionsAsString == "" {
|
if s.supportedVersionsAsString == "" {
|
||||||
for i, v := range protocol.SupportedVersions {
|
var versions []string
|
||||||
s.supportedVersionsAsString += strconv.Itoa(int(v))
|
for _, v := range protocol.SupportedVersions {
|
||||||
if i != len(protocol.SupportedVersions)-1 {
|
versions = append(versions, v.ToAltSvc())
|
||||||
s.supportedVersionsAsString += ","
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
s.supportedVersionsAsString = strings.Join(versions, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
||||||
@ -344,6 +348,9 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||||||
}
|
}
|
||||||
defer tcpConn.Close()
|
defer tcpConn.Close()
|
||||||
|
|
||||||
|
tlsConn := tls.NewListener(tcpConn, config)
|
||||||
|
defer tlsConn.Close()
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
@ -365,7 +372,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||||||
hErr := make(chan error)
|
hErr := make(chan error)
|
||||||
qErr := make(chan error)
|
qErr := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
hErr <- httpServer.Serve(tcpConn)
|
hErr <- httpServer.Serve(tlsConn)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
qErr <- quicServer.Serve(udpConn)
|
qErr <- quicServer.Serve(udpConn)
|
||||||
|
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
@ -1,265 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConnectionParametersManager negotiates and stores the connection parameters
|
|
||||||
// A ConnectionParametersManager can be used for a server as well as a client
|
|
||||||
// For the server:
|
|
||||||
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
|
|
||||||
// 2. call GetHelloMap to get the values to send in the SHLO
|
|
||||||
// For the client:
|
|
||||||
// 1. call GetHelloMap to get the values to send in a CHLO
|
|
||||||
// 2. call SetFromMap with the values received in the SHLO
|
|
||||||
type ConnectionParametersManager interface {
|
|
||||||
SetFromMap(map[Tag][]byte) error
|
|
||||||
GetHelloMap() (map[Tag][]byte, error)
|
|
||||||
|
|
||||||
GetSendStreamFlowControlWindow() protocol.ByteCount
|
|
||||||
GetSendConnectionFlowControlWindow() protocol.ByteCount
|
|
||||||
GetReceiveStreamFlowControlWindow() protocol.ByteCount
|
|
||||||
GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount
|
|
||||||
GetReceiveConnectionFlowControlWindow() protocol.ByteCount
|
|
||||||
GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount
|
|
||||||
GetMaxOutgoingStreams() uint32
|
|
||||||
GetMaxIncomingStreams() uint32
|
|
||||||
GetIdleConnectionStateLifetime() time.Duration
|
|
||||||
TruncateConnectionID() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type connectionParametersManager struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
perspective protocol.Perspective
|
|
||||||
|
|
||||||
flowControlNegotiated bool
|
|
||||||
|
|
||||||
truncateConnectionID bool
|
|
||||||
maxStreamsPerConnection uint32
|
|
||||||
maxIncomingDynamicStreamsPerConnection uint32
|
|
||||||
idleConnectionStateLifetime time.Duration
|
|
||||||
sendStreamFlowControlWindow protocol.ByteCount
|
|
||||||
sendConnectionFlowControlWindow protocol.ByteCount
|
|
||||||
receiveStreamFlowControlWindow protocol.ByteCount
|
|
||||||
receiveConnectionFlowControlWindow protocol.ByteCount
|
|
||||||
maxReceiveStreamFlowControlWindow protocol.ByteCount
|
|
||||||
maxReceiveConnectionFlowControlWindow protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ConnectionParametersManager = &connectionParametersManager{}
|
|
||||||
|
|
||||||
// ErrMalformedTag is returned when the tag value cannot be read
|
|
||||||
var (
|
|
||||||
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
|
||||||
ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewConnectionParamatersManager creates a new connection parameters manager
|
|
||||||
func NewConnectionParamatersManager(
|
|
||||||
pers protocol.Perspective, v protocol.VersionNumber,
|
|
||||||
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
|
|
||||||
) ConnectionParametersManager {
|
|
||||||
h := &connectionParametersManager{
|
|
||||||
perspective: pers,
|
|
||||||
version: v,
|
|
||||||
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
|
|
||||||
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
|
|
||||||
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
|
||||||
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
|
||||||
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
|
||||||
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
|
||||||
h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout
|
|
||||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
|
||||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective
|
|
||||||
} else {
|
|
||||||
h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient
|
|
||||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
|
||||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective
|
|
||||||
}
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFromMap reads all params
|
|
||||||
func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.truncateConnectionID = (clientValue == 0)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagMSPC]; ok {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagMIDS]; ok {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagICSL]; ok {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagSFCW]; ok {
|
|
||||||
if h.flowControlNegotiated {
|
|
||||||
return ErrFlowControlRenegotiationNotSupported
|
|
||||||
}
|
|
||||||
sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagCFCW]; ok {
|
|
||||||
if h.flowControlNegotiated {
|
|
||||||
return ErrFlowControlRenegotiationNotSupported
|
|
||||||
}
|
|
||||||
sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, containsSFCW := params[TagSFCW]
|
|
||||||
_, containsCFCW := params[TagCFCW]
|
|
||||||
if containsCFCW || containsSFCW {
|
|
||||||
h.flowControlNegotiated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 {
|
|
||||||
return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 {
|
|
||||||
return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration {
|
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
|
||||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer)
|
|
||||||
}
|
|
||||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHelloMap gets all parameters needed for the Hello message
|
|
||||||
func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) {
|
|
||||||
sfcw := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow()))
|
|
||||||
cfcw := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow()))
|
|
||||||
mspc := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(mspc, h.maxStreamsPerConnection)
|
|
||||||
mids := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
|
|
||||||
icsl := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second))
|
|
||||||
|
|
||||||
return map[Tag][]byte{
|
|
||||||
TagICSL: icsl.Bytes(),
|
|
||||||
TagMSPC: mspc.Bytes(),
|
|
||||||
TagMIDS: mids.Bytes(),
|
|
||||||
TagCFCW: cfcw.Bytes(),
|
|
||||||
TagSFCW: sfcw.Bytes(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.sendStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.sendConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
|
||||||
func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.receiveStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
return h.maxReceiveStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
|
||||||
func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.receiveConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
return h.maxReceiveConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
|
|
||||||
func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
return h.maxIncomingDynamicStreamsPerConnection
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxIncomingStreams get the maximum number of incoming streams per connection
|
|
||||||
func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection
|
|
||||||
return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetIdleConnectionStateLifetime gets the idle timeout
|
|
||||||
func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.idleConnectionStateLifetime
|
|
||||||
}
|
|
||||||
|
|
||||||
// TruncateConnectionID determines if the client requests truncated ConnectionIDs
|
|
||||||
func (h *connectionParametersManager) TruncateConnectionID() bool {
|
|
||||||
if h.perspective == protocol.PerspectiveClient {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.truncateConnectionID
|
|
||||||
}
|
|
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
@ -1,24 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
|
|
||||||
// Sealer seals a packet
|
|
||||||
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
|
||||||
|
|
||||||
// CryptoSetup is a crypto setup
|
|
||||||
type CryptoSetup interface {
|
|
||||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
|
||||||
HandleCryptoStream() error
|
|
||||||
// TODO: clean up this interface
|
|
||||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
|
||||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
|
||||||
|
|
||||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
|
||||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
|
||||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TransportParameters are parameters sent to the peer during the handshake
|
|
||||||
type TransportParameters struct {
|
|
||||||
RequestConnectionIDTruncation bool
|
|
||||||
}
|
|
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
@ -1,100 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/asn1"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/crypto"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
stkPrefixIP byte = iota
|
|
||||||
stkPrefixString
|
|
||||||
)
|
|
||||||
|
|
||||||
// An STK is a source address token
|
|
||||||
type STK struct {
|
|
||||||
RemoteAddr string
|
|
||||||
SentTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// token is the struct that is used for ASN1 serialization and deserialization
|
|
||||||
type token struct {
|
|
||||||
Data []byte
|
|
||||||
Timestamp int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// An STKGenerator generates STKs
|
|
||||||
type STKGenerator struct {
|
|
||||||
stkSource crypto.StkSource
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSTKGenerator initializes a new STKGenerator
|
|
||||||
func NewSTKGenerator() (*STKGenerator, error) {
|
|
||||||
stkSource, err := crypto.NewStkSource()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &STKGenerator{
|
|
||||||
stkSource: stkSource,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewToken generates a new STK token for a given source address
|
|
||||||
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
|
||||||
data, err := asn1.Marshal(token{
|
|
||||||
Data: encodeRemoteAddr(raddr),
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return g.stkSource.NewToken(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecodeToken decodes an STK token
|
|
||||||
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
|
|
||||||
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
|
|
||||||
if len(encrypted) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := g.stkSource.DecodeToken(encrypted)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t := &token{}
|
|
||||||
rest, err := asn1.Unmarshal(data, t)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(rest) != 0 {
|
|
||||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
|
||||||
}
|
|
||||||
return &STK{
|
|
||||||
RemoteAddr: decodeRemoteAddr(t.Data),
|
|
||||||
SentTime: time.Unix(t.Timestamp, 0),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
|
||||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
|
||||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
|
||||||
return append([]byte{stkPrefixIP}, udpAddr.IP...)
|
|
||||||
}
|
|
||||||
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeRemoteAddr decodes the remote address saved in the STK
|
|
||||||
func decodeRemoteAddr(data []byte) string {
|
|
||||||
// data will never be empty for an STK that we generated. Check it to be on the safe side
|
|
||||||
if len(data) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if data[0] == stkPrefixIP {
|
|
||||||
return net.IP(data[1:]).String()
|
|
||||||
}
|
|
||||||
return string(data[1:])
|
|
||||||
}
|
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
@ -1 +0,0 @@
|
|||||||
package chrome
|
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
@ -1 +0,0 @@
|
|||||||
package gquic
|
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
@ -1 +0,0 @@
|
|||||||
package self
|
|
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
@ -1,14 +1,12 @@
|
|||||||
package quicproxy
|
package quicproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Connection is a UDP connection
|
// Connection is a UDP connection
|
||||||
@ -28,21 +26,43 @@ const (
|
|||||||
DirectionIncoming Direction = iota
|
DirectionIncoming Direction = iota
|
||||||
// DirectionOutgoing is the direction from the server to the client.
|
// DirectionOutgoing is the direction from the server to the client.
|
||||||
DirectionOutgoing
|
DirectionOutgoing
|
||||||
|
// DirectionBoth is both incoming and outgoing
|
||||||
|
DirectionBoth
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (d Direction) String() string {
|
||||||
|
switch d {
|
||||||
|
case DirectionIncoming:
|
||||||
|
return "incoming"
|
||||||
|
case DirectionOutgoing:
|
||||||
|
return "outgoing"
|
||||||
|
case DirectionBoth:
|
||||||
|
return "both"
|
||||||
|
default:
|
||||||
|
panic("unknown direction")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Direction) Is(dir Direction) bool {
|
||||||
|
if d == DirectionBoth || dir == DirectionBoth {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return d == dir
|
||||||
|
}
|
||||||
|
|
||||||
// DropCallback is a callback that determines which packet gets dropped.
|
// DropCallback is a callback that determines which packet gets dropped.
|
||||||
type DropCallback func(Direction, protocol.PacketNumber) bool
|
type DropCallback func(dir Direction, packetCount uint64) bool
|
||||||
|
|
||||||
// NoDropper doesn't drop packets.
|
// NoDropper doesn't drop packets.
|
||||||
var NoDropper DropCallback = func(Direction, protocol.PacketNumber) bool {
|
var NoDropper DropCallback = func(Direction, uint64) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
||||||
type DelayCallback func(Direction, protocol.PacketNumber) time.Duration
|
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
|
||||||
|
|
||||||
// NoDelay doesn't apply a delay.
|
// NoDelay doesn't apply a delay.
|
||||||
var NoDelay DelayCallback = func(Direction, protocol.PacketNumber) time.Duration {
|
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,6 +82,8 @@ type Opts struct {
|
|||||||
type QuicProxy struct {
|
type QuicProxy struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
|
version protocol.VersionNumber
|
||||||
|
|
||||||
conn *net.UDPConn
|
conn *net.UDPConn
|
||||||
serverAddr *net.UDPAddr
|
serverAddr *net.UDPAddr
|
||||||
|
|
||||||
@ -73,7 +95,10 @@ type QuicProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewQuicProxy creates a new UDP proxy
|
// NewQuicProxy creates a new UDP proxy
|
||||||
func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
|
||||||
|
if opts == nil {
|
||||||
|
opts = &Opts{}
|
||||||
|
}
|
||||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -103,6 +128,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
|||||||
serverAddr: raddr,
|
serverAddr: raddr,
|
||||||
dropPacket: packetDropper,
|
dropPacket: packetDropper,
|
||||||
delayPacket: packetDelayer,
|
delayPacket: packetDelayer,
|
||||||
|
version: version,
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.runProxy()
|
go p.runProxy()
|
||||||
@ -119,6 +145,7 @@ func (p *QuicProxy) LocalAddr() net.Addr {
|
|||||||
return p.conn.LocalAddr()
|
return p.conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LocalPort is the UDP port number the proxy is listening on.
|
||||||
func (p *QuicProxy) LocalPort() int {
|
func (p *QuicProxy) LocalPort() int {
|
||||||
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
}
|
}
|
||||||
@ -137,7 +164,7 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
|||||||
// runProxy listens on the proxy address and handles incoming packets.
|
// runProxy listens on the proxy address and handles incoming packets.
|
||||||
func (p *QuicProxy) runProxy() error {
|
func (p *QuicProxy) runProxy() error {
|
||||||
for {
|
for {
|
||||||
buffer := make([]byte, protocol.MaxPacketSize)
|
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||||
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -159,20 +186,14 @@ func (p *QuicProxy) runProxy() error {
|
|||||||
}
|
}
|
||||||
p.mutex.Unlock()
|
p.mutex.Unlock()
|
||||||
|
|
||||||
atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||||
|
|
||||||
r := bytes.NewReader(raw)
|
if p.dropPacket(DirectionIncoming, packetCount) {
|
||||||
hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.dropPacket(DirectionIncoming, hdr.PacketNumber) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the packet to the server
|
// Send the packet to the server
|
||||||
delay := p.delayPacket(DirectionIncoming, hdr.PacketNumber)
|
delay := p.delayPacket(DirectionIncoming, packetCount)
|
||||||
if delay != 0 {
|
if delay != 0 {
|
||||||
time.AfterFunc(delay, func() {
|
time.AfterFunc(delay, func() {
|
||||||
// TODO: handle error
|
// TODO: handle error
|
||||||
@ -190,28 +211,20 @@ func (p *QuicProxy) runProxy() error {
|
|||||||
// runConnection handles packets from server to a single client
|
// runConnection handles packets from server to a single client
|
||||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
func (p *QuicProxy) runConnection(conn *connection) error {
|
||||||
for {
|
for {
|
||||||
buffer := make([]byte, protocol.MaxPacketSize)
|
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||||
n, err := conn.ServerConn.Read(buffer)
|
n, err := conn.ServerConn.Read(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
raw := buffer[0:n]
|
raw := buffer[0:n]
|
||||||
|
|
||||||
// TODO: Switch back to using the public header once Chrome properly sets the type byte.
|
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||||
// r := bytes.NewReader(raw)
|
|
||||||
// , err := quic.ParsePublicHeader(r, protocol.PerspectiveServer)
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
|
|
||||||
v := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
if p.dropPacket(DirectionOutgoing, packetCount) {
|
||||||
|
|
||||||
packetNumber := protocol.PacketNumber(v)
|
|
||||||
if p.dropPacket(DirectionOutgoing, packetNumber) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
delay := p.delayPacket(DirectionOutgoing, packetNumber)
|
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
||||||
if delay != 0 {
|
if delay != 0 {
|
||||||
time.AfterFunc(delay, func() {
|
time.AfterFunc(delay, func() {
|
||||||
// TODO: handle error
|
// TODO: handle error
|
||||||
|
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
@ -27,7 +27,7 @@ var _ = BeforeEach(func() {
|
|||||||
|
|
||||||
if len(logFileName) > 0 {
|
if len(logFileName) > 0 {
|
||||||
var err error
|
var err error
|
||||||
logFile, err = os.Create("./log.txt")
|
logFile, err = os.Create(logFileName)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
log.SetOutput(logFile)
|
log.SetOutput(logFile)
|
||||||
utils.SetLogLevel(utils.LogLevelDebug)
|
utils.SetLogLevel(utils.LogLevelDebug)
|
||||||
|
18
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
18
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
@ -7,7 +7,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/h2quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
@ -23,8 +25,9 @@ var (
|
|||||||
PRData = GeneratePRData(dataLen)
|
PRData = GeneratePRData(dataLen)
|
||||||
PRDataLong = GeneratePRData(dataLenLong)
|
PRDataLong = GeneratePRData(dataLenLong)
|
||||||
|
|
||||||
server *h2quic.Server
|
server *h2quic.Server
|
||||||
port string
|
stoppedServing chan struct{}
|
||||||
|
port string
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -75,11 +78,16 @@ func GeneratePRData(l int) []byte {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartQuicServer() {
|
// StartQuicServer starts a h2quic.Server.
|
||||||
|
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
||||||
|
func StartQuicServer(versions []protocol.VersionNumber) {
|
||||||
server = &h2quic.Server{
|
server = &h2quic.Server{
|
||||||
Server: &http.Server{
|
Server: &http.Server{
|
||||||
TLSConfig: testdata.GetTLSConfig(),
|
TLSConfig: testdata.GetTLSConfig(),
|
||||||
},
|
},
|
||||||
|
QuicConfig: &quic.Config{
|
||||||
|
Versions: versions,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||||
@ -88,14 +96,18 @@ func StartQuicServer() {
|
|||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
||||||
|
|
||||||
|
stoppedServing = make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
server.Serve(conn)
|
server.Serve(conn)
|
||||||
|
close(stoppedServing)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func StopQuicServer() {
|
func StopQuicServer() {
|
||||||
Expect(server.Close()).NotTo(HaveOccurred())
|
Expect(server.Close()).NotTo(HaveOccurred())
|
||||||
|
Eventually(stoppedServing).Should(BeClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Port() string {
|
func Port() string {
|
||||||
|
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
@ -6,23 +6,55 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// The StreamID is the ID of a QUIC stream.
|
||||||
|
type StreamID = protocol.StreamID
|
||||||
|
|
||||||
|
// A VersionNumber is a QUIC version number.
|
||||||
|
type VersionNumber = protocol.VersionNumber
|
||||||
|
|
||||||
|
// A Cookie can be used to verify the ownership of the client address.
|
||||||
|
type Cookie = handshake.Cookie
|
||||||
|
|
||||||
|
// ConnectionState records basic details about the QUIC connection.
|
||||||
|
type ConnectionState = handshake.ConnectionState
|
||||||
|
|
||||||
|
// An ErrorCode is an application-defined error code.
|
||||||
|
type ErrorCode = protocol.ApplicationErrorCode
|
||||||
|
|
||||||
// Stream is the interface implemented by QUIC streams
|
// Stream is the interface implemented by QUIC streams
|
||||||
type Stream interface {
|
type Stream interface {
|
||||||
|
// StreamID returns the stream ID.
|
||||||
|
StreamID() StreamID
|
||||||
// Read reads data from the stream.
|
// Read reads data from the stream.
|
||||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||||
|
// If the stream was canceled by the peer, the error implements the StreamError
|
||||||
|
// interface, and Canceled() == true.
|
||||||
io.Reader
|
io.Reader
|
||||||
// Write writes data to the stream.
|
// Write writes data to the stream.
|
||||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||||
|
// If the stream was canceled by the peer, the error implements the StreamError
|
||||||
|
// interface, and Canceled() == true.
|
||||||
io.Writer
|
io.Writer
|
||||||
|
// Close closes the write-direction of the stream.
|
||||||
|
// Future calls to Write are not permitted after calling Close.
|
||||||
|
// It must not be called concurrently with Write.
|
||||||
|
// It must not be called after calling CancelWrite.
|
||||||
io.Closer
|
io.Closer
|
||||||
StreamID() protocol.StreamID
|
// CancelWrite aborts sending on this stream.
|
||||||
// Reset closes the stream with an error.
|
// It must not be called after Close.
|
||||||
Reset(error)
|
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||||
|
// Write will unblock immediately, and future calls to Write will fail.
|
||||||
|
CancelWrite(ErrorCode) error
|
||||||
|
// CancelRead aborts receiving on this stream.
|
||||||
|
// It will ask the peer to stop transmitting stream data.
|
||||||
|
// Read will unblock immediately, and future Read calls will fail.
|
||||||
|
CancelRead(ErrorCode) error
|
||||||
// The context is canceled as soon as the write-side of the stream is closed.
|
// The context is canceled as soon as the write-side of the stream is closed.
|
||||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
@ -43,6 +75,41 @@ type Stream interface {
|
|||||||
SetDeadline(t time.Time) error
|
SetDeadline(t time.Time) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A ReceiveStream is a unidirectional Receive Stream.
|
||||||
|
type ReceiveStream interface {
|
||||||
|
// see Stream.StreamID
|
||||||
|
StreamID() StreamID
|
||||||
|
// see Stream.Read
|
||||||
|
io.Reader
|
||||||
|
// see Stream.CancelRead
|
||||||
|
CancelRead(ErrorCode) error
|
||||||
|
// see Stream.SetReadDealine
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// A SendStream is a unidirectional Send Stream.
|
||||||
|
type SendStream interface {
|
||||||
|
// see Stream.StreamID
|
||||||
|
StreamID() StreamID
|
||||||
|
// see Stream.Write
|
||||||
|
io.Writer
|
||||||
|
// see Stream.Close
|
||||||
|
io.Closer
|
||||||
|
// see Stream.CancelWrite
|
||||||
|
CancelWrite(ErrorCode) error
|
||||||
|
// see Stream.Context
|
||||||
|
Context() context.Context
|
||||||
|
// see Stream.SetWriteDeadline
|
||||||
|
SetWriteDeadline(t time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||||
|
type StreamError interface {
|
||||||
|
error
|
||||||
|
Canceled() bool
|
||||||
|
ErrorCode() ErrorCode
|
||||||
|
}
|
||||||
|
|
||||||
// A Session is a QUIC connection between two peers.
|
// A Session is a QUIC connection between two peers.
|
||||||
type Session interface {
|
type Session interface {
|
||||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||||
@ -64,53 +131,41 @@ type Session interface {
|
|||||||
// The context is cancelled when the session is closed.
|
// The context is cancelled when the session is closed.
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
}
|
// ConnectionState returns basic details about the QUIC connection.
|
||||||
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
ConnectionState() ConnectionState
|
||||||
// The communication is encrypted, but not yet forward secure.
|
|
||||||
type NonFWSession interface {
|
|
||||||
Session
|
|
||||||
WaitUntilHandshakeComplete() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// An STK is a Source Address token.
|
|
||||||
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
|
|
||||||
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
|
|
||||||
type STK struct {
|
|
||||||
// The remote address this token was issued for.
|
|
||||||
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
|
|
||||||
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
|
|
||||||
remoteAddr string
|
|
||||||
// The time that the STK was issued (resolution 1 second)
|
|
||||||
sentTime time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config contains all configuration data needed for a QUIC server or client.
|
// Config contains all configuration data needed for a QUIC server or client.
|
||||||
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// The QUIC versions that can be negotiated.
|
// The QUIC versions that can be negotiated.
|
||||||
// If not set, it uses all versions available.
|
// If not set, it uses all versions available.
|
||||||
// Warning: This API should not be considered stable and will change soon.
|
// Warning: This API should not be considered stable and will change soon.
|
||||||
Versions []protocol.VersionNumber
|
Versions []VersionNumber
|
||||||
// Ask the server to truncate the connection ID sent in the Public Header.
|
// Ask the server to omit the connection ID sent in the Public Header.
|
||||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||||
// Currently only valid for the client.
|
// Currently only valid for the client.
|
||||||
RequestConnectionIDTruncation bool
|
RequestConnectionIDOmission bool
|
||||||
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
||||||
// If the timeout is exceeded, the connection is closed.
|
// If the timeout is exceeded, the connection is closed.
|
||||||
// If this value is zero, the timeout is set to 10 seconds.
|
// If this value is zero, the timeout is set to 10 seconds.
|
||||||
HandshakeTimeout time.Duration
|
HandshakeTimeout time.Duration
|
||||||
// AcceptSTK determines if an STK is accepted.
|
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||||
// It is called with stk = nil if the client didn't send an STK.
|
// This value only applies after the handshake has completed.
|
||||||
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
|
// If the timeout is exceeded, the connection is closed.
|
||||||
|
// If this value is zero, the timeout is set to 30 seconds.
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
// AcceptCookie determines if a Cookie is accepted.
|
||||||
|
// It is called with cookie = nil if the client didn't send an Cookie.
|
||||||
|
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
|
||||||
// This option is only valid for the server.
|
// This option is only valid for the server.
|
||||||
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
|
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
|
||||||
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
|
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
|
||||||
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
|
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
|
||||||
MaxReceiveStreamFlowControlWindow protocol.ByteCount
|
MaxReceiveStreamFlowControlWindow uint64
|
||||||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||||
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
|
MaxReceiveConnectionFlowControlWindow uint64
|
||||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||||
KeepAlive bool
|
KeepAlive bool
|
||||||
}
|
}
|
||||||
|
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SentPacketHandler handles ACKs received for outgoing packets
|
||||||
|
type SentPacketHandler interface {
|
||||||
|
// SentPacket may modify the packet
|
||||||
|
SentPacket(packet *Packet) error
|
||||||
|
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||||
|
SetHandshakeComplete()
|
||||||
|
|
||||||
|
// SendingAllowed says if a packet can be sent.
|
||||||
|
// Sending packets might not be possible because:
|
||||||
|
// * we're congestion limited
|
||||||
|
// * we're tracking the maximum number of sent packets
|
||||||
|
SendingAllowed() bool
|
||||||
|
// TimeUntilSend is the time when the next packet should be sent.
|
||||||
|
// It is used for pacing packets.
|
||||||
|
TimeUntilSend() time.Time
|
||||||
|
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
|
||||||
|
// It always returns a number greater or equal than 1.
|
||||||
|
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
|
||||||
|
// Note that the number of packets is only calculated based on the pacing algorithm.
|
||||||
|
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||||
|
ShouldSendNumPackets() int
|
||||||
|
|
||||||
|
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||||
|
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||||
|
DequeuePacketForRetransmission() (packet *Packet)
|
||||||
|
GetLeastUnacked() protocol.PacketNumber
|
||||||
|
|
||||||
|
GetAlarmTimeout() time.Time
|
||||||
|
OnAlarm()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||||
|
type ReceivedPacketHandler interface {
|
||||||
|
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||||
|
IgnoreBelow(protocol.PacketNumber)
|
||||||
|
|
||||||
|
GetAlarmTimeout() time.Time
|
||||||
|
GetAckFrame() *wire.AckFrame
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user