mirror of
https://github.com/caddyserver/caddy.git
synced 2025-06-04 22:25:31 -04:00
vendor: Updated quic-go for QUIC 39+ (#1968)
* Updated lucas-clemente/quic-go for QUIC 39+ support * Update quic-go to latest
This commit is contained in:
parent
faa5248d1f
commit
1201492222
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2016 Richard Barnes
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mint
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
type Alert uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// alert level
|
||||||
|
AlertLevelWarning = 1
|
||||||
|
AlertLevelError = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AlertCloseNotify Alert = 0
|
||||||
|
AlertUnexpectedMessage Alert = 10
|
||||||
|
AlertBadRecordMAC Alert = 20
|
||||||
|
AlertDecryptionFailed Alert = 21
|
||||||
|
AlertRecordOverflow Alert = 22
|
||||||
|
AlertDecompressionFailure Alert = 30
|
||||||
|
AlertHandshakeFailure Alert = 40
|
||||||
|
AlertBadCertificate Alert = 42
|
||||||
|
AlertUnsupportedCertificate Alert = 43
|
||||||
|
AlertCertificateRevoked Alert = 44
|
||||||
|
AlertCertificateExpired Alert = 45
|
||||||
|
AlertCertificateUnknown Alert = 46
|
||||||
|
AlertIllegalParameter Alert = 47
|
||||||
|
AlertUnknownCA Alert = 48
|
||||||
|
AlertAccessDenied Alert = 49
|
||||||
|
AlertDecodeError Alert = 50
|
||||||
|
AlertDecryptError Alert = 51
|
||||||
|
AlertProtocolVersion Alert = 70
|
||||||
|
AlertInsufficientSecurity Alert = 71
|
||||||
|
AlertInternalError Alert = 80
|
||||||
|
AlertInappropriateFallback Alert = 86
|
||||||
|
AlertUserCanceled Alert = 90
|
||||||
|
AlertNoRenegotiation Alert = 100
|
||||||
|
AlertMissingExtension Alert = 109
|
||||||
|
AlertUnsupportedExtension Alert = 110
|
||||||
|
AlertCertificateUnobtainable Alert = 111
|
||||||
|
AlertUnrecognizedName Alert = 112
|
||||||
|
AlertBadCertificateStatsResponse Alert = 113
|
||||||
|
AlertBadCertificateHashValue Alert = 114
|
||||||
|
AlertUnknownPSKIdentity Alert = 115
|
||||||
|
AlertNoApplicationProtocol Alert = 120
|
||||||
|
AlertWouldBlock Alert = 254
|
||||||
|
AlertNoAlert Alert = 255
|
||||||
|
)
|
||||||
|
|
||||||
|
var alertText = map[Alert]string{
|
||||||
|
AlertCloseNotify: "close notify",
|
||||||
|
AlertUnexpectedMessage: "unexpected message",
|
||||||
|
AlertBadRecordMAC: "bad record MAC",
|
||||||
|
AlertDecryptionFailed: "decryption failed",
|
||||||
|
AlertRecordOverflow: "record overflow",
|
||||||
|
AlertDecompressionFailure: "decompression failure",
|
||||||
|
AlertHandshakeFailure: "handshake failure",
|
||||||
|
AlertBadCertificate: "bad certificate",
|
||||||
|
AlertUnsupportedCertificate: "unsupported certificate",
|
||||||
|
AlertCertificateRevoked: "revoked certificate",
|
||||||
|
AlertCertificateExpired: "expired certificate",
|
||||||
|
AlertCertificateUnknown: "unknown certificate",
|
||||||
|
AlertIllegalParameter: "illegal parameter",
|
||||||
|
AlertUnknownCA: "unknown certificate authority",
|
||||||
|
AlertAccessDenied: "access denied",
|
||||||
|
AlertDecodeError: "error decoding message",
|
||||||
|
AlertDecryptError: "error decrypting message",
|
||||||
|
AlertProtocolVersion: "protocol version not supported",
|
||||||
|
AlertInsufficientSecurity: "insufficient security level",
|
||||||
|
AlertInternalError: "internal error",
|
||||||
|
AlertInappropriateFallback: "inappropriate fallback",
|
||||||
|
AlertUserCanceled: "user canceled",
|
||||||
|
AlertMissingExtension: "missing extension",
|
||||||
|
AlertUnsupportedExtension: "unsupported extension",
|
||||||
|
AlertCertificateUnobtainable: "certificate unobtainable",
|
||||||
|
AlertUnrecognizedName: "unrecognized name",
|
||||||
|
AlertBadCertificateStatsResponse: "bad certificate status response",
|
||||||
|
AlertBadCertificateHashValue: "bad certificate hash value",
|
||||||
|
AlertUnknownPSKIdentity: "unknown PSK identity",
|
||||||
|
AlertNoApplicationProtocol: "no application protocol",
|
||||||
|
AlertNoRenegotiation: "no renegotiation",
|
||||||
|
AlertWouldBlock: "would have blocked",
|
||||||
|
AlertNoAlert: "no alert",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Alert) String() string {
|
||||||
|
s, ok := alertText[e]
|
||||||
|
if ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return "alert(" + strconv.Itoa(int(e)) + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Alert) Error() string {
|
||||||
|
return e.String()
|
||||||
|
}
|
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
)
|
||||||
|
|
||||||
|
var url string
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
url := flag.String("url", "https://localhost:4430", "URL to send request")
|
||||||
|
flag.Parse()
|
||||||
|
mintdial := func(network, addr string) (net.Conn, error) {
|
||||||
|
return mint.Dial(network, addr, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
tr := &http.Transport{
|
||||||
|
DialTLS: mintdial,
|
||||||
|
DisableCompression: true,
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
|
||||||
|
response, err := client.Get(*url)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("err:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
contents, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("%s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("%s\n", string(contents))
|
||||||
|
}
|
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
)
|
||||||
|
|
||||||
|
var addr string
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&addr, "addr", "localhost:4430", "port")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
conn, err := mint.Dial("tcp", addr, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("TLS handshake failed:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
request := "GET / HTTP/1.0\r\n\r\n"
|
||||||
|
conn.Write([]byte(request))
|
||||||
|
|
||||||
|
response := ""
|
||||||
|
buffer := make([]byte, 1024)
|
||||||
|
var read int
|
||||||
|
for err == nil {
|
||||||
|
read, err = conn.Read(buffer)
|
||||||
|
fmt.Println(" ~~ read: ", read)
|
||||||
|
response += string(buffer)
|
||||||
|
}
|
||||||
|
fmt.Println("err:", err)
|
||||||
|
fmt.Println("Received from server:")
|
||||||
|
fmt.Println(response)
|
||||||
|
}
|
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
port string
|
||||||
|
serverName string
|
||||||
|
certFile string
|
||||||
|
keyFile string
|
||||||
|
responseFile string
|
||||||
|
h2 bool
|
||||||
|
sendTickets bool
|
||||||
|
)
|
||||||
|
|
||||||
|
type responder []byte
|
||||||
|
|
||||||
|
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(rsp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
|
||||||
|
// PEM-encoded private key.
|
||||||
|
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||||
|
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
|
||||||
|
keyDER, _ := pem.Decode(keyPEM)
|
||||||
|
if keyDER == nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
// We don't include the actual error into
|
||||||
|
// the final error. The reason might be
|
||||||
|
// we don't want to leak any info about
|
||||||
|
// the private key.
|
||||||
|
return nil, fmt.Errorf("No successful private key decoder")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch generalKey.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return generalKey.(*rsa.PrivateKey), nil
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return generalKey.(*ecdsa.PrivateKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// should never reach here
|
||||||
|
return nil, fmt.Errorf("Should be unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
|
||||||
|
// either a raw x509 certificate or a PKCS #7 structure possibly containing
|
||||||
|
// multiple certificates, from the top of certsPEM, which itself may
|
||||||
|
// contain multiple PEM encoded certificate objects.
|
||||||
|
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||||
|
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
|
||||||
|
block, rest := pem.Decode(certsPEM)
|
||||||
|
if block == nil {
|
||||||
|
return nil, rest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
var certs = []*x509.Certificate{cert}
|
||||||
|
return certs, rest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
|
||||||
|
// can handle PEM encoded PKCS #7 structures.
|
||||||
|
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||||
|
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
|
||||||
|
var certs []*x509.Certificate
|
||||||
|
var err error
|
||||||
|
certsPEM = bytes.TrimSpace(certsPEM)
|
||||||
|
for len(certsPEM) > 0 {
|
||||||
|
var cert []*x509.Certificate
|
||||||
|
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if cert == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
certs = append(certs, cert...)
|
||||||
|
}
|
||||||
|
if len(certsPEM) > 0 {
|
||||||
|
return nil, fmt.Errorf("Trailing PEM data")
|
||||||
|
}
|
||||||
|
return certs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&port, "port", "4430", "port")
|
||||||
|
flag.StringVar(&serverName, "host", "example.com", "hostname")
|
||||||
|
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
|
||||||
|
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
|
||||||
|
flag.StringVar(&responseFile, "response", "", "file to serve")
|
||||||
|
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
|
||||||
|
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
var certChain []*x509.Certificate
|
||||||
|
var priv crypto.Signer
|
||||||
|
var response []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Load the key and certificate chain
|
||||||
|
if certFile != "" {
|
||||||
|
certs, err := ioutil.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
} else {
|
||||||
|
certChain, err = ParseCertificatesPEM(certs)
|
||||||
|
if err != nil {
|
||||||
|
certChain, err = x509.ParseCertificates(certs)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error parsing certificates: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if keyFile != "" {
|
||||||
|
keyPEM, err := ioutil.ReadFile(keyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
} else {
|
||||||
|
priv, err = ParsePrivateKeyPEM(keyPEM)
|
||||||
|
if priv == nil || err != nil {
|
||||||
|
log.Fatalf("Error parsing private key: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load response file
|
||||||
|
if responseFile != "" {
|
||||||
|
log.Printf("Loading response file: %v", responseFile)
|
||||||
|
response, err = ioutil.ReadFile(responseFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response = []byte("Welcome to the TLS 1.3 zone!")
|
||||||
|
}
|
||||||
|
handler := responder(response)
|
||||||
|
|
||||||
|
config := mint.Config{
|
||||||
|
SendSessionTickets: true,
|
||||||
|
ServerName: serverName,
|
||||||
|
NextProtos: []string{"http/1.1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if h2 {
|
||||||
|
config.NextProtos = []string{"h2"}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.SendSessionTickets = sendTickets
|
||||||
|
|
||||||
|
if certChain != nil && priv != nil {
|
||||||
|
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
|
||||||
|
config.Certificates = []*mint.Certificate{
|
||||||
|
{
|
||||||
|
Chain: certChain,
|
||||||
|
PrivateKey: priv,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Init(false)
|
||||||
|
|
||||||
|
service := "0.0.0.0:" + port
|
||||||
|
srv := &http.Server{Handler: handler}
|
||||||
|
|
||||||
|
log.Printf("Listening on port %v", port)
|
||||||
|
// Need the inner loop here because the h1 server errors on a dropped connection
|
||||||
|
// Need the outer loop here because the h2 server is per-connection
|
||||||
|
for {
|
||||||
|
listener, err := mint.Listen("tcp", service, &config)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Listen Error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h2 {
|
||||||
|
alert := srv.Serve(listener)
|
||||||
|
if alert != mint.AlertNoAlert {
|
||||||
|
log.Printf("Serve Error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
srv2 := new(http2.Server)
|
||||||
|
opts := &http2.ServeConnOpts{
|
||||||
|
Handler: handler,
|
||||||
|
BaseConfig: srv,
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Accept error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go srv2.ServeConn(conn, opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
)
|
||||||
|
|
||||||
|
var port string
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var config mint.Config
|
||||||
|
config.SendSessionTickets = true
|
||||||
|
config.ServerName = "localhost"
|
||||||
|
config.Init(false)
|
||||||
|
|
||||||
|
flag.StringVar(&port, "port", "4430", "port")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
service := "0.0.0.0:" + port
|
||||||
|
listener, err := mint.Listen("tcp", service, &config)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("server: listen: %s", err)
|
||||||
|
}
|
||||||
|
log.Print("server: listening")
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("server: accept: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
log.Printf("server: accepted from %s", conn.RemoteAddr())
|
||||||
|
go handleClient(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleClient(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 10)
|
||||||
|
for {
|
||||||
|
log.Print("server: conn: waiting")
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("server: conn: read: %s", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = conn.Write([]byte("hello world"))
|
||||||
|
log.Printf("server: conn: wrote %d bytes", n)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("server: write: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Println("server: conn: closed")
|
||||||
|
}
|
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
@ -0,0 +1,942 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"hash"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client State Machine
|
||||||
|
//
|
||||||
|
// START <----+
|
||||||
|
// Send ClientHello | | Recv HelloRetryRequest
|
||||||
|
// / v |
|
||||||
|
// | WAIT_SH ---+
|
||||||
|
// Can | | Recv ServerHello
|
||||||
|
// send | V
|
||||||
|
// early | WAIT_EE
|
||||||
|
// data | | Recv EncryptedExtensions
|
||||||
|
// | +--------+--------+
|
||||||
|
// | Using | | Using certificate
|
||||||
|
// | PSK | v
|
||||||
|
// | | WAIT_CERT_CR
|
||||||
|
// | | Recv | | Recv CertificateRequest
|
||||||
|
// | | Certificate | v
|
||||||
|
// | | | WAIT_CERT
|
||||||
|
// | | | | Recv Certificate
|
||||||
|
// | | v v
|
||||||
|
// | | WAIT_CV
|
||||||
|
// | | | Recv CertificateVerify
|
||||||
|
// | +> WAIT_FINISHED <+
|
||||||
|
// | | Recv Finished
|
||||||
|
// \ |
|
||||||
|
// | [Send EndOfEarlyData]
|
||||||
|
// | [Send Certificate [+ CertificateVerify]]
|
||||||
|
// | Send Finished
|
||||||
|
// Can send v
|
||||||
|
// app data --> CONNECTED
|
||||||
|
// after
|
||||||
|
// here
|
||||||
|
//
|
||||||
|
// State Instructions
|
||||||
|
// START Send(CH); [RekeyOut; SendEarlyData]
|
||||||
|
// WAIT_SH Send(CH) || RekeyIn
|
||||||
|
// WAIT_EE {}
|
||||||
|
// WAIT_CERT_CR {}
|
||||||
|
// WAIT_CERT {}
|
||||||
|
// WAIT_CV {}
|
||||||
|
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
||||||
|
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||||
|
|
||||||
|
type ClientStateStart struct {
|
||||||
|
Caps Capabilities
|
||||||
|
Opts ConnectionOptions
|
||||||
|
Params ConnectionParameters
|
||||||
|
|
||||||
|
cookie []byte
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// key_shares
|
||||||
|
offeredDH := map[NamedGroup][]byte{}
|
||||||
|
ks := KeyShareExtension{
|
||||||
|
HandshakeType: HandshakeTypeClientHello,
|
||||||
|
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
|
||||||
|
}
|
||||||
|
for i, group := range state.Caps.Groups {
|
||||||
|
pub, priv, err := newKeyShare(group)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Shares[i].Group = group
|
||||||
|
ks.Shares[i].KeyExchange = pub
|
||||||
|
offeredDH[group] = priv
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "opts: %+v", state.Opts)
|
||||||
|
|
||||||
|
// supported_versions, supported_groups, signature_algorithms, server_name
|
||||||
|
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
|
||||||
|
sni := ServerNameExtension(state.Opts.ServerName)
|
||||||
|
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
|
||||||
|
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||||
|
|
||||||
|
state.Params.ServerName = state.Opts.ServerName
|
||||||
|
|
||||||
|
// Application Layer Protocol Negotiation
|
||||||
|
var alpn *ALPNExtension
|
||||||
|
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
|
||||||
|
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct base ClientHello
|
||||||
|
ch := &ClientHelloBody{
|
||||||
|
CipherSuites: state.Caps.CipherSuites,
|
||||||
|
}
|
||||||
|
_, err := prng.Read(ch.Random[:])
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
|
||||||
|
err := ch.Extensions.Add(ext)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// XXX: These optional extensions can't be folded into the above because Go
|
||||||
|
// interface-typed values are never reported as nil
|
||||||
|
if alpn != nil {
|
||||||
|
err := ch.Extensions.Add(alpn)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if state.cookie != nil {
|
||||||
|
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PSK and EarlyData just before transmitting, so that we can
|
||||||
|
// calculate the PSK binder value
|
||||||
|
var psk *PreSharedKeyExtension
|
||||||
|
var ed *EarlyDataExtension
|
||||||
|
var offeredPSK PreSharedKey
|
||||||
|
var earlyHash crypto.Hash
|
||||||
|
var earlySecret []byte
|
||||||
|
var clientEarlyTrafficKeys keySet
|
||||||
|
var clientHello *HandshakeMessage
|
||||||
|
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
|
||||||
|
offeredPSK = key
|
||||||
|
|
||||||
|
// Narrow ciphersuites to ones that match PSK hash
|
||||||
|
params, ok := cipherSuiteMap[key.CipherSuite]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
compatibleSuites := []CipherSuite{}
|
||||||
|
for _, suite := range ch.CipherSuites {
|
||||||
|
if cipherSuiteMap[suite].Hash == params.Hash {
|
||||||
|
compatibleSuites = append(compatibleSuites, suite)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ch.CipherSuites = compatibleSuites
|
||||||
|
|
||||||
|
// Signal early data if we're going to do it
|
||||||
|
if len(state.Opts.EarlyData) > 0 {
|
||||||
|
state.Params.ClientSendingEarlyData = true
|
||||||
|
ed = &EarlyDataExtension{}
|
||||||
|
err = ch.Extensions.Add(ed)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "Error adding early data extension: %v", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal supported PSK key exchange modes
|
||||||
|
if len(state.Caps.PSKModes) == 0 {
|
||||||
|
logf(logTypeHandshake, "PSK selected, but no PSKModes")
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
|
||||||
|
err = ch.Extensions.Add(kem)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the shim PSK extension to the ClientHello
|
||||||
|
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
|
||||||
|
psk = &PreSharedKeyExtension{
|
||||||
|
HandshakeType: HandshakeTypeClientHello,
|
||||||
|
Identities: []PSKIdentity{
|
||||||
|
{
|
||||||
|
Identity: key.Identity,
|
||||||
|
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Binders: []PSKBinderEntry{
|
||||||
|
// Note: Stub to get the length fields right
|
||||||
|
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ch.Extensions.Add(psk)
|
||||||
|
|
||||||
|
// Compute the binder key
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
|
||||||
|
earlyHash = params.Hash
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
|
||||||
|
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||||
|
|
||||||
|
binderLabel := labelExternalBinder
|
||||||
|
if key.IsResumption {
|
||||||
|
binderLabel = labelResumptionBinder
|
||||||
|
}
|
||||||
|
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||||
|
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
|
||||||
|
|
||||||
|
// Compute the binder value
|
||||||
|
trunc, err := ch.Truncated()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
truncHash := params.Hash.New()
|
||||||
|
truncHash.Write(trunc)
|
||||||
|
|
||||||
|
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
|
||||||
|
|
||||||
|
// Replace the PSK extension
|
||||||
|
psk.Binders[0].Binder = binder
|
||||||
|
ch.Extensions.Add(psk)
|
||||||
|
|
||||||
|
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
||||||
|
// this one should too.
|
||||||
|
clientHello, _ = HandshakeMessageFromBody(ch)
|
||||||
|
|
||||||
|
// Compute early traffic keys
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(clientHello.Marshal())
|
||||||
|
chHash := h.Sum(nil)
|
||||||
|
|
||||||
|
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||||
|
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
|
||||||
|
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
|
||||||
|
} else if len(state.Opts.EarlyData) > 0 {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
} else {
|
||||||
|
clientHello, err = HandshakeMessageFromBody(ch)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
|
||||||
|
nextState := ClientStateWaitSH{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Opts: state.Opts,
|
||||||
|
Params: state.Params,
|
||||||
|
OfferedDH: offeredDH,
|
||||||
|
OfferedPSK: offeredPSK,
|
||||||
|
|
||||||
|
earlySecret: earlySecret,
|
||||||
|
earlyHash: earlyHash,
|
||||||
|
|
||||||
|
firstClientHello: state.firstClientHello,
|
||||||
|
helloRetryRequest: state.helloRetryRequest,
|
||||||
|
clientHello: clientHello,
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
SendHandshakeMessage{clientHello},
|
||||||
|
}
|
||||||
|
if state.Params.ClientSendingEarlyData {
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||||
|
SendEarlyData{},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitSH struct {
|
||||||
|
Caps Capabilities
|
||||||
|
Opts ConnectionOptions
|
||||||
|
Params ConnectionParameters
|
||||||
|
OfferedDH map[NamedGroup][]byte
|
||||||
|
OfferedPSK PreSharedKey
|
||||||
|
PSK []byte
|
||||||
|
|
||||||
|
earlySecret []byte
|
||||||
|
earlyHash crypto.Hash
|
||||||
|
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
clientHello *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyGeneric, err := hm.ToBody()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
switch body := bodyGeneric.(type) {
|
||||||
|
case *HelloRetryRequestBody:
|
||||||
|
hrr := body
|
||||||
|
|
||||||
|
if state.helloRetryRequest != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the version sent by the server is the one we support
|
||||||
|
if hrr.Version != supportedVersion {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the server provided a supported ciphersuite
|
||||||
|
supportedCipherSuite := false
|
||||||
|
for _, suite := range state.Caps.CipherSuites {
|
||||||
|
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
|
||||||
|
}
|
||||||
|
if !supportedCipherSuite {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Narrow the supported ciphersuites to the server-provided one
|
||||||
|
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The only thing we know how to respond to in an HRR is the Cookie
|
||||||
|
// extension, so if there is either no Cookie extension or anything other
|
||||||
|
// than a Cookie extension, we have to fail.
|
||||||
|
serverCookie := new(CookieExtension)
|
||||||
|
foundCookie := hrr.Extensions.Find(serverCookie)
|
||||||
|
if !foundCookie || len(hrr.Extensions) != 1 {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash the body into a pseudo-message
|
||||||
|
// XXX: Ignoring some errors here
|
||||||
|
params := cipherSuiteMap[hrr.CipherSuite]
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(state.clientHello.Marshal())
|
||||||
|
firstClientHello := &HandshakeMessage{
|
||||||
|
msgType: HandshakeTypeMessageHash,
|
||||||
|
body: h.Sum(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
|
||||||
|
return ClientStateStart{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Opts: state.Opts,
|
||||||
|
cookie: serverCookie.Cookie,
|
||||||
|
firstClientHello: firstClientHello,
|
||||||
|
helloRetryRequest: hm,
|
||||||
|
}.Next(nil)
|
||||||
|
|
||||||
|
case *ServerHelloBody:
|
||||||
|
sh := body
|
||||||
|
|
||||||
|
// Check that the version sent by the server is the one we support
|
||||||
|
if sh.Version != supportedVersion {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the server provided a supported ciphersuite
|
||||||
|
supportedCipherSuite := false
|
||||||
|
for _, suite := range state.Caps.CipherSuites {
|
||||||
|
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
|
||||||
|
}
|
||||||
|
if !supportedCipherSuite {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do PSK or key agreement depending on extensions
|
||||||
|
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
|
||||||
|
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
|
||||||
|
|
||||||
|
foundPSK := sh.Extensions.Find(&serverPSK)
|
||||||
|
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
|
||||||
|
|
||||||
|
if foundPSK && (serverPSK.SelectedIdentity == 0) {
|
||||||
|
state.Params.UsingPSK = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var dhSecret []byte
|
||||||
|
if foundKeyShare {
|
||||||
|
sks := serverKeyShare.Shares[0]
|
||||||
|
priv, ok := state.OfferedDH[sks.Group]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Params.UsingDH = true
|
||||||
|
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
|
||||||
|
}
|
||||||
|
|
||||||
|
suite := sh.CipherSuite
|
||||||
|
state.Params.CipherSuite = suite
|
||||||
|
|
||||||
|
params, ok := cipherSuiteMap[suite]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start up the handshake hash
|
||||||
|
handshakeHash := params.Hash.New()
|
||||||
|
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||||
|
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||||
|
handshakeHash.Write(state.clientHello.Marshal())
|
||||||
|
handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
// Compute handshake secrets
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
|
||||||
|
var earlySecret []byte
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
if params.Hash != state.earlyHash {
|
||||||
|
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
|
||||||
|
state.earlyHash, suite, params.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
earlySecret = state.earlySecret
|
||||||
|
} else {
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dhSecret == nil {
|
||||||
|
dhSecret = zero
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
h2 := handshakeHash.Sum(nil)
|
||||||
|
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||||
|
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
|
||||||
|
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||||
|
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||||
|
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||||
|
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||||
|
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||||
|
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||||
|
|
||||||
|
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
|
||||||
|
nextState := ClientStateWaitEE{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: params,
|
||||||
|
handshakeHash: handshakeHash,
|
||||||
|
certificates: state.Caps.Certificates,
|
||||||
|
masterSecret: masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||||
|
}
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitEE struct {
|
||||||
|
Caps Capabilities
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
certificates []*Certificate
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
ee := EncryptedExtensionsBody{}
|
||||||
|
_, err := ee.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serverALPN := ALPNExtension{}
|
||||||
|
serverEarlyData := EarlyDataExtension{}
|
||||||
|
|
||||||
|
gotALPN := ee.Extensions.Find(&serverALPN)
|
||||||
|
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
|
||||||
|
|
||||||
|
if gotALPN && len(serverALPN.Protocols) > 0 {
|
||||||
|
state.Params.NextProto = serverALPN.Protocols[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
|
||||||
|
nextState := ClientStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
|
||||||
|
nextState := ClientStateWaitCertCR{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitCertCR struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
certificates []*Certificate
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyGeneric, err := hm.ToBody()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
switch body := bodyGeneric.(type) {
|
||||||
|
case *CertificateBody:
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
|
||||||
|
nextState := ClientStateWaitCV{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificate: body,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
|
||||||
|
case *CertificateRequestBody:
|
||||||
|
// A certificate request in the handshake should have a zero-length context
|
||||||
|
if len(body.CertificateRequestContext) > 0 {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Params.UsingClientAuth = true
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
|
||||||
|
nextState := ClientStateWaitCert{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificateRequest: body,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitCert struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
|
certificates []*Certificate
|
||||||
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := &CertificateBody{}
|
||||||
|
_, err := cert.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
|
||||||
|
nextState := ClientStateWaitCV{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificate: cert,
|
||||||
|
serverCertificateRequest: state.serverCertificateRequest,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitCV struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
|
certificates []*Certificate
|
||||||
|
serverCertificate *CertificateBody
|
||||||
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
certVerify := CertificateVerifyBody{}
|
||||||
|
_, err := certVerify.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
hcv := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
|
||||||
|
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.AuthCertificate != nil {
|
||||||
|
err := state.AuthCertificate(state.serverCertificate.CertificateList)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
|
||||||
|
return nil, nil, AlertBadCertificate
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
|
||||||
|
nextState := ClientStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
certificates: state.certificates,
|
||||||
|
serverCertificateRequest: state.serverCertificateRequest,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientStateWaitFinished struct {
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
|
certificates []*Certificate
|
||||||
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
serverHandshakeTrafficSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify server's Finished
|
||||||
|
h3 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||||
|
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||||
|
|
||||||
|
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
|
||||||
|
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||||
|
|
||||||
|
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
|
||||||
|
_, err := fin.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
|
||||||
|
fin.VerifyData, serverFinishedData)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the handshake hash with the Finished
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
|
||||||
|
h4 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
|
||||||
|
|
||||||
|
// Compute traffic secrets and keys
|
||||||
|
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||||
|
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||||
|
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||||
|
|
||||||
|
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
|
||||||
|
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
|
||||||
|
|
||||||
|
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
|
||||||
|
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||||
|
|
||||||
|
// Assemble client's second flight
|
||||||
|
toSend := []HandshakeAction{}
|
||||||
|
|
||||||
|
if state.Params.UsingEarlyData {
|
||||||
|
// Note: We only send EOED if the server is actually going to use the early
|
||||||
|
// data. Otherwise, it will never see it, and the transcripts will
|
||||||
|
// mismatch.
|
||||||
|
// EOED marshal is infallible
|
||||||
|
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{eoedm})
|
||||||
|
state.handshakeHash.Write(eoedm.Marshal())
|
||||||
|
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||||
|
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
|
||||||
|
|
||||||
|
if state.Params.UsingClientAuth {
|
||||||
|
// Extract constraints from certicateRequest
|
||||||
|
schemes := SignatureAlgorithmsExtension{}
|
||||||
|
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
|
||||||
|
if !gotSchemes {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select a certificate
|
||||||
|
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
|
||||||
|
if err != nil {
|
||||||
|
// XXX: Signal this to the application layer?
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||||
|
|
||||||
|
certificate := &CertificateBody{}
|
||||||
|
certm, err := HandshakeMessageFromBody(certificate)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||||
|
state.handshakeHash.Write(certm.Marshal())
|
||||||
|
} else {
|
||||||
|
// Create and send Certificate, CertificateVerify
|
||||||
|
certificate := &CertificateBody{
|
||||||
|
CertificateList: make([]CertificateEntry, len(cert.Chain)),
|
||||||
|
}
|
||||||
|
for i, entry := range cert.Chain {
|
||||||
|
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||||
|
}
|
||||||
|
certm, err := HandshakeMessageFromBody(certificate)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||||
|
state.handshakeHash.Write(certm.Marshal())
|
||||||
|
|
||||||
|
hcv := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
|
||||||
|
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
|
||||||
|
|
||||||
|
err = certificateVerify.Sign(cert.PrivateKey, hcv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||||
|
state.handshakeHash.Write(certvm.Marshal())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the client's Finished message
|
||||||
|
h5 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||||
|
|
||||||
|
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||||
|
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||||
|
|
||||||
|
fin = &FinishedBody{
|
||||||
|
VerifyDataLen: len(clientFinishedData),
|
||||||
|
VerifyData: clientFinishedData,
|
||||||
|
}
|
||||||
|
finm, err := HandshakeMessageFromBody(fin)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the resumption secret
|
||||||
|
state.handshakeHash.Write(finm.Marshal())
|
||||||
|
h6 := state.handshakeHash.Sum(nil)
|
||||||
|
|
||||||
|
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||||
|
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||||
|
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
SendHandshakeMessage{finm},
|
||||||
|
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
|
||||||
|
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
|
||||||
|
nextState := StateConnected{
|
||||||
|
Params: state.Params,
|
||||||
|
isClient: true,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
resumptionSecret: resumptionSecret,
|
||||||
|
clientTrafficSecret: clientTrafficSecret,
|
||||||
|
serverTrafficSecret: serverTrafficSecret,
|
||||||
|
exporterSecret: exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
supportedVersion uint16 = 0x7f15 // draft-21
|
||||||
|
|
||||||
|
// Flags for some minor compat issues
|
||||||
|
allowWrongVersionNumber = true
|
||||||
|
allowPKCS1 = true
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} ContentType;
|
||||||
|
type RecordType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
RecordTypeAlert RecordType = 21
|
||||||
|
RecordTypeHandshake RecordType = 22
|
||||||
|
RecordTypeApplicationData RecordType = 23
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} HandshakeType;
|
||||||
|
type HandshakeType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Omitted: *_RESERVED
|
||||||
|
HandshakeTypeClientHello HandshakeType = 1
|
||||||
|
HandshakeTypeServerHello HandshakeType = 2
|
||||||
|
HandshakeTypeNewSessionTicket HandshakeType = 4
|
||||||
|
HandshakeTypeEndOfEarlyData HandshakeType = 5
|
||||||
|
HandshakeTypeHelloRetryRequest HandshakeType = 6
|
||||||
|
HandshakeTypeEncryptedExtensions HandshakeType = 8
|
||||||
|
HandshakeTypeCertificate HandshakeType = 11
|
||||||
|
HandshakeTypeCertificateRequest HandshakeType = 13
|
||||||
|
HandshakeTypeCertificateVerify HandshakeType = 15
|
||||||
|
HandshakeTypeServerConfiguration HandshakeType = 17
|
||||||
|
HandshakeTypeFinished HandshakeType = 20
|
||||||
|
HandshakeTypeKeyUpdate HandshakeType = 24
|
||||||
|
HandshakeTypeMessageHash HandshakeType = 254
|
||||||
|
)
|
||||||
|
|
||||||
|
// uint8 CipherSuite[2];
|
||||||
|
type CipherSuite uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
|
||||||
|
// value for this type so that we can detect when a field is set.
|
||||||
|
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
|
||||||
|
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
|
||||||
|
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
|
||||||
|
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
|
||||||
|
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
|
||||||
|
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c CipherSuite) String() string {
|
||||||
|
switch c {
|
||||||
|
case CIPHER_SUITE_UNKNOWN:
|
||||||
|
return "unknown"
|
||||||
|
case TLS_AES_128_GCM_SHA256:
|
||||||
|
return "TLS_AES_128_GCM_SHA256"
|
||||||
|
case TLS_AES_256_GCM_SHA384:
|
||||||
|
return "TLS_AES_256_GCM_SHA384"
|
||||||
|
case TLS_CHACHA20_POLY1305_SHA256:
|
||||||
|
return "TLS_CHACHA20_POLY1305_SHA256"
|
||||||
|
case TLS_AES_128_CCM_SHA256:
|
||||||
|
return "TLS_AES_128_CCM_SHA256"
|
||||||
|
case TLS_AES_256_CCM_8_SHA256:
|
||||||
|
return "TLS_AES_256_CCM_8_SHA256"
|
||||||
|
}
|
||||||
|
// cannot use %x here, since it calls String(), leading to infinite recursion
|
||||||
|
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
|
||||||
|
}
|
||||||
|
|
||||||
|
// enum {...} SignatureScheme
|
||||||
|
type SignatureScheme uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RSASSA-PKCS1-v1_5 algorithms
|
||||||
|
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
|
||||||
|
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
|
||||||
|
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
|
||||||
|
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
|
||||||
|
// ECDSA algorithms
|
||||||
|
ECDSA_P256_SHA256 SignatureScheme = 0x0403
|
||||||
|
ECDSA_P384_SHA384 SignatureScheme = 0x0503
|
||||||
|
ECDSA_P521_SHA512 SignatureScheme = 0x0603
|
||||||
|
// RSASSA-PSS algorithms
|
||||||
|
RSA_PSS_SHA256 SignatureScheme = 0x0804
|
||||||
|
RSA_PSS_SHA384 SignatureScheme = 0x0805
|
||||||
|
RSA_PSS_SHA512 SignatureScheme = 0x0806
|
||||||
|
// EdDSA algorithms
|
||||||
|
Ed25519 SignatureScheme = 0x0807
|
||||||
|
Ed448 SignatureScheme = 0x0808
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} ExtensionType
|
||||||
|
type ExtensionType uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
ExtensionTypeServerName ExtensionType = 0
|
||||||
|
ExtensionTypeSupportedGroups ExtensionType = 10
|
||||||
|
ExtensionTypeSignatureAlgorithms ExtensionType = 13
|
||||||
|
ExtensionTypeALPN ExtensionType = 16
|
||||||
|
ExtensionTypeKeyShare ExtensionType = 40
|
||||||
|
ExtensionTypePreSharedKey ExtensionType = 41
|
||||||
|
ExtensionTypeEarlyData ExtensionType = 42
|
||||||
|
ExtensionTypeSupportedVersions ExtensionType = 43
|
||||||
|
ExtensionTypeCookie ExtensionType = 44
|
||||||
|
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
|
||||||
|
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} NamedGroup
|
||||||
|
type NamedGroup uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Elliptic Curve Groups.
|
||||||
|
P256 NamedGroup = 23
|
||||||
|
P384 NamedGroup = 24
|
||||||
|
P521 NamedGroup = 25
|
||||||
|
// ECDH functions.
|
||||||
|
X25519 NamedGroup = 29
|
||||||
|
X448 NamedGroup = 30
|
||||||
|
// Finite field groups.
|
||||||
|
FFDHE2048 NamedGroup = 256
|
||||||
|
FFDHE3072 NamedGroup = 257
|
||||||
|
FFDHE4096 NamedGroup = 258
|
||||||
|
FFDHE6144 NamedGroup = 259
|
||||||
|
FFDHE8192 NamedGroup = 260
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {...} PskKeyExchangeMode;
|
||||||
|
type PSKKeyExchangeMode uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
PSKModeKE PSKKeyExchangeMode = 0
|
||||||
|
PSKModeDHEKE PSKKeyExchangeMode = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// enum {
|
||||||
|
// update_not_requested(0), update_requested(1), (255)
|
||||||
|
// } KeyUpdateRequest;
|
||||||
|
type KeyUpdateRequest uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
KeyUpdateNotRequested KeyUpdateRequest = 0
|
||||||
|
KeyUpdateRequested KeyUpdateRequest = 1
|
||||||
|
)
|
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
@ -0,0 +1,819 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var WouldBlock = fmt.Errorf("Would have blocked")
|
||||||
|
|
||||||
|
type Certificate struct {
|
||||||
|
Chain []*x509.Certificate
|
||||||
|
PrivateKey crypto.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreSharedKey struct {
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
IsResumption bool
|
||||||
|
Identity []byte
|
||||||
|
Key []byte
|
||||||
|
NextProto string
|
||||||
|
ReceivedAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
TicketAgeAdd uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreSharedKeyCache interface {
|
||||||
|
Get(string) (PreSharedKey, bool)
|
||||||
|
Put(string, PreSharedKey)
|
||||||
|
Size() int
|
||||||
|
}
|
||||||
|
|
||||||
|
type PSKMapCache map[string]PreSharedKey
|
||||||
|
|
||||||
|
// A CookieHandler does two things:
|
||||||
|
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||||
|
// - validates this byte string echoed by the client in the ClientHello
|
||||||
|
type CookieHandler interface {
|
||||||
|
Generate(*Conn) ([]byte, error)
|
||||||
|
Validate(*Conn, []byte) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||||||
|
psk, ok = cache[key]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
|
||||||
|
(*cache)[key] = psk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache PSKMapCache) Size() int {
|
||||||
|
return len(cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the struct used to pass configuration settings to a TLS client or
|
||||||
|
// server instance. The settings for client and server are pretty different,
|
||||||
|
// but we just throw them all in here.
|
||||||
|
type Config struct {
|
||||||
|
// Client fields
|
||||||
|
ServerName string
|
||||||
|
|
||||||
|
// Server fields
|
||||||
|
SendSessionTickets bool
|
||||||
|
TicketLifetime uint32
|
||||||
|
TicketLen int
|
||||||
|
EarlyDataLifetime uint32
|
||||||
|
AllowEarlyData bool
|
||||||
|
// Require the client to echo a cookie.
|
||||||
|
RequireCookie bool
|
||||||
|
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
||||||
|
// The default cookie handler uses 32 random bytes as a cookie.
|
||||||
|
CookieHandler CookieHandler
|
||||||
|
RequireClientAuth bool
|
||||||
|
|
||||||
|
// Shared fields
|
||||||
|
Certificates []*Certificate
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
CipherSuites []CipherSuite
|
||||||
|
Groups []NamedGroup
|
||||||
|
SignatureSchemes []SignatureScheme
|
||||||
|
NextProtos []string
|
||||||
|
PSKs PreSharedKeyCache
|
||||||
|
PSKModes []PSKKeyExchangeMode
|
||||||
|
NonBlocking bool
|
||||||
|
|
||||||
|
// The same config object can be shared among different connections, so it
|
||||||
|
// needs its own mutex
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a shallow clone of c. It is safe to clone a Config that is
|
||||||
|
// being used concurrently by a TLS client or server.
|
||||||
|
func (c *Config) Clone() *Config {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
return &Config{
|
||||||
|
ServerName: c.ServerName,
|
||||||
|
|
||||||
|
SendSessionTickets: c.SendSessionTickets,
|
||||||
|
TicketLifetime: c.TicketLifetime,
|
||||||
|
TicketLen: c.TicketLen,
|
||||||
|
EarlyDataLifetime: c.EarlyDataLifetime,
|
||||||
|
AllowEarlyData: c.AllowEarlyData,
|
||||||
|
RequireCookie: c.RequireCookie,
|
||||||
|
RequireClientAuth: c.RequireClientAuth,
|
||||||
|
|
||||||
|
Certificates: c.Certificates,
|
||||||
|
AuthCertificate: c.AuthCertificate,
|
||||||
|
CipherSuites: c.CipherSuites,
|
||||||
|
Groups: c.Groups,
|
||||||
|
SignatureSchemes: c.SignatureSchemes,
|
||||||
|
NextProtos: c.NextProtos,
|
||||||
|
PSKs: c.PSKs,
|
||||||
|
PSKModes: c.PSKModes,
|
||||||
|
NonBlocking: c.NonBlocking,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Init(isClient bool) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if len(c.CipherSuites) == 0 {
|
||||||
|
c.CipherSuites = defaultSupportedCipherSuites
|
||||||
|
}
|
||||||
|
if len(c.Groups) == 0 {
|
||||||
|
c.Groups = defaultSupportedGroups
|
||||||
|
}
|
||||||
|
if len(c.SignatureSchemes) == 0 {
|
||||||
|
c.SignatureSchemes = defaultSignatureSchemes
|
||||||
|
}
|
||||||
|
if c.TicketLen == 0 {
|
||||||
|
c.TicketLen = defaultTicketLen
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(c.PSKs).IsValid() {
|
||||||
|
c.PSKs = &PSKMapCache{}
|
||||||
|
}
|
||||||
|
if len(c.PSKModes) == 0 {
|
||||||
|
c.PSKModes = defaultPSKModes
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no certificate, generate one
|
||||||
|
if !isClient && len(c.Certificates) == 0 {
|
||||||
|
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
||||||
|
priv, err := newSigningKey(RSA_PSS_SHA256)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Certificates = []*Certificate{
|
||||||
|
{
|
||||||
|
Chain: []*x509.Certificate{cert},
|
||||||
|
PrivateKey: priv,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ValidForServer() bool {
|
||||||
|
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
|
||||||
|
(len(c.Certificates) > 0 &&
|
||||||
|
len(c.Certificates[0].Chain) > 0 &&
|
||||||
|
c.Certificates[0].PrivateKey != nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) ValidForClient() bool {
|
||||||
|
return len(c.ServerName) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultSupportedCipherSuites = []CipherSuite{
|
||||||
|
TLS_AES_128_GCM_SHA256,
|
||||||
|
TLS_AES_256_GCM_SHA384,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSupportedGroups = []NamedGroup{
|
||||||
|
P256,
|
||||||
|
P384,
|
||||||
|
FFDHE2048,
|
||||||
|
X25519,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSignatureSchemes = []SignatureScheme{
|
||||||
|
RSA_PSS_SHA256,
|
||||||
|
RSA_PSS_SHA384,
|
||||||
|
RSA_PSS_SHA512,
|
||||||
|
ECDSA_P256_SHA256,
|
||||||
|
ECDSA_P384_SHA384,
|
||||||
|
ECDSA_P521_SHA512,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultTicketLen = 16
|
||||||
|
|
||||||
|
defaultPSKModes = []PSKKeyExchangeMode{
|
||||||
|
PSKModeKE,
|
||||||
|
PSKModeDHEKE,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionState struct {
|
||||||
|
HandshakeState string // string representation of the handshake state.
|
||||||
|
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||||
|
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
||||||
|
NextProto string // Selected ALPN proto
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||||||
|
// * Read, Write, and Close are provided locally
|
||||||
|
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
|
||||||
|
type Conn struct {
|
||||||
|
config *Config
|
||||||
|
conn net.Conn
|
||||||
|
isClient bool
|
||||||
|
|
||||||
|
EarlyData []byte
|
||||||
|
|
||||||
|
state StateConnected
|
||||||
|
hState HandshakeState
|
||||||
|
handshakeMutex sync.Mutex
|
||||||
|
handshakeAlert Alert
|
||||||
|
handshakeComplete bool
|
||||||
|
|
||||||
|
readBuffer []byte
|
||||||
|
in, out *RecordLayer
|
||||||
|
hIn, hOut *HandshakeLayer
|
||||||
|
|
||||||
|
extHandler AppExtensionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||||||
|
c := &Conn{conn: conn, config: config, isClient: isClient}
|
||||||
|
c.in = NewRecordLayer(c.conn)
|
||||||
|
c.out = NewRecordLayer(c.conn)
|
||||||
|
c.hIn = NewHandshakeLayer(c.in)
|
||||||
|
c.hIn.nonblocking = c.config.NonBlocking
|
||||||
|
c.hOut = NewHandshakeLayer(c.out)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read up
|
||||||
|
func (c *Conn) consumeRecord() error {
|
||||||
|
pt, err := c.in.ReadRecord()
|
||||||
|
if pt == nil {
|
||||||
|
logf(logTypeIO, "extendBuffer returns error %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch pt.contentType {
|
||||||
|
case RecordTypeHandshake:
|
||||||
|
logf(logTypeHandshake, "Received post-handshake message")
|
||||||
|
// We do not support fragmentation of post-handshake handshake messages.
|
||||||
|
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||||||
|
start := 0
|
||||||
|
for start < len(pt.fragment) {
|
||||||
|
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
||||||
|
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||||||
|
}
|
||||||
|
|
||||||
|
hm := &HandshakeMessage{}
|
||||||
|
hm.msgType = HandshakeType(pt.fragment[start])
|
||||||
|
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||||||
|
|
||||||
|
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
||||||
|
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||||||
|
}
|
||||||
|
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
||||||
|
|
||||||
|
// Advance state machine
|
||||||
|
state, actions, alert := c.state.Next(hm)
|
||||||
|
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||||
|
// authentication, we'll need to allow transitions other than
|
||||||
|
// Connected -> Connected
|
||||||
|
var connected bool
|
||||||
|
c.state, connected = state.(StateConnected)
|
||||||
|
if !connected {
|
||||||
|
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
start += handshakeHeaderLen + hmLen
|
||||||
|
}
|
||||||
|
case RecordTypeAlert:
|
||||||
|
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||||
|
if len(pt.fragment) != 2 {
|
||||||
|
c.sendAlert(AlertUnexpectedMessage)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
if Alert(pt.fragment[1]) == AlertCloseNotify {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
switch pt.fragment[0] {
|
||||||
|
case AlertLevelWarning:
|
||||||
|
// drop on the floor
|
||||||
|
case AlertLevelError:
|
||||||
|
return Alert(pt.fragment[1])
|
||||||
|
default:
|
||||||
|
c.sendAlert(AlertUnexpectedMessage)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
case RecordTypeApplicationData:
|
||||||
|
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||||||
|
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read application data up to the size of buffer. Handshake and alert records
|
||||||
|
// are consumed by the Conn object directly.
|
||||||
|
func (c *Conn) Read(buffer []byte) (int, error) {
|
||||||
|
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||||||
|
if alert := c.Handshake(); alert != AlertNoAlert {
|
||||||
|
return 0, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(buffer) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock the input channel
|
||||||
|
c.in.Lock()
|
||||||
|
defer c.in.Unlock()
|
||||||
|
for len(c.readBuffer) == 0 {
|
||||||
|
err := c.consumeRecord()
|
||||||
|
|
||||||
|
// err can be nil if consumeRecord processed a non app-data
|
||||||
|
// record.
|
||||||
|
if err != nil {
|
||||||
|
if c.config.NonBlocking || err != WouldBlock {
|
||||||
|
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var read int
|
||||||
|
n := len(buffer)
|
||||||
|
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
||||||
|
if len(c.readBuffer) <= n {
|
||||||
|
buffer = buffer[:len(c.readBuffer)]
|
||||||
|
copy(buffer, c.readBuffer)
|
||||||
|
read = len(c.readBuffer)
|
||||||
|
c.readBuffer = c.readBuffer[:0]
|
||||||
|
} else {
|
||||||
|
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
||||||
|
copy(buffer[:n], c.readBuffer[:n])
|
||||||
|
c.readBuffer = c.readBuffer[n:]
|
||||||
|
read = n
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write application data
|
||||||
|
func (c *Conn) Write(buffer []byte) (int, error) {
|
||||||
|
// Lock the output channel
|
||||||
|
c.out.Lock()
|
||||||
|
defer c.out.Unlock()
|
||||||
|
|
||||||
|
// Send full-size fragments
|
||||||
|
var start int
|
||||||
|
sent := 0
|
||||||
|
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||||
|
err := c.out.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeApplicationData,
|
||||||
|
fragment: buffer[start : start+maxFragmentLen],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return sent, err
|
||||||
|
}
|
||||||
|
sent += maxFragmentLen
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a final partial fragment if necessary
|
||||||
|
if start < len(buffer) {
|
||||||
|
err := c.out.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeApplicationData,
|
||||||
|
fragment: buffer[start:],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return sent, err
|
||||||
|
}
|
||||||
|
sent += len(buffer[start:])
|
||||||
|
}
|
||||||
|
return sent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendAlert sends a TLS alert message.
|
||||||
|
// c.out.Mutex <= L.
|
||||||
|
func (c *Conn) sendAlert(err Alert) error {
|
||||||
|
c.handshakeMutex.Lock()
|
||||||
|
defer c.handshakeMutex.Unlock()
|
||||||
|
|
||||||
|
var level int
|
||||||
|
switch err {
|
||||||
|
case AlertNoRenegotiation, AlertCloseNotify:
|
||||||
|
level = AlertLevelWarning
|
||||||
|
default:
|
||||||
|
level = AlertLevelError
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := []byte{byte(err), byte(level)}
|
||||||
|
c.out.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeAlert,
|
||||||
|
fragment: buf,
|
||||||
|
})
|
||||||
|
|
||||||
|
// close_notify and end_of_early_data are not actually errors
|
||||||
|
if level == AlertLevelWarning {
|
||||||
|
return &net.OpError{Op: "local error", Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
// XXX crypto/tls has an interlock with Write here. Do we need that?
|
||||||
|
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local network address.
|
||||||
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote network address.
|
||||||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the read and write deadlines associated with the connection.
|
||||||
|
// A zero value for t means Read and Write will not time out.
|
||||||
|
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline on the underlying connection.
|
||||||
|
// A zero value for t means Read will not time out.
|
||||||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||||||
|
// A zero value for t means Write will not time out.
|
||||||
|
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||||||
|
label := "[server]"
|
||||||
|
if c.isClient {
|
||||||
|
label = "[client]"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch action := actionGeneric.(type) {
|
||||||
|
case SendHandshakeMessage:
|
||||||
|
err := c.hOut.WriteMessage(action.Message)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case RekeyIn:
|
||||||
|
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
||||||
|
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case RekeyOut:
|
||||||
|
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
||||||
|
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case SendEarlyData:
|
||||||
|
logf(logTypeHandshake, "%s Sending early data...", label)
|
||||||
|
_, err := c.Write(c.EarlyData)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
case ReadPastEarlyData:
|
||||||
|
logf(logTypeHandshake, "%s Reading past early data...", label)
|
||||||
|
// Scan past all records that fail to decrypt
|
||||||
|
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_, ok := err.(DecryptError)
|
||||||
|
|
||||||
|
for ok {
|
||||||
|
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_, ok = err.(DecryptError)
|
||||||
|
}
|
||||||
|
|
||||||
|
case ReadEarlyData:
|
||||||
|
logf(logTypeHandshake, "%s Reading early data...", label)
|
||||||
|
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
||||||
|
|
||||||
|
for t == RecordTypeApplicationData {
|
||||||
|
// Read a record into the buffer. Note that this is safe
|
||||||
|
// in blocking mode because we read the record in in
|
||||||
|
// PeekRecordType.
|
||||||
|
pt, err := c.in.ReadRecord()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
||||||
|
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
||||||
|
|
||||||
|
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "%s Done reading early data", label)
|
||||||
|
|
||||||
|
case StorePSK:
|
||||||
|
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||||||
|
if c.isClient {
|
||||||
|
// Clients look up PSKs based on server name
|
||||||
|
c.config.PSKs.Put(c.config.ServerName, action.PSK)
|
||||||
|
} else {
|
||||||
|
// Servers look them up based on the identity in the extension
|
||||||
|
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) HandshakeSetup() Alert {
|
||||||
|
var state HandshakeState
|
||||||
|
var actions []HandshakeAction
|
||||||
|
var alert Alert
|
||||||
|
|
||||||
|
if err := c.config.Init(c.isClient); err != nil {
|
||||||
|
logf(logTypeHandshake, "Error initializing config: %v", err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set things up
|
||||||
|
caps := Capabilities{
|
||||||
|
CipherSuites: c.config.CipherSuites,
|
||||||
|
Groups: c.config.Groups,
|
||||||
|
SignatureSchemes: c.config.SignatureSchemes,
|
||||||
|
PSKs: c.config.PSKs,
|
||||||
|
PSKModes: c.config.PSKModes,
|
||||||
|
AllowEarlyData: c.config.AllowEarlyData,
|
||||||
|
RequireCookie: c.config.RequireCookie,
|
||||||
|
CookieHandler: c.config.CookieHandler,
|
||||||
|
RequireClientAuth: c.config.RequireClientAuth,
|
||||||
|
NextProtos: c.config.NextProtos,
|
||||||
|
Certificates: c.config.Certificates,
|
||||||
|
ExtensionHandler: c.extHandler,
|
||||||
|
}
|
||||||
|
opts := ConnectionOptions{
|
||||||
|
ServerName: c.config.ServerName,
|
||||||
|
NextProtos: c.config.NextProtos,
|
||||||
|
EarlyData: c.EarlyData,
|
||||||
|
}
|
||||||
|
|
||||||
|
if caps.RequireCookie && caps.CookieHandler == nil {
|
||||||
|
caps.CookieHandler = &defaultCookieHandler{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.isClient {
|
||||||
|
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state = ServerStateStart{Caps: caps, conn: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.hState = state
|
||||||
|
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||||||
|
// determines whether a client or server handshake is performed. If a
|
||||||
|
// handshake has already been performed, then its result will be returned.
|
||||||
|
func (c *Conn) Handshake() Alert {
|
||||||
|
label := "[server]"
|
||||||
|
if c.isClient {
|
||||||
|
label = "[client]"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO Lock handshakeMutex
|
||||||
|
// TODO Remove CloseNotify hack
|
||||||
|
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
|
||||||
|
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
|
||||||
|
return c.handshakeAlert
|
||||||
|
}
|
||||||
|
if c.handshakeComplete {
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
var alert Alert
|
||||||
|
if c.hState == nil {
|
||||||
|
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
||||||
|
alert = c.HandshakeSetup()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
||||||
|
}
|
||||||
|
|
||||||
|
state := c.hState
|
||||||
|
_, connected := state.(StateConnected)
|
||||||
|
|
||||||
|
var actions []HandshakeAction
|
||||||
|
|
||||||
|
for !connected {
|
||||||
|
// Read a handshake message
|
||||||
|
hm, err := c.hIn.ReadMessage()
|
||||||
|
if err == WouldBlock {
|
||||||
|
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
||||||
|
return AlertWouldBlock
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
||||||
|
c.sendAlert(AlertCloseNotify)
|
||||||
|
return AlertCloseNotify
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||||
|
|
||||||
|
// Advance the state machine
|
||||||
|
state, actions, alert = state.Next(hm)
|
||||||
|
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, action := range actions {
|
||||||
|
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.hState = state
|
||||||
|
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||||||
|
|
||||||
|
_, connected = state.(StateConnected)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = state.(StateConnected)
|
||||||
|
|
||||||
|
// Send NewSessionTicket if acting as server
|
||||||
|
if !c.isClient && c.config.SendSessionTickets {
|
||||||
|
actions, alert := c.state.NewSessionTicket(
|
||||||
|
c.config.TicketLen,
|
||||||
|
c.config.TicketLifetime,
|
||||||
|
c.config.EarlyDataLifetime)
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return alert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handshakeComplete = true
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
||||||
|
if !c.handshakeComplete {
|
||||||
|
return fmt.Errorf("Cannot update keys until after handshake")
|
||||||
|
}
|
||||||
|
|
||||||
|
request := KeyUpdateNotRequested
|
||||||
|
if requestUpdate {
|
||||||
|
request = KeyUpdateRequested
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the key update and update state
|
||||||
|
actions, alert := c.state.KeyUpdate(request)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return fmt.Errorf("Alert while generating key update: %v", alert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take actions (send key update and rekey)
|
||||||
|
for _, action := range actions {
|
||||||
|
alert = c.takeAction(action)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
c.sendAlert(alert)
|
||||||
|
return fmt.Errorf("Alert during key update actions: %v", alert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) GetHsState() string {
|
||||||
|
return reflect.TypeOf(c.hState).Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||||
|
_, connected := c.hState.(StateConnected)
|
||||||
|
if !connected {
|
||||||
|
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.state.exporterSecret == nil {
|
||||||
|
return nil, fmt.Errorf("Internal error: no exporter secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
|
||||||
|
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
|
||||||
|
|
||||||
|
hc := c.state.cryptoParams.Hash.New().Sum(context)
|
||||||
|
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) State() ConnectionState {
|
||||||
|
state := ConnectionState{
|
||||||
|
HandshakeState: c.GetHsState(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.handshakeComplete {
|
||||||
|
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||||||
|
state.NextProto = c.state.Params.NextProto
|
||||||
|
}
|
||||||
|
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
||||||
|
if c.hState != nil {
|
||||||
|
return fmt.Errorf("Can't set extension handler after setup")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.extHandler = h
|
||||||
|
return nil
|
||||||
|
}
|
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
@ -0,0 +1,654 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
|
||||||
|
// Blank includes to ensure hash support
|
||||||
|
_ "crypto/sha1"
|
||||||
|
_ "crypto/sha256"
|
||||||
|
_ "crypto/sha512"
|
||||||
|
)
|
||||||
|
|
||||||
|
var prng = rand.Reader
|
||||||
|
|
||||||
|
type aeadFactory func(key []byte) (cipher.AEAD, error)
|
||||||
|
|
||||||
|
type CipherSuiteParams struct {
|
||||||
|
Suite CipherSuite
|
||||||
|
Cipher aeadFactory // Cipher factory
|
||||||
|
Hash crypto.Hash // Hash function
|
||||||
|
KeyLen int // Key length in octets
|
||||||
|
IvLen int // IV length in octets
|
||||||
|
}
|
||||||
|
|
||||||
|
type signatureAlgorithm uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
signatureAlgorithmUnknown = iota
|
||||||
|
signatureAlgorithmRSA_PKCS1
|
||||||
|
signatureAlgorithmRSA_PSS
|
||||||
|
signatureAlgorithmECDSA
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
hashMap = map[SignatureScheme]crypto.Hash{
|
||||||
|
RSA_PKCS1_SHA1: crypto.SHA1,
|
||||||
|
RSA_PKCS1_SHA256: crypto.SHA256,
|
||||||
|
RSA_PKCS1_SHA384: crypto.SHA384,
|
||||||
|
RSA_PKCS1_SHA512: crypto.SHA512,
|
||||||
|
ECDSA_P256_SHA256: crypto.SHA256,
|
||||||
|
ECDSA_P384_SHA384: crypto.SHA384,
|
||||||
|
ECDSA_P521_SHA512: crypto.SHA512,
|
||||||
|
RSA_PSS_SHA256: crypto.SHA256,
|
||||||
|
RSA_PSS_SHA384: crypto.SHA384,
|
||||||
|
RSA_PSS_SHA512: crypto.SHA512,
|
||||||
|
}
|
||||||
|
|
||||||
|
sigMap = map[SignatureScheme]signatureAlgorithm{
|
||||||
|
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
|
||||||
|
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
|
||||||
|
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
|
||||||
|
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
|
||||||
|
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
|
||||||
|
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
|
||||||
|
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
|
||||||
|
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
|
||||||
|
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
|
||||||
|
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
|
||||||
|
}
|
||||||
|
|
||||||
|
curveMap = map[SignatureScheme]NamedGroup{
|
||||||
|
ECDSA_P256_SHA256: P256,
|
||||||
|
ECDSA_P384_SHA384: P384,
|
||||||
|
ECDSA_P521_SHA512: P521,
|
||||||
|
}
|
||||||
|
|
||||||
|
newAESGCM = func(key []byte) (cipher.AEAD, error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLS always uses 12-byte nonces
|
||||||
|
return cipher.NewGCMWithNonceSize(block, 12)
|
||||||
|
}
|
||||||
|
|
||||||
|
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
|
||||||
|
TLS_AES_128_GCM_SHA256: {
|
||||||
|
Suite: TLS_AES_128_GCM_SHA256,
|
||||||
|
Cipher: newAESGCM,
|
||||||
|
Hash: crypto.SHA256,
|
||||||
|
KeyLen: 16,
|
||||||
|
IvLen: 12,
|
||||||
|
},
|
||||||
|
TLS_AES_256_GCM_SHA384: {
|
||||||
|
Suite: TLS_AES_256_GCM_SHA384,
|
||||||
|
Cipher: newAESGCM,
|
||||||
|
Hash: crypto.SHA384,
|
||||||
|
KeyLen: 32,
|
||||||
|
IvLen: 12,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
|
||||||
|
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
|
||||||
|
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
|
||||||
|
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
|
||||||
|
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
|
||||||
|
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
|
||||||
|
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
|
||||||
|
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultRSAKeySize = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
|
||||||
|
switch group {
|
||||||
|
case P256:
|
||||||
|
crv = elliptic.P256()
|
||||||
|
case P384:
|
||||||
|
crv = elliptic.P384()
|
||||||
|
case P521:
|
||||||
|
crv = elliptic.P521()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
|
||||||
|
switch key.Curve.Params().Name {
|
||||||
|
case elliptic.P256().Params().Name:
|
||||||
|
g = P256
|
||||||
|
case elliptic.P384().Params().Name:
|
||||||
|
g = P384
|
||||||
|
case elliptic.P521().Params().Name:
|
||||||
|
g = P521
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
|
||||||
|
size = 0
|
||||||
|
switch group {
|
||||||
|
case X25519:
|
||||||
|
size = 32
|
||||||
|
case P256:
|
||||||
|
size = 65
|
||||||
|
case P384:
|
||||||
|
size = 97
|
||||||
|
case P521:
|
||||||
|
size = 133
|
||||||
|
case FFDHE2048:
|
||||||
|
size = 256
|
||||||
|
case FFDHE3072:
|
||||||
|
size = 384
|
||||||
|
case FFDHE4096:
|
||||||
|
size = 512
|
||||||
|
case FFDHE6144:
|
||||||
|
size = 768
|
||||||
|
case FFDHE8192:
|
||||||
|
size = 1024
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
|
||||||
|
switch group {
|
||||||
|
case FFDHE2048:
|
||||||
|
p = finiteFieldPrime2048
|
||||||
|
case FFDHE3072:
|
||||||
|
p = finiteFieldPrime3072
|
||||||
|
case FFDHE4096:
|
||||||
|
p = finiteFieldPrime4096
|
||||||
|
case FFDHE6144:
|
||||||
|
p = finiteFieldPrime6144
|
||||||
|
case FFDHE8192:
|
||||||
|
p = finiteFieldPrime8192
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
|
||||||
|
sigType := sigMap[alg]
|
||||||
|
switch key.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return sigType == signatureAlgorithmECDSA
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
|
||||||
|
primeLen := len(p.Bytes())
|
||||||
|
for {
|
||||||
|
// g = 2 for all ffdhe groups
|
||||||
|
priv, err = rand.Int(prng, p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pub = big.NewInt(0)
|
||||||
|
pub.Exp(big.NewInt(2), priv, p)
|
||||||
|
|
||||||
|
if len(pub.Bytes()) == primeLen {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
|
||||||
|
switch group {
|
||||||
|
case P256, P384, P521:
|
||||||
|
var x, y *big.Int
|
||||||
|
crv := curveFromNamedGroup(group)
|
||||||
|
priv, x, y, err = elliptic.GenerateKey(crv, prng)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pub = elliptic.Marshal(crv, x, y)
|
||||||
|
return
|
||||||
|
|
||||||
|
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||||
|
p := primeFromNamedGroup(group)
|
||||||
|
x, X, err2 := ffdheKeyShareFromPrime(p)
|
||||||
|
if err2 != nil {
|
||||||
|
err = err2
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
priv = x.Bytes()
|
||||||
|
pubBytes := X.Bytes()
|
||||||
|
|
||||||
|
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||||
|
|
||||||
|
pub = make([]byte, numBytes)
|
||||||
|
copy(pub[numBytes-len(pubBytes):], pubBytes)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
case X25519:
|
||||||
|
var private, public [32]byte
|
||||||
|
_, err = prng.Read(private[:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
curve25519.ScalarBaseMult(&public, &private)
|
||||||
|
priv = private[:]
|
||||||
|
pub = public[:]
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
|
||||||
|
switch group {
|
||||||
|
case P256, P384, P521:
|
||||||
|
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
crv := curveFromNamedGroup(group)
|
||||||
|
pubX, pubY := elliptic.Unmarshal(crv, pub)
|
||||||
|
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
|
||||||
|
xBytes := x.Bytes()
|
||||||
|
|
||||||
|
numBytes := len(crv.Params().P.Bytes())
|
||||||
|
|
||||||
|
ret := make([]byte, numBytes)
|
||||||
|
copy(ret[numBytes-len(xBytes):], xBytes)
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
|
||||||
|
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||||
|
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||||
|
if len(pub) != numBytes {
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||||
|
}
|
||||||
|
p := primeFromNamedGroup(group)
|
||||||
|
x := big.NewInt(0).SetBytes(priv)
|
||||||
|
Y := big.NewInt(0).SetBytes(pub)
|
||||||
|
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
|
||||||
|
|
||||||
|
ret := make([]byte, numBytes)
|
||||||
|
copy(ret[numBytes-len(ZBytes):], ZBytes)
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
|
||||||
|
case X25519:
|
||||||
|
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
var private, public, ret [32]byte
|
||||||
|
copy(private[:], priv)
|
||||||
|
copy(public[:], pub)
|
||||||
|
curve25519.ScalarMult(&ret, &private, &public)
|
||||||
|
|
||||||
|
return ret[:], nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
|
||||||
|
switch sig {
|
||||||
|
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
|
||||||
|
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
|
||||||
|
RSA_PSS_SHA256, RSA_PSS_SHA384,
|
||||||
|
RSA_PSS_SHA512:
|
||||||
|
return rsa.GenerateKey(prng, defaultRSAKeySize)
|
||||||
|
case ECDSA_P256_SHA256:
|
||||||
|
return ecdsa.GenerateKey(elliptic.P256(), prng)
|
||||||
|
case ECDSA_P384_SHA384:
|
||||||
|
return ecdsa.GenerateKey(elliptic.P384(), prng)
|
||||||
|
case ECDSA_P521_SHA512:
|
||||||
|
return ecdsa.GenerateKey(elliptic.P521(), prng)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||||
|
sigAlg, ok := x509AlgMap[alg]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||||
|
}
|
||||||
|
if len(name) == 0 {
|
||||||
|
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||||
|
SignatureAlgorithm: sigAlg,
|
||||||
|
Subject: pkix.Name{CommonName: name},
|
||||||
|
DNSNames: []string{name},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is safe to ignore the error here because we're parsing known-good data
|
||||||
|
cert, _ := x509.ParseCertificate(der)
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX(rlb): Copied from crypto/x509
|
||||||
|
type ecdsaSignature struct {
|
||||||
|
R, S *big.Int
|
||||||
|
}
|
||||||
|
|
||||||
|
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
|
||||||
|
var opts crypto.SignerOpts
|
||||||
|
|
||||||
|
hash := hashMap[alg]
|
||||||
|
if hash == crypto.SHA1 {
|
||||||
|
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||||
|
}
|
||||||
|
|
||||||
|
sigType := sigMap[alg]
|
||||||
|
var realInput []byte
|
||||||
|
switch key := privateKey.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
switch {
|
||||||
|
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
|
||||||
|
opts = hash
|
||||||
|
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
fallthrough
|
||||||
|
case sigType == signatureAlgorithmRSA_PSS:
|
||||||
|
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
|
||||||
|
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput = h.Sum(nil)
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
if sigType != signatureAlgorithmECDSA {
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
algGroup := curveMap[alg]
|
||||||
|
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
|
||||||
|
if algGroup != keyGroup {
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput = h.Sum(nil)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
|
||||||
|
}
|
||||||
|
|
||||||
|
sig, err := privateKey.Sign(prng, realInput, opts)
|
||||||
|
logf(logTypeCrypto, "signature: %x", sig)
|
||||||
|
return sig, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
|
||||||
|
hash := hashMap[alg]
|
||||||
|
|
||||||
|
if hash == crypto.SHA1 {
|
||||||
|
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||||
|
}
|
||||||
|
|
||||||
|
sigType := sigMap[alg]
|
||||||
|
switch pub := publicKey.(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
switch {
|
||||||
|
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput := h.Sum(nil)
|
||||||
|
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
|
||||||
|
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||||
|
fallthrough
|
||||||
|
case sigType == signatureAlgorithmRSA_PSS:
|
||||||
|
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
|
||||||
|
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput := h.Sum(nil)
|
||||||
|
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
if sigType != signatureAlgorithmECDSA {
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
|
||||||
|
}
|
||||||
|
|
||||||
|
ecdsaSig := new(ecdsaSignature)
|
||||||
|
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
|
||||||
|
return err
|
||||||
|
} else if len(rest) != 0 {
|
||||||
|
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
|
||||||
|
}
|
||||||
|
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||||
|
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hash.New()
|
||||||
|
h.Write(sigInput)
|
||||||
|
realInput := h.Sum(nil)
|
||||||
|
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
|
||||||
|
return fmt.Errorf("tls.verify: ECDSA verification failure")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("tls.verify: Unsupported key type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// PSK -> HKDF-Extract = Early Secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(.,
|
||||||
|
// | "ext binder" |
|
||||||
|
// | "res binder",
|
||||||
|
// | "")
|
||||||
|
// | = binder_key
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "c e traffic",
|
||||||
|
// | ClientHello)
|
||||||
|
// | = client_early_traffic_secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "e exp master",
|
||||||
|
// | ClientHello)
|
||||||
|
// | = early_exporter_master_secret
|
||||||
|
// v
|
||||||
|
// Derive-Secret(., "derived", "")
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// (EC)DHE -> HKDF-Extract = Handshake Secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "c hs traffic",
|
||||||
|
// | ClientHello...ServerHello)
|
||||||
|
// | = client_handshake_traffic_secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "s hs traffic",
|
||||||
|
// | ClientHello...ServerHello)
|
||||||
|
// | = server_handshake_traffic_secret
|
||||||
|
// v
|
||||||
|
// Derive-Secret(., "derived", "")
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// 0 -> HKDF-Extract = Master Secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "c ap traffic",
|
||||||
|
// | ClientHello...server Finished)
|
||||||
|
// | = client_application_traffic_secret_0
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "s ap traffic",
|
||||||
|
// | ClientHello...server Finished)
|
||||||
|
// | = server_application_traffic_secret_0
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "exp master",
|
||||||
|
// | ClientHello...server Finished)
|
||||||
|
// | = exporter_master_secret
|
||||||
|
// |
|
||||||
|
// +-----> Derive-Secret(., "res master",
|
||||||
|
// ClientHello...client Finished)
|
||||||
|
// = resumption_master_secret
|
||||||
|
|
||||||
|
// From RFC 5869
|
||||||
|
// PRK = HMAC-Hash(salt, IKM)
|
||||||
|
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
|
||||||
|
salt := saltIn
|
||||||
|
|
||||||
|
// if [salt is] not provided, it is set to a string of HashLen zeros
|
||||||
|
if salt == nil {
|
||||||
|
salt = bytes.Repeat([]byte{0}, hash.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hmac.New(hash.New, salt)
|
||||||
|
h.Write(input)
|
||||||
|
out := h.Sum(nil)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "HKDF Extract:\n")
|
||||||
|
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
|
||||||
|
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
|
||||||
|
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
labelExternalBinder = "ext binder"
|
||||||
|
labelResumptionBinder = "res binder"
|
||||||
|
labelEarlyTrafficSecret = "c e traffic"
|
||||||
|
labelEarlyExporterSecret = "e exp master"
|
||||||
|
labelClientHandshakeTrafficSecret = "c hs traffic"
|
||||||
|
labelServerHandshakeTrafficSecret = "s hs traffic"
|
||||||
|
labelClientApplicationTrafficSecret = "c ap traffic"
|
||||||
|
labelServerApplicationTrafficSecret = "s ap traffic"
|
||||||
|
labelExporterSecret = "exp master"
|
||||||
|
labelResumptionSecret = "res master"
|
||||||
|
labelDerived = "derived"
|
||||||
|
labelFinished = "finished"
|
||||||
|
labelResumption = "resumption"
|
||||||
|
)
|
||||||
|
|
||||||
|
// struct HkdfLabel {
|
||||||
|
// uint16 length;
|
||||||
|
// opaque label<9..255>;
|
||||||
|
// opaque hash_value<0..255>;
|
||||||
|
// };
|
||||||
|
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
|
||||||
|
label := "tls13 " + labelIn
|
||||||
|
|
||||||
|
labelLen := len(label)
|
||||||
|
hashLen := len(hashValue)
|
||||||
|
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
|
||||||
|
hkdfLabel[0] = byte(outLen >> 8)
|
||||||
|
hkdfLabel[1] = byte(outLen)
|
||||||
|
hkdfLabel[2] = byte(labelLen)
|
||||||
|
copy(hkdfLabel[3:3+labelLen], []byte(label))
|
||||||
|
hkdfLabel[3+labelLen] = byte(hashLen)
|
||||||
|
copy(hkdfLabel[3+labelLen+1:], hashValue)
|
||||||
|
|
||||||
|
return hkdfLabel
|
||||||
|
}
|
||||||
|
|
||||||
|
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
|
||||||
|
out := []byte{}
|
||||||
|
T := []byte{}
|
||||||
|
i := byte(1)
|
||||||
|
for len(out) < outLen {
|
||||||
|
block := append(T, info...)
|
||||||
|
block = append(block, i)
|
||||||
|
|
||||||
|
h := hmac.New(hash.New, prk)
|
||||||
|
h.Write(block)
|
||||||
|
|
||||||
|
T = h.Sum(nil)
|
||||||
|
out = append(out, T...)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return out[:outLen]
|
||||||
|
}
|
||||||
|
|
||||||
|
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
|
||||||
|
info := hkdfEncodeLabel(label, hashValue, outLen)
|
||||||
|
derived := HkdfExpand(hash, secret, info, outLen)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
|
||||||
|
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
|
||||||
|
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
|
||||||
|
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
|
||||||
|
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
|
||||||
|
|
||||||
|
return derived
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
|
||||||
|
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
|
||||||
|
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
|
||||||
|
mac := hmac.New(params.Hash.New, macKey)
|
||||||
|
mac.Write(input)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type keySet struct {
|
||||||
|
cipher aeadFactory
|
||||||
|
key []byte
|
||||||
|
iv []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
|
||||||
|
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
|
||||||
|
return keySet{
|
||||||
|
cipher: params.Cipher,
|
||||||
|
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
|
||||||
|
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
||||||
|
}
|
||||||
|
}
|
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
@ -0,0 +1,586 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint/syntax"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ExtensionBody interface {
|
||||||
|
Type() ExtensionType
|
||||||
|
Marshal() ([]byte, error)
|
||||||
|
Unmarshal(data []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ExtensionType extension_type;
|
||||||
|
// opaque extension_data<0..2^16-1>;
|
||||||
|
// } Extension;
|
||||||
|
type Extension struct {
|
||||||
|
ExtensionType ExtensionType
|
||||||
|
ExtensionData []byte `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext Extension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *Extension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExtensionList []Extension
|
||||||
|
|
||||||
|
type extensionListInner struct {
|
||||||
|
List []Extension `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el ExtensionList) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(extensionListInner{el})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
|
||||||
|
var list extensionListInner
|
||||||
|
read, err := syntax.Unmarshal(data, &list)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
*el = list.List
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *ExtensionList) Add(src ExtensionBody) error {
|
||||||
|
data, err := src.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if el == nil {
|
||||||
|
el = new(ExtensionList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If one already exists with this type, replace it
|
||||||
|
for i := range *el {
|
||||||
|
if (*el)[i].ExtensionType == src.Type() {
|
||||||
|
(*el)[i].ExtensionData = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise append
|
||||||
|
*el = append(*el, Extension{
|
||||||
|
ExtensionType: src.Type(),
|
||||||
|
ExtensionData: data,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el ExtensionList) Find(dst ExtensionBody) bool {
|
||||||
|
for _, ext := range el {
|
||||||
|
if ext.ExtensionType == dst.Type() {
|
||||||
|
_, err := dst.Unmarshal(ext.ExtensionData)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// NameType name_type;
|
||||||
|
// select (name_type) {
|
||||||
|
// case host_name: HostName;
|
||||||
|
// } name;
|
||||||
|
// } ServerName;
|
||||||
|
//
|
||||||
|
// enum {
|
||||||
|
// host_name(0), (255)
|
||||||
|
// } NameType;
|
||||||
|
//
|
||||||
|
// opaque HostName<1..2^16-1>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ServerName server_name_list<1..2^16-1>
|
||||||
|
// } ServerNameList;
|
||||||
|
//
|
||||||
|
// But we only care about the case where there's a single DNS hostname. We
|
||||||
|
// will never create anything else, and throw if we receive something else
|
||||||
|
//
|
||||||
|
// 2 1 2
|
||||||
|
// | listLen | NameType | nameLen | name |
|
||||||
|
type ServerNameExtension string
|
||||||
|
|
||||||
|
type serverNameInner struct {
|
||||||
|
NameType uint8
|
||||||
|
HostName []byte `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverNameListInner struct {
|
||||||
|
ServerNameList []serverNameInner `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sni ServerNameExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sni ServerNameExtension) Marshal() ([]byte, error) {
|
||||||
|
list := serverNameListInner{
|
||||||
|
ServerNameList: []serverNameInner{{
|
||||||
|
NameType: 0x00, // host_name
|
||||||
|
HostName: []byte(sni),
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
var list serverNameListInner
|
||||||
|
read, err := syntax.Unmarshal(data, &list)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Syntax requires at least one entry
|
||||||
|
// Entries beyond the first are ignored
|
||||||
|
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
|
||||||
|
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
|
||||||
|
}
|
||||||
|
|
||||||
|
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// NamedGroup group;
|
||||||
|
// opaque key_exchange<1..2^16-1>;
|
||||||
|
// } KeyShareEntry;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// select (Handshake.msg_type) {
|
||||||
|
// case client_hello:
|
||||||
|
// KeyShareEntry client_shares<0..2^16-1>;
|
||||||
|
//
|
||||||
|
// case hello_retry_request:
|
||||||
|
// NamedGroup selected_group;
|
||||||
|
//
|
||||||
|
// case server_hello:
|
||||||
|
// KeyShareEntry server_share;
|
||||||
|
// };
|
||||||
|
// } KeyShare;
|
||||||
|
type KeyShareEntry struct {
|
||||||
|
Group NamedGroup
|
||||||
|
KeyExchange []byte `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kse KeyShareEntry) SizeValid() bool {
|
||||||
|
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyShareExtension struct {
|
||||||
|
HandshakeType HandshakeType
|
||||||
|
SelectedGroup NamedGroup
|
||||||
|
Shares []KeyShareEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyShareClientHelloInner struct {
|
||||||
|
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
|
||||||
|
}
|
||||||
|
type KeyShareHelloRetryInner struct {
|
||||||
|
SelectedGroup NamedGroup
|
||||||
|
}
|
||||||
|
type KeyShareServerHelloInner struct {
|
||||||
|
ServerShare KeyShareEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks KeyShareExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeKeyShare
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks KeyShareExtension) Marshal() ([]byte, error) {
|
||||||
|
switch ks.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
for _, share := range ks.Shares {
|
||||||
|
if !share.SizeValid() {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
|
||||||
|
|
||||||
|
case HandshakeTypeHelloRetryRequest:
|
||||||
|
if len(ks.Shares) > 0 {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
if len(ks.Shares) != 1 {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ks.Shares[0].SizeValid() {
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
switch ks.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
var inner KeyShareClientHelloInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, share := range inner.ClientShares {
|
||||||
|
if !share.SizeValid() {
|
||||||
|
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Shares = inner.ClientShares
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
case HandshakeTypeHelloRetryRequest:
|
||||||
|
var inner KeyShareHelloRetryInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.SelectedGroup = inner.SelectedGroup
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
var inner KeyShareServerHelloInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inner.ServerShare.SizeValid() {
|
||||||
|
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||||
|
}
|
||||||
|
|
||||||
|
ks.Shares = []KeyShareEntry{inner.ServerShare}
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// NamedGroup named_group_list<2..2^16-1>;
|
||||||
|
// } NamedGroupList;
|
||||||
|
type SupportedGroupsExtension struct {
|
||||||
|
Groups []NamedGroup `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sg SupportedGroupsExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeSupportedGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
|
||||||
|
// } SignatureSchemeList
|
||||||
|
type SignatureAlgorithmsExtension struct {
|
||||||
|
Algorithms []SignatureScheme `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeSignatureAlgorithms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque identity<1..2^16-1>;
|
||||||
|
// uint32 obfuscated_ticket_age;
|
||||||
|
// } PskIdentity;
|
||||||
|
//
|
||||||
|
// opaque PskBinderEntry<32..255>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// select (Handshake.msg_type) {
|
||||||
|
// case client_hello:
|
||||||
|
// PskIdentity identities<7..2^16-1>;
|
||||||
|
// PskBinderEntry binders<33..2^16-1>;
|
||||||
|
//
|
||||||
|
// case server_hello:
|
||||||
|
// uint16 selected_identity;
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// } PreSharedKeyExtension;
|
||||||
|
type PSKIdentity struct {
|
||||||
|
Identity []byte `tls:"head=2,min=1"`
|
||||||
|
ObfuscatedTicketAge uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type PSKBinderEntry struct {
|
||||||
|
Binder []byte `tls:"head=1,min=32"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreSharedKeyExtension struct {
|
||||||
|
HandshakeType HandshakeType
|
||||||
|
Identities []PSKIdentity
|
||||||
|
Binders []PSKBinderEntry
|
||||||
|
SelectedIdentity uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type preSharedKeyClientInner struct {
|
||||||
|
Identities []PSKIdentity `tls:"head=2,min=7"`
|
||||||
|
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type preSharedKeyServerInner struct {
|
||||||
|
SelectedIdentity uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk PreSharedKeyExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypePreSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
|
||||||
|
switch psk.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
return syntax.Marshal(preSharedKeyClientInner{
|
||||||
|
Identities: psk.Identities,
|
||||||
|
Binders: psk.Binders,
|
||||||
|
})
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
|
||||||
|
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
|
||||||
|
}
|
||||||
|
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
switch psk.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
var inner preSharedKeyClientInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inner.Identities) != len(inner.Binders) {
|
||||||
|
return 0, fmt.Errorf("Lengths of identities and binders not equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
psk.Identities = inner.Identities
|
||||||
|
psk.Binders = inner.Binders
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
var inner preSharedKeyServerInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
psk.SelectedIdentity = inner.SelectedIdentity
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
|
||||||
|
for i, localID := range psk.Identities {
|
||||||
|
if bytes.Equal(localID.Identity, id) {
|
||||||
|
return psk.Binders[i].Binder, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// PskKeyExchangeMode ke_modes<1..255>;
|
||||||
|
// } PskKeyExchangeModes;
|
||||||
|
type PSKKeyExchangeModesExtension struct {
|
||||||
|
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypePSKKeyExchangeModes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(pkem)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, pkem)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// } EarlyDataIndication;
|
||||||
|
|
||||||
|
type EarlyDataExtension struct{}
|
||||||
|
|
||||||
|
func (ed EarlyDataExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeEarlyData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// uint32 max_early_data_size;
|
||||||
|
// } TicketEarlyDataInfo;
|
||||||
|
|
||||||
|
type TicketEarlyDataInfoExtension struct {
|
||||||
|
MaxEarlyDataSize uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeTicketEarlyDataInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(tedi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, tedi)
|
||||||
|
}
|
||||||
|
|
||||||
|
// opaque ProtocolName<1..2^8-1>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ProtocolName protocol_name_list<2..2^16-1>
|
||||||
|
// } ProtocolNameList;
|
||||||
|
type ALPNExtension struct {
|
||||||
|
Protocols []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type protocolNameInner struct {
|
||||||
|
Name []byte `tls:"head=1,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type alpnExtensionInner struct {
|
||||||
|
Protocols []protocolNameInner `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alpn ALPNExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeALPN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alpn ALPNExtension) Marshal() ([]byte, error) {
|
||||||
|
protocols := make([]protocolNameInner, len(alpn.Protocols))
|
||||||
|
for i, protocol := range alpn.Protocols {
|
||||||
|
protocols[i] = protocolNameInner{[]byte(protocol)}
|
||||||
|
}
|
||||||
|
return syntax.Marshal(alpnExtensionInner{protocols})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
var inner alpnExtensionInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
alpn.Protocols = make([]string, len(inner.Protocols))
|
||||||
|
for i, protocol := range inner.Protocols {
|
||||||
|
alpn.Protocols[i] = string(protocol.Name)
|
||||||
|
}
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion versions<2..254>;
|
||||||
|
// } SupportedVersions;
|
||||||
|
type SupportedVersionsExtension struct {
|
||||||
|
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeSupportedVersions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque cookie<1..2^16-1>;
|
||||||
|
// } Cookie;
|
||||||
|
type CookieExtension struct {
|
||||||
|
Cookie []byte `tls:"head=2,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CookieExtension) Type() ExtensionType {
|
||||||
|
return ExtensionTypeCookie
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CookieExtension) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultCookieLength is the default length of a cookie
|
||||||
|
const defaultCookieLength = 32
|
||||||
|
|
||||||
|
type defaultCookieHandler struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ CookieHandler = &defaultCookieHandler{}
|
||||||
|
|
||||||
|
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
||||||
|
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
|
||||||
|
h.data = make([]byte, defaultCookieLength)
|
||||||
|
if _, err := prng.Read(h.data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
|
||||||
|
return bytes.Equal(h.data, data)
|
||||||
|
}
|
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"math/big"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B423861285C97FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
|
||||||
|
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
|
||||||
|
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||||
|
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||||
|
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||||
|
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||||
|
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||||
|
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
|
||||||
|
"FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
|
||||||
|
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||||
|
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||||
|
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||||
|
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||||
|
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||||
|
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||||
|
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||||
|
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||||
|
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||||
|
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||||
|
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||||
|
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||||
|
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||||
|
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||||
|
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||||
|
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||||
|
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
|
||||||
|
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
|
||||||
|
|
||||||
|
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||||
|
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||||
|
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||||
|
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||||
|
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||||
|
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||||
|
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||||
|
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||||
|
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||||
|
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||||
|
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||||
|
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||||
|
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||||
|
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||||
|
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||||
|
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||||
|
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||||
|
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||||
|
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||||
|
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||||
|
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||||
|
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||||
|
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||||
|
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||||
|
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||||
|
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||||
|
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||||
|
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||||
|
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||||
|
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||||
|
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||||
|
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
|
||||||
|
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
|
||||||
|
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
|
||||||
|
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
|
||||||
|
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
|
||||||
|
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
|
||||||
|
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
|
||||||
|
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
|
||||||
|
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
|
||||||
|
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
|
||||||
|
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
|
||||||
|
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
|
||||||
|
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
|
||||||
|
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
|
||||||
|
)
|
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// Read a generic "framed" packet consisting of a header and a
|
||||||
|
// This is used for both TLS Records and TLS Handshake Messages
|
||||||
|
package mint
|
||||||
|
|
||||||
|
type framing interface {
|
||||||
|
headerLen() int
|
||||||
|
defaultReadLen() int
|
||||||
|
frameLen(hdr []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
kFrameReaderHdr = 0
|
||||||
|
kFrameReaderBody = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type frameNextAction func(f *frameReader) error
|
||||||
|
|
||||||
|
type frameReader struct {
|
||||||
|
details framing
|
||||||
|
state uint8
|
||||||
|
header []byte
|
||||||
|
body []byte
|
||||||
|
working []byte
|
||||||
|
writeOffset int
|
||||||
|
remainder []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFrameReader(d framing) *frameReader {
|
||||||
|
hdr := make([]byte, d.headerLen())
|
||||||
|
return &frameReader{
|
||||||
|
d,
|
||||||
|
kFrameReaderHdr,
|
||||||
|
hdr,
|
||||||
|
nil,
|
||||||
|
hdr,
|
||||||
|
0,
|
||||||
|
nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dup(a []byte) []byte {
|
||||||
|
r := make([]byte, len(a))
|
||||||
|
copy(r, a)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *frameReader) needed() int {
|
||||||
|
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
|
||||||
|
if tmp < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *frameReader) addChunk(in []byte) {
|
||||||
|
// Append to the buffer.
|
||||||
|
logf(logTypeFrameReader, "Appending %v", len(in))
|
||||||
|
f.remainder = append(f.remainder, in...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
||||||
|
for f.needed() == 0 {
|
||||||
|
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
|
||||||
|
// Fill out our working block
|
||||||
|
copied := copy(f.working[f.writeOffset:], f.remainder)
|
||||||
|
f.remainder = f.remainder[copied:]
|
||||||
|
f.writeOffset += copied
|
||||||
|
if f.writeOffset < len(f.working) {
|
||||||
|
logf(logTypeFrameReader, "Read would have blocked 1")
|
||||||
|
return nil, nil, WouldBlock
|
||||||
|
}
|
||||||
|
// Reset the write offset, because we are now full.
|
||||||
|
f.writeOffset = 0
|
||||||
|
|
||||||
|
// We have read a full frame
|
||||||
|
if f.state == kFrameReaderBody {
|
||||||
|
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
|
||||||
|
f.state = kFrameReaderHdr
|
||||||
|
f.working = f.header
|
||||||
|
return dup(f.header), dup(f.body), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have read the header
|
||||||
|
bodyLen, err := f.details.frameLen(f.header)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
|
||||||
|
|
||||||
|
f.body = make([]byte, bodyLen)
|
||||||
|
f.working = f.body
|
||||||
|
f.writeOffset = 0
|
||||||
|
f.state = kFrameReaderBody
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeFrameReader, "Read would have blocked 2")
|
||||||
|
return nil, nil, WouldBlock
|
||||||
|
}
|
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
handshakeHeaderLen = 4 // handshake message header length
|
||||||
|
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||||
|
)
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// HandshakeType msg_type; /* handshake type */
|
||||||
|
// uint24 length; /* bytes in message */
|
||||||
|
// select (HandshakeType) {
|
||||||
|
// ...
|
||||||
|
// } body;
|
||||||
|
// } Handshake;
|
||||||
|
//
|
||||||
|
// We do the select{...} part in a different layer, so we treat the
|
||||||
|
// actual message body as opaque:
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// HandshakeType msg_type;
|
||||||
|
// opaque msg<0..2^24-1>
|
||||||
|
// } Handshake;
|
||||||
|
//
|
||||||
|
// TODO: File a spec bug
|
||||||
|
type HandshakeMessage struct {
|
||||||
|
// Omitted: length
|
||||||
|
msgType HandshakeType
|
||||||
|
body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: This could be done with the `syntax` module, using the simplified
|
||||||
|
// syntax as discussed above. However, since this is so simple, there's not
|
||||||
|
// much benefit to doing so.
|
||||||
|
func (hm *HandshakeMessage) Marshal() []byte {
|
||||||
|
if hm == nil {
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLen := len(hm.body)
|
||||||
|
data := make([]byte, 4+len(hm.body))
|
||||||
|
data[0] = byte(hm.msgType)
|
||||||
|
data[1] = byte(msgLen >> 16)
|
||||||
|
data[2] = byte(msgLen >> 8)
|
||||||
|
data[3] = byte(msgLen)
|
||||||
|
copy(data[4:], hm.body)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||||
|
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
|
||||||
|
|
||||||
|
var body HandshakeMessageBody
|
||||||
|
switch hm.msgType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
body = new(ClientHelloBody)
|
||||||
|
case HandshakeTypeServerHello:
|
||||||
|
body = new(ServerHelloBody)
|
||||||
|
case HandshakeTypeHelloRetryRequest:
|
||||||
|
body = new(HelloRetryRequestBody)
|
||||||
|
case HandshakeTypeEncryptedExtensions:
|
||||||
|
body = new(EncryptedExtensionsBody)
|
||||||
|
case HandshakeTypeCertificate:
|
||||||
|
body = new(CertificateBody)
|
||||||
|
case HandshakeTypeCertificateRequest:
|
||||||
|
body = new(CertificateRequestBody)
|
||||||
|
case HandshakeTypeCertificateVerify:
|
||||||
|
body = new(CertificateVerifyBody)
|
||||||
|
case HandshakeTypeFinished:
|
||||||
|
body = &FinishedBody{VerifyDataLen: len(hm.body)}
|
||||||
|
case HandshakeTypeNewSessionTicket:
|
||||||
|
body = new(NewSessionTicketBody)
|
||||||
|
case HandshakeTypeKeyUpdate:
|
||||||
|
body = new(KeyUpdateBody)
|
||||||
|
case HandshakeTypeEndOfEarlyData:
|
||||||
|
body = new(EndOfEarlyDataBody)
|
||||||
|
default:
|
||||||
|
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := body.Unmarshal(hm.body)
|
||||||
|
return body, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||||
|
data, err := body.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HandshakeMessage{
|
||||||
|
msgType: body.Type(),
|
||||||
|
body: data,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandshakeLayer struct {
|
||||||
|
nonblocking bool // Should we operate in nonblocking mode
|
||||||
|
conn *RecordLayer // Used for reading/writing records
|
||||||
|
frame *frameReader // The buffered frame reader
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshakeLayerFrameDetails struct{}
|
||||||
|
|
||||||
|
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||||
|
return handshakeHeaderLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||||
|
return handshakeHeaderLen + maxFragmentLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||||
|
logf(logTypeIO, "Header=%x", hdr)
|
||||||
|
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
||||||
|
h := HandshakeLayer{}
|
||||||
|
h.conn = r
|
||||||
|
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
||||||
|
return &h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) readRecord() error {
|
||||||
|
logf(logTypeIO, "Trying to read record")
|
||||||
|
pt, err := h.conn.ReadRecord()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pt.contentType != RecordTypeHandshake &&
|
||||||
|
pt.contentType != RecordTypeAlert {
|
||||||
|
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pt.contentType == RecordTypeAlert {
|
||||||
|
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||||
|
if len(pt.fragment) < 2 {
|
||||||
|
h.sendAlert(AlertUnexpectedMessage)
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return Alert(pt.fragment[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
||||||
|
h.frame.addChunk(pt.fragment)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendAlert sends a TLS alert message.
|
||||||
|
func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||||||
|
tmp := make([]byte, 2)
|
||||||
|
tmp[0] = AlertLevelError
|
||||||
|
tmp[1] = byte(err)
|
||||||
|
h.conn.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeAlert,
|
||||||
|
fragment: tmp},
|
||||||
|
)
|
||||||
|
|
||||||
|
// closeNotify is a special case in that it isn't an error:
|
||||||
|
if err != AlertCloseNotify {
|
||||||
|
return &net.OpError{Op: "local error", Err: err}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||||
|
var hdr, body []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for {
|
||||||
|
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||||
|
if h.frame.needed() > 0 {
|
||||||
|
logf(logTypeHandshake, "Trying to read a new record")
|
||||||
|
err = h.readRecord()
|
||||||
|
}
|
||||||
|
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hdr, body, err = h.frame.process()
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "read handshake message")
|
||||||
|
|
||||||
|
hm := &HandshakeMessage{}
|
||||||
|
hm.msgType = HandshakeType(hdr[0])
|
||||||
|
|
||||||
|
hm.body = make([]byte, len(body))
|
||||||
|
copy(hm.body, body)
|
||||||
|
|
||||||
|
return hm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
||||||
|
return h.WriteMessages([]*HandshakeMessage{hm})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
||||||
|
for _, hm := range hms {
|
||||||
|
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write out headers and bodies
|
||||||
|
buffer := []byte{}
|
||||||
|
for _, msg := range hms {
|
||||||
|
msgLen := len(msg.body)
|
||||||
|
if msgLen > maxHandshakeMessageLen {
|
||||||
|
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer = append(buffer, msg.Marshal()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send full-size fragments
|
||||||
|
var start int
|
||||||
|
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||||
|
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeHandshake,
|
||||||
|
fragment: buffer[start : start+maxFragmentLen],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a final partial fragment if necessary
|
||||||
|
if start < len(buffer) {
|
||||||
|
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeHandshake,
|
||||||
|
fragment: buffer[start:],
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
@ -0,0 +1,450 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint/syntax"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandshakeMessageBody interface {
|
||||||
|
Type() HandshakeType
|
||||||
|
Marshal() ([]byte, error)
|
||||||
|
Unmarshal(data []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
||||||
|
// Random random;
|
||||||
|
// opaque legacy_session_id<0..32>;
|
||||||
|
// CipherSuite cipher_suites<2..2^16-2>;
|
||||||
|
// opaque legacy_compression_methods<1..2^8-1>;
|
||||||
|
// Extension extensions<0..2^16-1>;
|
||||||
|
// } ClientHello;
|
||||||
|
type ClientHelloBody struct {
|
||||||
|
// Omitted: clientVersion
|
||||||
|
// Omitted: legacySessionID
|
||||||
|
// Omitted: legacyCompressionMethods
|
||||||
|
Random [32]byte
|
||||||
|
CipherSuites []CipherSuite
|
||||||
|
Extensions ExtensionList
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientHelloBodyInner struct {
|
||||||
|
LegacyVersion uint16
|
||||||
|
Random [32]byte
|
||||||
|
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||||
|
CipherSuites []CipherSuite `tls:"head=2,min=2"`
|
||||||
|
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
|
||||||
|
Extensions []Extension `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch ClientHelloBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeClientHello
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(clientHelloBodyInner{
|
||||||
|
LegacyVersion: 0x0303,
|
||||||
|
Random: ch.Random,
|
||||||
|
LegacySessionID: []byte{},
|
||||||
|
CipherSuites: ch.CipherSuites,
|
||||||
|
LegacyCompressionMethods: []byte{0},
|
||||||
|
Extensions: ch.Extensions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
var inner clientHelloBodyInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are strict about these things because we only support 1.3
|
||||||
|
if inner.LegacyVersion != 0x0303 {
|
||||||
|
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||||
|
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.Random = inner.Random
|
||||||
|
ch.CipherSuites = inner.CipherSuites
|
||||||
|
ch.Extensions = inner.Extensions
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: File a spec bug to clarify this
|
||||||
|
func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
||||||
|
if len(ch.Extensions) == 0 {
|
||||||
|
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
pskExt := ch.Extensions[len(ch.Extensions)-1]
|
||||||
|
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
|
||||||
|
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
||||||
|
}
|
||||||
|
|
||||||
|
chm, err := HandshakeMessageFromBody(&ch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
chData := chm.Marshal()
|
||||||
|
|
||||||
|
psk := PreSharedKeyExtension{
|
||||||
|
HandshakeType: HandshakeTypeClientHello,
|
||||||
|
}
|
||||||
|
_, err = psk.Unmarshal(pskExt.ExtensionData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal just the binders so that we know how much to truncate
|
||||||
|
binders := struct {
|
||||||
|
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||||
|
}{Binders: psk.Binders}
|
||||||
|
binderData, _ := syntax.Marshal(binders)
|
||||||
|
binderLen := len(binderData)
|
||||||
|
|
||||||
|
chLen := len(chData)
|
||||||
|
return chData[:chLen-binderLen], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion server_version;
|
||||||
|
// CipherSuite cipher_suite;
|
||||||
|
// Extension extensions<2..2^16-1>;
|
||||||
|
// } HelloRetryRequest;
|
||||||
|
type HelloRetryRequestBody struct {
|
||||||
|
Version uint16
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
Extensions ExtensionList `tls:"head=2,min=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hrr HelloRetryRequestBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeHelloRetryRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(hrr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, hrr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ProtocolVersion version;
|
||||||
|
// Random random;
|
||||||
|
// CipherSuite cipher_suite;
|
||||||
|
// Extension extensions<0..2^16-1>;
|
||||||
|
// } ServerHello;
|
||||||
|
type ServerHelloBody struct {
|
||||||
|
Version uint16
|
||||||
|
Random [32]byte
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh ServerHelloBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeServerHello
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh ServerHelloBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(sh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, sh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque verify_data[verify_data_length];
|
||||||
|
// } Finished;
|
||||||
|
//
|
||||||
|
// verifyDataLen is not a field in the TLS struct, but we add it here so
|
||||||
|
// that calling code can tell us how much data to expect when we marshal /
|
||||||
|
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
|
||||||
|
// try to keep the signature consistent for now.)
|
||||||
|
//
|
||||||
|
// For similar reasons, we don't use the `syntax` module here, because this
|
||||||
|
// struct doesn't map well to standard TLS presentation language concepts.
|
||||||
|
//
|
||||||
|
// TODO: File a spec bug
|
||||||
|
type FinishedBody struct {
|
||||||
|
VerifyDataLen int
|
||||||
|
VerifyData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fin FinishedBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeFinished
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fin FinishedBody) Marshal() ([]byte, error) {
|
||||||
|
if len(fin.VerifyData) != fin.VerifyDataLen {
|
||||||
|
return nil, fmt.Errorf("tls.finished: data length mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
body := make([]byte, len(fin.VerifyData))
|
||||||
|
copy(body, fin.VerifyData)
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
if len(data) < fin.VerifyDataLen {
|
||||||
|
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
fin.VerifyData = make([]byte, fin.VerifyDataLen)
|
||||||
|
copy(fin.VerifyData, data[:fin.VerifyDataLen])
|
||||||
|
return fin.VerifyDataLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// Extension extensions<0..2^16-1>;
|
||||||
|
// } EncryptedExtensions;
|
||||||
|
//
|
||||||
|
// Marshal() and Unmarshal() are handled by ExtensionList
|
||||||
|
type EncryptedExtensionsBody struct {
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ee EncryptedExtensionsBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeEncryptedExtensions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(ee)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, ee)
|
||||||
|
}
|
||||||
|
|
||||||
|
// opaque ASN1Cert<1..2^24-1>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ASN1Cert cert_data;
|
||||||
|
// Extension extensions<0..2^16-1>
|
||||||
|
// } CertificateEntry;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// opaque certificate_request_context<0..2^8-1>;
|
||||||
|
// CertificateEntry certificate_list<0..2^24-1>;
|
||||||
|
// } Certificate;
|
||||||
|
type CertificateEntry struct {
|
||||||
|
CertData *x509.Certificate
|
||||||
|
Extensions ExtensionList
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertificateBody struct {
|
||||||
|
CertificateRequestContext []byte
|
||||||
|
CertificateList []CertificateEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type certificateEntryInner struct {
|
||||||
|
CertData []byte `tls:"head=3,min=1"`
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type certificateBodyInner struct {
|
||||||
|
CertificateRequestContext []byte `tls:"head=1"`
|
||||||
|
CertificateList []certificateEntryInner `tls:"head=3"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CertificateBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeCertificate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c CertificateBody) Marshal() ([]byte, error) {
|
||||||
|
inner := certificateBodyInner{
|
||||||
|
CertificateRequestContext: c.CertificateRequestContext,
|
||||||
|
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, entry := range c.CertificateList {
|
||||||
|
inner.CertificateList[i] = certificateEntryInner{
|
||||||
|
CertData: entry.CertData.Raw,
|
||||||
|
Extensions: entry.Extensions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return syntax.Marshal(inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
inner := certificateBodyInner{}
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return read, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.CertificateRequestContext = inner.CertificateRequestContext
|
||||||
|
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
|
||||||
|
|
||||||
|
for i, entry := range inner.CertificateList {
|
||||||
|
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.CertificateList[i].Extensions = entry.Extensions
|
||||||
|
}
|
||||||
|
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// SignatureScheme algorithm;
|
||||||
|
// opaque signature<0..2^16-1>;
|
||||||
|
// } CertificateVerify;
|
||||||
|
type CertificateVerifyBody struct {
|
||||||
|
Algorithm SignatureScheme
|
||||||
|
Signature []byte `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv CertificateVerifyBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeCertificateVerify
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(cv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, cv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
|
||||||
|
// TODO: Change context for client auth
|
||||||
|
// TODO: Put this in a const
|
||||||
|
const context = "TLS 1.3, server CertificateVerify"
|
||||||
|
sigInput := bytes.Repeat([]byte{0x20}, 64)
|
||||||
|
sigInput = append(sigInput, []byte(context)...)
|
||||||
|
sigInput = append(sigInput, []byte{0}...)
|
||||||
|
sigInput = append(sigInput, data...)
|
||||||
|
return sigInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
|
||||||
|
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||||
|
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
|
||||||
|
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
|
||||||
|
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||||
|
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||||
|
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// opaque certificate_request_context<0..2^8-1>;
|
||||||
|
// Extension extensions<2..2^16-1>;
|
||||||
|
// } CertificateRequest;
|
||||||
|
type CertificateRequestBody struct {
|
||||||
|
CertificateRequestContext []byte `tls:"head=1"`
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr CertificateRequestBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeCertificateRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(cr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, cr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// uint32 ticket_lifetime;
|
||||||
|
// uint32 ticket_age_add;
|
||||||
|
// opaque ticket_nonce<1..255>;
|
||||||
|
// opaque ticket<1..2^16-1>;
|
||||||
|
// Extension extensions<0..2^16-2>;
|
||||||
|
// } NewSessionTicket;
|
||||||
|
type NewSessionTicketBody struct {
|
||||||
|
TicketLifetime uint32
|
||||||
|
TicketAgeAdd uint32
|
||||||
|
TicketNonce []byte `tls:"head=1,min=1"`
|
||||||
|
Ticket []byte `tls:"head=2,min=1"`
|
||||||
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const ticketNonceLen = 16
|
||||||
|
|
||||||
|
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
|
||||||
|
buf := make([]byte, 4+ticketNonceLen+ticketLen)
|
||||||
|
_, err := prng.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tkt := &NewSessionTicketBody{
|
||||||
|
TicketLifetime: ticketLifetime,
|
||||||
|
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
|
||||||
|
TicketNonce: buf[4 : 4+ticketNonceLen],
|
||||||
|
Ticket: buf[4+ticketNonceLen:],
|
||||||
|
}
|
||||||
|
|
||||||
|
return tkt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tkt NewSessionTicketBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeNewSessionTicket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(tkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, tkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enum {
|
||||||
|
// update_not_requested(0), update_requested(1), (255)
|
||||||
|
// } KeyUpdateRequest;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// KeyUpdateRequest request_update;
|
||||||
|
// } KeyUpdate;
|
||||||
|
type KeyUpdateBody struct {
|
||||||
|
KeyUpdateRequest KeyUpdateRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ku KeyUpdateBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeKeyUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
|
||||||
|
return syntax.Marshal(ku)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return syntax.Unmarshal(data, ku)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {} EndOfEarlyData;
|
||||||
|
type EndOfEarlyDataBody struct{}
|
||||||
|
|
||||||
|
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
|
||||||
|
return HandshakeTypeEndOfEarlyData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// We use this environment variable to control logging. It should be a
|
||||||
|
// comma-separated list of log tags (see below) or "*" to enable all logging.
|
||||||
|
const logConfigVar = "MINT_LOG"
|
||||||
|
|
||||||
|
// Pre-defined log types
|
||||||
|
const (
|
||||||
|
logTypeCrypto = "crypto"
|
||||||
|
logTypeHandshake = "handshake"
|
||||||
|
logTypeNegotiation = "negotiation"
|
||||||
|
logTypeIO = "io"
|
||||||
|
logTypeFrameReader = "frame"
|
||||||
|
logTypeVerbose = "verbose"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
logFunction = log.Printf
|
||||||
|
logAll = false
|
||||||
|
logSettings = map[string]bool{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
parseLogEnv(os.Environ())
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLogEnv(env []string) {
|
||||||
|
for _, stmt := range env {
|
||||||
|
if strings.HasPrefix(stmt, logConfigVar+"=") {
|
||||||
|
val := stmt[len(logConfigVar)+1:]
|
||||||
|
|
||||||
|
if val == "*" {
|
||||||
|
logAll = true
|
||||||
|
} else {
|
||||||
|
for _, t := range strings.Split(val, ",") {
|
||||||
|
logSettings[t] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logf(tag string, format string, args ...interface{}) {
|
||||||
|
if logAll || logSettings[tag] {
|
||||||
|
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
|
||||||
|
logFunction(fullFormat, args...)
|
||||||
|
}
|
||||||
|
}
|
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
|
||||||
|
for _, offeredVersion := range offered {
|
||||||
|
for _, supportedVersion := range supported {
|
||||||
|
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
|
||||||
|
if offeredVersion == supportedVersion {
|
||||||
|
// XXX: Should probably be highest supported version, but for now, we
|
||||||
|
// only support one version, so it doesn't really matter.
|
||||||
|
return true, offeredVersion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
|
||||||
|
for _, share := range keyShares {
|
||||||
|
for _, group := range groups {
|
||||||
|
if group != share.Group {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pub, priv, err := newKeyShare(share.Group)
|
||||||
|
if err != nil {
|
||||||
|
// If we encounter an error, just keep looking
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
|
||||||
|
if err != nil {
|
||||||
|
// If we encounter an error, just keep looking
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, group, pub, dhSecret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
|
||||||
|
)
|
||||||
|
|
||||||
|
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
|
||||||
|
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
|
||||||
|
for i, id := range identities {
|
||||||
|
identityHex := hex.EncodeToString(id.Identity)
|
||||||
|
|
||||||
|
psk, ok := psks.Get(identityHex)
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// For resumption, make sure the ticket age is correct
|
||||||
|
if psk.IsResumption {
|
||||||
|
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
|
||||||
|
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
|
||||||
|
ticketAgeDelta := knownTicketAge - extTicketAge
|
||||||
|
if knownTicketAge < extTicketAge {
|
||||||
|
ticketAgeDelta = extTicketAge - knownTicketAge
|
||||||
|
}
|
||||||
|
if ticketAgeDelta > ticketAgeTolerance {
|
||||||
|
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
|
||||||
|
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
|
||||||
|
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params, ok := cipherSuiteMap[psk.CipherSuite]
|
||||||
|
if !ok {
|
||||||
|
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute binder
|
||||||
|
binderLabel := labelExternalBinder
|
||||||
|
if psk.IsResumption {
|
||||||
|
binderLabel = labelResumptionBinder
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
|
||||||
|
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||||
|
|
||||||
|
// context = ClientHello[truncated]
|
||||||
|
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
|
||||||
|
ctxHash := params.Hash.New()
|
||||||
|
ctxHash.Write(context)
|
||||||
|
|
||||||
|
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
|
||||||
|
if !bytes.Equal(binder, binders[i].Binder) {
|
||||||
|
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
|
||||||
|
return true, i, &psk, params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeNegotiation, "Failed to find a usable PSK")
|
||||||
|
return false, 0, nil, CipherSuiteParams{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
|
||||||
|
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
|
||||||
|
dhAllowed := false
|
||||||
|
dhRequired := true
|
||||||
|
for _, mode := range modes {
|
||||||
|
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
|
||||||
|
dhRequired = dhRequired && (mode == PSKModeDHEKE)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use PSK if we can meet DH requirement and modes were provided
|
||||||
|
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
|
||||||
|
|
||||||
|
// Use DH if allowed
|
||||||
|
usingDH := canDoDH && (dhAllowed || !usingPSK)
|
||||||
|
|
||||||
|
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
|
||||||
|
return usingDH, usingPSK
|
||||||
|
}
|
||||||
|
|
||||||
|
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
|
||||||
|
// Select for server name if provided
|
||||||
|
candidates := certs
|
||||||
|
if serverName != nil {
|
||||||
|
candidatesByName := []*Certificate{}
|
||||||
|
for _, cert := range certs {
|
||||||
|
for _, name := range cert.Chain[0].DNSNames {
|
||||||
|
if len(*serverName) > 0 && name == *serverName {
|
||||||
|
candidatesByName = append(candidatesByName, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(candidatesByName) == 0 {
|
||||||
|
return nil, 0, fmt.Errorf("No certificates available for server name")
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates = candidatesByName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select for signature scheme
|
||||||
|
for _, cert := range candidates {
|
||||||
|
for _, scheme := range signatureSchemes {
|
||||||
|
if !schemeValidForKey(scheme, cert.PrivateKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, scheme, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
|
||||||
|
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
|
||||||
|
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
|
||||||
|
return usingEarlyData
|
||||||
|
}
|
||||||
|
|
||||||
|
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
||||||
|
for _, s1 := range offered {
|
||||||
|
if psk != nil {
|
||||||
|
if s1 == psk.CipherSuite {
|
||||||
|
return s1, nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s2 := range supported {
|
||||||
|
if s1 == s2 {
|
||||||
|
return s1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
|
||||||
|
for _, p1 := range offered {
|
||||||
|
if psk != nil {
|
||||||
|
if p1 != psk.NextProto {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p2 := range supported {
|
||||||
|
if p1 == p2 {
|
||||||
|
return p1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client offers ALPN on resumption, it must match the earlier one
|
||||||
|
var err error
|
||||||
|
if psk != nil && psk.IsResumption && (len(offered) > 0) {
|
||||||
|
err = fmt.Errorf("ALPN for PSK not provided")
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sequenceNumberLen = 8 // sequence number length
|
||||||
|
recordHeaderLen = 5 // record header length
|
||||||
|
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||||
|
)
|
||||||
|
|
||||||
|
type DecryptError string
|
||||||
|
|
||||||
|
func (err DecryptError) Error() string {
|
||||||
|
return string(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// struct {
|
||||||
|
// ContentType type;
|
||||||
|
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
||||||
|
// uint16 length;
|
||||||
|
// opaque fragment[TLSPlaintext.length];
|
||||||
|
// } TLSPlaintext;
|
||||||
|
type TLSPlaintext struct {
|
||||||
|
// Omitted: record_version (static)
|
||||||
|
// Omitted: length (computed from fragment)
|
||||||
|
contentType RecordType
|
||||||
|
fragment []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type RecordLayer struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
conn io.ReadWriter // The underlying connection
|
||||||
|
frame *frameReader // The buffered frame reader
|
||||||
|
nextData []byte // The next record to send
|
||||||
|
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
||||||
|
cachedError error // Error on the last record read
|
||||||
|
|
||||||
|
ivLength int // Length of the seq and nonce fields
|
||||||
|
seq []byte // Zero-padded sequence number
|
||||||
|
nonce []byte // Buffer for per-record nonces
|
||||||
|
cipher cipher.AEAD // AEAD cipher
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordLayerFrameDetails struct{}
|
||||||
|
|
||||||
|
func (d recordLayerFrameDetails) headerLen() int {
|
||||||
|
return recordHeaderLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d recordLayerFrameDetails) defaultReadLen() int {
|
||||||
|
return recordHeaderLen + maxFragmentLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||||
|
return (int(hdr[3]) << 8) | int(hdr[4]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
|
||||||
|
r := RecordLayer{}
|
||||||
|
r.conn = conn
|
||||||
|
r.frame = newFrameReader(recordLayerFrameDetails{})
|
||||||
|
r.ivLength = 0
|
||||||
|
return &r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
|
||||||
|
var err error
|
||||||
|
r.cipher, err = cipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ivLength = len(iv)
|
||||||
|
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
|
||||||
|
r.nonce = make([]byte, r.ivLength)
|
||||||
|
copy(r.nonce, iv)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) incrementSequenceNumber() {
|
||||||
|
if r.ivLength == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
|
||||||
|
r.seq[i]++
|
||||||
|
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
|
||||||
|
if r.seq[i] != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not allowed to let sequence number wrap.
|
||||||
|
// Instead, must renegotiate before it does.
|
||||||
|
// Not likely enough to bother.
|
||||||
|
panic("TLS: sequence number wraparound")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||||
|
// Expand the fragment to hold contentType, padding, and overhead
|
||||||
|
originalLen := len(pt.fragment)
|
||||||
|
plaintextLen := originalLen + 1 + padLen
|
||||||
|
ciphertextLen := plaintextLen + r.cipher.Overhead()
|
||||||
|
|
||||||
|
// Assemble the revised plaintext
|
||||||
|
out := &TLSPlaintext{
|
||||||
|
contentType: RecordTypeApplicationData,
|
||||||
|
fragment: make([]byte, ciphertextLen),
|
||||||
|
}
|
||||||
|
copy(out.fragment, pt.fragment)
|
||||||
|
out.fragment[originalLen] = byte(pt.contentType)
|
||||||
|
for i := 1; i <= padLen; i++ {
|
||||||
|
out.fragment[originalLen+i] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the fragment
|
||||||
|
payload := out.fragment[:plaintextLen]
|
||||||
|
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
||||||
|
if len(pt.fragment) < r.cipher.Overhead() {
|
||||||
|
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
|
||||||
|
return nil, 0, DecryptError(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptLen := len(pt.fragment) - r.cipher.Overhead()
|
||||||
|
out := &TLSPlaintext{
|
||||||
|
contentType: pt.contentType,
|
||||||
|
fragment: make([]byte, decryptLen),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt
|
||||||
|
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the padding boundary
|
||||||
|
padLen := 0
|
||||||
|
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transfer the content type
|
||||||
|
newLen := decryptLen - padLen - 1
|
||||||
|
out.contentType = RecordType(out.fragment[newLen])
|
||||||
|
|
||||||
|
// Truncate the message to remove contentType, padding, overhead
|
||||||
|
out.fragment = out.fragment[:newLen]
|
||||||
|
return out, padLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
||||||
|
var pt *TLSPlaintext
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for {
|
||||||
|
pt, err = r.nextRecord()
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !block || err != WouldBlock {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pt.contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||||
|
pt, err := r.nextRecord()
|
||||||
|
|
||||||
|
// Consume the cached record if there was one
|
||||||
|
r.cachedRecord = nil
|
||||||
|
r.cachedError = nil
|
||||||
|
|
||||||
|
return pt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||||
|
if r.cachedRecord != nil {
|
||||||
|
logf(logTypeIO, "Returning cached record")
|
||||||
|
return r.cachedRecord, r.cachedError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop until one of three things happens:
|
||||||
|
//
|
||||||
|
// 1. We get a frame
|
||||||
|
// 2. We try to read off the socket and get nothing, in which case
|
||||||
|
// return WouldBlock
|
||||||
|
// 3. We get an error.
|
||||||
|
err := WouldBlock
|
||||||
|
var header, body []byte
|
||||||
|
|
||||||
|
for err != nil {
|
||||||
|
if r.frame.needed() > 0 {
|
||||||
|
buf := make([]byte, recordHeaderLen+maxFragmentLen)
|
||||||
|
n, err := r.conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeIO, "Error reading, %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return nil, WouldBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeIO, "Read %v bytes", n)
|
||||||
|
|
||||||
|
buf = buf[:n]
|
||||||
|
r.frame.addChunk(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
header, body, err = r.frame.process()
|
||||||
|
// Loop around on WouldBlock to see if some
|
||||||
|
// data is now available.
|
||||||
|
if err != nil && err != WouldBlock {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pt := &TLSPlaintext{}
|
||||||
|
// Validate content type
|
||||||
|
switch RecordType(header[0]) {
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
||||||
|
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
|
||||||
|
pt.contentType = RecordType(header[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate version
|
||||||
|
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
|
||||||
|
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate size < max
|
||||||
|
size := (int(header[3]) << 8) + int(header[4])
|
||||||
|
if size > maxFragmentLen+256 {
|
||||||
|
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
pt.fragment = make([]byte, size)
|
||||||
|
copy(pt.fragment, body)
|
||||||
|
|
||||||
|
// Attempt to decrypt fragment
|
||||||
|
if r.cipher != nil {
|
||||||
|
pt, _, err = r.decrypt(pt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that plaintext length is not too long
|
||||||
|
if len(pt.fragment) > maxFragmentLen {
|
||||||
|
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||||
|
|
||||||
|
r.cachedRecord = pt
|
||||||
|
r.incrementSequenceNumber()
|
||||||
|
return pt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
||||||
|
return r.WriteRecordWithPadding(pt, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
||||||
|
if r.cipher != nil {
|
||||||
|
pt = r.encrypt(pt, padLen)
|
||||||
|
} else if padLen > 0 {
|
||||||
|
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pt.fragment) > maxFragmentLen {
|
||||||
|
return fmt.Errorf("tls.record: Record size too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
length := len(pt.fragment)
|
||||||
|
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
|
||||||
|
record := append(header, pt.fragment...)
|
||||||
|
|
||||||
|
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||||
|
|
||||||
|
r.incrementSequenceNumber()
|
||||||
|
_, err := r.conn.Write(record)
|
||||||
|
return err
|
||||||
|
}
|
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
@ -0,0 +1,898 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"hash"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server State Machine
|
||||||
|
//
|
||||||
|
// START <-----+
|
||||||
|
// Recv ClientHello | | Send HelloRetryRequest
|
||||||
|
// v |
|
||||||
|
// RECVD_CH ----+
|
||||||
|
// | Select parameters
|
||||||
|
// | Send ServerHello
|
||||||
|
// v
|
||||||
|
// NEGOTIATED
|
||||||
|
// | Send EncryptedExtensions
|
||||||
|
// | [Send CertificateRequest]
|
||||||
|
// Can send | [Send Certificate + CertificateVerify]
|
||||||
|
// app data --> | Send Finished
|
||||||
|
// after +--------+--------+
|
||||||
|
// here No 0-RTT | | 0-RTT
|
||||||
|
// | v
|
||||||
|
// | WAIT_EOED <---+
|
||||||
|
// | Recv | | | Recv
|
||||||
|
// | EndOfEarlyData | | | early data
|
||||||
|
// | | +-----+
|
||||||
|
// +> WAIT_FLIGHT2 <-+
|
||||||
|
// |
|
||||||
|
// +--------+--------+
|
||||||
|
// No auth | | Client auth
|
||||||
|
// | |
|
||||||
|
// | v
|
||||||
|
// | WAIT_CERT
|
||||||
|
// | Recv | | Recv Certificate
|
||||||
|
// | empty | v
|
||||||
|
// | Certificate | WAIT_CV
|
||||||
|
// | | | Recv
|
||||||
|
// | v | CertificateVerify
|
||||||
|
// +-> WAIT_FINISHED <---+
|
||||||
|
// | Recv Finished
|
||||||
|
// v
|
||||||
|
// CONNECTED
|
||||||
|
//
|
||||||
|
// NB: Not using state RECVD_CH
|
||||||
|
//
|
||||||
|
// State Instructions
|
||||||
|
// START {}
|
||||||
|
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
|
||||||
|
// WAIT_EOED RekeyIn;
|
||||||
|
// WAIT_FLIGHT2 {}
|
||||||
|
// WAIT_CERT_CR {}
|
||||||
|
// WAIT_CERT {}
|
||||||
|
// WAIT_CV {}
|
||||||
|
// WAIT_FINISHED RekeyIn; RekeyOut;
|
||||||
|
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||||
|
|
||||||
|
type ServerStateStart struct {
|
||||||
|
Caps Capabilities
|
||||||
|
conn *Conn
|
||||||
|
|
||||||
|
cookieSent bool
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeClientHello {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := &ClientHelloBody{}
|
||||||
|
_, err := ch.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
clientHello := hm
|
||||||
|
connParams := ConnectionParameters{}
|
||||||
|
|
||||||
|
supportedVersions := new(SupportedVersionsExtension)
|
||||||
|
serverName := new(ServerNameExtension)
|
||||||
|
supportedGroups := new(SupportedGroupsExtension)
|
||||||
|
signatureAlgorithms := new(SignatureAlgorithmsExtension)
|
||||||
|
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
|
||||||
|
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
|
||||||
|
clientEarlyData := &EarlyDataExtension{}
|
||||||
|
clientALPN := new(ALPNExtension)
|
||||||
|
clientPSKModes := new(PSKKeyExchangeModesExtension)
|
||||||
|
clientCookie := new(CookieExtension)
|
||||||
|
|
||||||
|
// Handle external extensions.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
|
||||||
|
gotServerName := ch.Extensions.Find(serverName)
|
||||||
|
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
|
||||||
|
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
|
||||||
|
gotEarlyData := ch.Extensions.Find(clientEarlyData)
|
||||||
|
ch.Extensions.Find(clientKeyShares)
|
||||||
|
ch.Extensions.Find(clientPSK)
|
||||||
|
ch.Extensions.Find(clientALPN)
|
||||||
|
ch.Extensions.Find(clientPSKModes)
|
||||||
|
ch.Extensions.Find(clientCookie)
|
||||||
|
|
||||||
|
if gotServerName {
|
||||||
|
connParams.ServerName = string(*serverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the client didn't send supportedVersions or doesn't support 1.3,
|
||||||
|
// then we're done here.
|
||||||
|
if !gotSupportedVersions {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
|
||||||
|
if !versionOK {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
|
||||||
|
return nil, nil, AlertProtocolVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
|
||||||
|
return nil, nil, AlertAccessDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out if we can do DH
|
||||||
|
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
|
||||||
|
|
||||||
|
// Figure out if we can do PSK
|
||||||
|
canDoPSK := false
|
||||||
|
var selectedPSK int
|
||||||
|
var psk *PreSharedKey
|
||||||
|
var params CipherSuiteParams
|
||||||
|
if len(clientPSK.Identities) > 0 {
|
||||||
|
contextBase := []byte{}
|
||||||
|
if state.helloRetryRequest != nil {
|
||||||
|
chBytes := state.firstClientHello.Marshal()
|
||||||
|
hrrBytes := state.helloRetryRequest.Marshal()
|
||||||
|
contextBase = append(chBytes, hrrBytes...)
|
||||||
|
}
|
||||||
|
|
||||||
|
chTrunc, err := ch.Truncated()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
context := append(contextBase, chTrunc...)
|
||||||
|
|
||||||
|
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out if we actually should do DH / PSK
|
||||||
|
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
|
||||||
|
|
||||||
|
// Select a ciphersuite
|
||||||
|
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a cookie if required
|
||||||
|
// NB: Need to do this here because it's after ciphersuite selection, which
|
||||||
|
// has to be after PSK selection.
|
||||||
|
// XXX: Doing this statefully for now, could be stateless
|
||||||
|
var cookieData []byte
|
||||||
|
if state.Caps.RequireCookie && !state.cookieSent {
|
||||||
|
var err error
|
||||||
|
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cookieData != nil {
|
||||||
|
// Ignoring errors because everything here is newly constructed, so there
|
||||||
|
// shouldn't be marshal errors
|
||||||
|
hrr := &HelloRetryRequestBody{
|
||||||
|
Version: supportedVersion,
|
||||||
|
CipherSuite: connParams.CipherSuite,
|
||||||
|
}
|
||||||
|
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
params := cipherSuiteMap[connParams.CipherSuite]
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(clientHello.Marshal())
|
||||||
|
firstClientHello := &HandshakeMessage{
|
||||||
|
msgType: HandshakeTypeMessageHash,
|
||||||
|
body: h.Sum(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
nextState := ServerStateStart{
|
||||||
|
Caps: state.Caps,
|
||||||
|
conn: state.conn,
|
||||||
|
cookieSent: true,
|
||||||
|
firstClientHello: firstClientHello,
|
||||||
|
helloRetryRequest: helloRetryRequest,
|
||||||
|
}
|
||||||
|
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've got no entropy to make keys from, fail
|
||||||
|
if !connParams.UsingDH && !connParams.UsingPSK {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
var pskSecret []byte
|
||||||
|
var cert *Certificate
|
||||||
|
var certScheme SignatureScheme
|
||||||
|
if connParams.UsingPSK {
|
||||||
|
pskSecret = psk.Key
|
||||||
|
} else {
|
||||||
|
psk = nil
|
||||||
|
|
||||||
|
// If we're not using a PSK mode, then we need to have certain extensions
|
||||||
|
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
|
||||||
|
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
|
||||||
|
return nil, nil, AlertMissingExtension
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select a certificate
|
||||||
|
name := string(*serverName)
|
||||||
|
var err error
|
||||||
|
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
|
||||||
|
return nil, nil, AlertAccessDenied
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !connParams.UsingDH {
|
||||||
|
dhSecret = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out if we're going to do early data
|
||||||
|
var clientEarlyTrafficSecret []byte
|
||||||
|
connParams.ClientSendingEarlyData = gotEarlyData
|
||||||
|
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
|
||||||
|
if connParams.UsingEarlyData {
|
||||||
|
|
||||||
|
h := params.Hash.New()
|
||||||
|
h.Write(clientHello.Marshal())
|
||||||
|
chHash := h.Sum(nil)
|
||||||
|
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
|
||||||
|
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select a next protocol
|
||||||
|
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
|
||||||
|
return nil, nil, AlertNoApplicationProtocol
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
|
||||||
|
return ServerStateNegotiated{
|
||||||
|
Caps: state.Caps,
|
||||||
|
Params: connParams,
|
||||||
|
|
||||||
|
dhGroup: dhGroup,
|
||||||
|
dhPublic: dhPublic,
|
||||||
|
dhSecret: dhSecret,
|
||||||
|
pskSecret: pskSecret,
|
||||||
|
selectedPSK: selectedPSK,
|
||||||
|
cert: cert,
|
||||||
|
certScheme: certScheme,
|
||||||
|
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
|
||||||
|
|
||||||
|
firstClientHello: state.firstClientHello,
|
||||||
|
helloRetryRequest: state.helloRetryRequest,
|
||||||
|
clientHello: clientHello,
|
||||||
|
}.Next(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateNegotiated struct {
|
||||||
|
Caps Capabilities
|
||||||
|
Params ConnectionParameters
|
||||||
|
|
||||||
|
dhGroup NamedGroup
|
||||||
|
dhPublic []byte
|
||||||
|
dhSecret []byte
|
||||||
|
pskSecret []byte
|
||||||
|
clientEarlyTrafficSecret []byte
|
||||||
|
selectedPSK int
|
||||||
|
cert *Certificate
|
||||||
|
certScheme SignatureScheme
|
||||||
|
|
||||||
|
firstClientHello *HandshakeMessage
|
||||||
|
helloRetryRequest *HandshakeMessage
|
||||||
|
clientHello *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the ServerHello
|
||||||
|
sh := &ServerHelloBody{
|
||||||
|
Version: supportedVersion,
|
||||||
|
CipherSuite: state.Params.CipherSuite,
|
||||||
|
}
|
||||||
|
_, err := prng.Read(sh.Random[:])
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
if state.Params.UsingDH {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
|
||||||
|
err = sh.Extensions.Add(&KeyShareExtension{
|
||||||
|
HandshakeType: HandshakeTypeServerHello,
|
||||||
|
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
|
||||||
|
err = sh.Extensions.Add(&PreSharedKeyExtension{
|
||||||
|
HandshakeType: HandshakeTypeServerHello,
|
||||||
|
SelectedIdentity: uint16(state.selectedPSK),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serverHello, err := HandshakeMessageFromBody(sh)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up crypto params
|
||||||
|
params, ok := cipherSuiteMap[sh.CipherSuite]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start up the handshake hash
|
||||||
|
handshakeHash := params.Hash.New()
|
||||||
|
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||||
|
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||||
|
handshakeHash.Write(state.clientHello.Marshal())
|
||||||
|
handshakeHash.Write(serverHello.Marshal())
|
||||||
|
|
||||||
|
// Compute handshake secrets
|
||||||
|
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||||
|
|
||||||
|
var earlySecret []byte
|
||||||
|
if state.Params.UsingPSK {
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
|
||||||
|
} else {
|
||||||
|
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.dhSecret == nil {
|
||||||
|
state.dhSecret = zero
|
||||||
|
}
|
||||||
|
|
||||||
|
h0 := params.Hash.New().Sum(nil)
|
||||||
|
h2 := handshakeHash.Sum(nil)
|
||||||
|
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||||
|
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
|
||||||
|
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||||
|
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||||
|
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||||
|
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||||
|
|
||||||
|
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
|
||||||
|
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||||
|
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||||
|
|
||||||
|
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
|
||||||
|
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||||
|
|
||||||
|
// Send an EncryptedExtensions message (even if it's empty)
|
||||||
|
eeList := ExtensionList{}
|
||||||
|
if state.Params.NextProto != "" {
|
||||||
|
logf(logTypeHandshake, "[server] sending ALPN extension")
|
||||||
|
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if state.Params.UsingEarlyData {
|
||||||
|
logf(logTypeHandshake, "[server] sending EDI extension")
|
||||||
|
err = eeList.Add(&EarlyDataExtension{})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ee := &EncryptedExtensionsBody{eeList}
|
||||||
|
|
||||||
|
// Run the external extension handler.
|
||||||
|
if state.Caps.ExtensionHandler != nil {
|
||||||
|
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eem, err := HandshakeMessageFromBody(ee)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
handshakeHash.Write(eem.Marshal())
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
SendHandshakeMessage{serverHello},
|
||||||
|
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||||
|
SendHandshakeMessage{eem},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate with a certificate if required
|
||||||
|
if !state.Params.UsingPSK {
|
||||||
|
// Send a CertificateRequest message if we want client auth
|
||||||
|
if state.Caps.RequireClientAuth {
|
||||||
|
state.Params.UsingClientAuth = true
|
||||||
|
|
||||||
|
// XXX: We don't support sending any constraints besides a list of
|
||||||
|
// supported signature algorithms
|
||||||
|
cr := &CertificateRequestBody{}
|
||||||
|
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||||
|
err := cr.Extensions.Add(schemes)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
crm, err := HandshakeMessageFromBody(cr)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
//TODO state.state.serverCertificateRequest = cr
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{crm})
|
||||||
|
handshakeHash.Write(crm.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and send Certificate, CertificateVerify
|
||||||
|
certificate := &CertificateBody{
|
||||||
|
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
|
||||||
|
}
|
||||||
|
for i, entry := range state.cert.Chain {
|
||||||
|
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||||
|
}
|
||||||
|
certm, err := HandshakeMessageFromBody(certificate)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||||
|
handshakeHash.Write(certm.Marshal())
|
||||||
|
|
||||||
|
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
|
||||||
|
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
|
||||||
|
|
||||||
|
hcv := handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
|
||||||
|
return nil, nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||||
|
handshakeHash.Write(certvm.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute secrets resulting from the server's first flight
|
||||||
|
h3 := handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||||
|
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||||
|
|
||||||
|
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
|
||||||
|
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||||
|
|
||||||
|
// Assemble the Finished message
|
||||||
|
fin := &FinishedBody{
|
||||||
|
VerifyDataLen: len(serverFinishedData),
|
||||||
|
VerifyData: serverFinishedData,
|
||||||
|
}
|
||||||
|
finm, _ := HandshakeMessageFromBody(fin)
|
||||||
|
|
||||||
|
toSend = append(toSend, SendHandshakeMessage{finm})
|
||||||
|
handshakeHash.Write(finm.Marshal())
|
||||||
|
|
||||||
|
// Compute traffic secrets
|
||||||
|
h4 := handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
|
||||||
|
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
|
||||||
|
|
||||||
|
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||||
|
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||||
|
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||||
|
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||||
|
|
||||||
|
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
|
||||||
|
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
|
||||||
|
|
||||||
|
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
|
||||||
|
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||||
|
|
||||||
|
if state.Params.UsingEarlyData {
|
||||||
|
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
|
||||||
|
nextState := ServerStateWaitEOED{
|
||||||
|
AuthCertificate: state.Caps.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: params,
|
||||||
|
handshakeHash: handshakeHash,
|
||||||
|
masterSecret: masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: clientTrafficSecret,
|
||||||
|
serverTrafficSecret: serverTrafficSecret,
|
||||||
|
exporterSecret: exporterSecret,
|
||||||
|
}
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||||
|
ReadEarlyData{},
|
||||||
|
}...)
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
|
||||||
|
toSend = append(toSend, []HandshakeAction{
|
||||||
|
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||||
|
ReadPastEarlyData{},
|
||||||
|
}...)
|
||||||
|
waitFlight2 := ServerStateWaitFlight2{
|
||||||
|
AuthCertificate: state.Caps.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: params,
|
||||||
|
handshakeHash: handshakeHash,
|
||||||
|
masterSecret: masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: clientTrafficSecret,
|
||||||
|
serverTrafficSecret: serverTrafficSecret,
|
||||||
|
exporterSecret: exporterSecret,
|
||||||
|
}
|
||||||
|
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||||
|
toSend = append(toSend, moreToSend...)
|
||||||
|
return nextState, toSend, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitEOED struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hm.body) > 0 {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||||
|
}
|
||||||
|
waitFlight2 := ServerStateWaitFlight2{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||||
|
toSend = append(toSend, moreToSend...)
|
||||||
|
return nextState, toSend, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitFlight2 struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Params.UsingClientAuth {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
|
||||||
|
nextState := ServerStateWaitCert{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
|
||||||
|
nextState := ServerStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitCert struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := &CertificateBody{}
|
||||||
|
_, err := cert.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
if len(cert.CertificateList) == 0 {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
|
||||||
|
nextState := ServerStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
|
||||||
|
nextState := ServerStateWaitCV{
|
||||||
|
AuthCertificate: state.AuthCertificate,
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
clientCertificate: cert,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitCV struct {
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
|
||||||
|
clientCertificate *CertificateBody
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
certVerify := &CertificateVerifyBody{}
|
||||||
|
_, err := certVerify.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify client signature over handshake hash
|
||||||
|
hcv := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||||
|
|
||||||
|
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
|
||||||
|
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.AuthCertificate != nil {
|
||||||
|
err := state.AuthCertificate(state.clientCertificate.CertificateList)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
|
||||||
|
return nil, nil, AlertBadCertificate
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it passes, record the certificateVerify in the transcript hash
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
|
||||||
|
nextState := ServerStateWaitFinished{
|
||||||
|
Params: state.Params,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
masterSecret: state.masterSecret,
|
||||||
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
|
handshakeHash: state.handshakeHash,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
return nextState, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerStateWaitFinished struct {
|
||||||
|
Params ConnectionParameters
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
|
||||||
|
masterSecret []byte
|
||||||
|
clientHandshakeTrafficSecret []byte
|
||||||
|
|
||||||
|
handshakeHash hash.Hash
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
|
||||||
|
_, err := fin.Unmarshal(hm.body)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify client Finished data
|
||||||
|
h5 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||||
|
|
||||||
|
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||||
|
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||||
|
|
||||||
|
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
|
||||||
|
return nil, nil, AlertHandshakeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the resumption secret
|
||||||
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
h6 := state.handshakeHash.Sum(nil)
|
||||||
|
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
|
||||||
|
|
||||||
|
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||||
|
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||||
|
|
||||||
|
// Compute client traffic keys
|
||||||
|
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
|
||||||
|
nextState := StateConnected{
|
||||||
|
Params: state.Params,
|
||||||
|
isClient: false,
|
||||||
|
cryptoParams: state.cryptoParams,
|
||||||
|
resumptionSecret: resumptionSecret,
|
||||||
|
clientTrafficSecret: state.clientTrafficSecret,
|
||||||
|
serverTrafficSecret: state.serverTrafficSecret,
|
||||||
|
exporterSecret: state.exporterSecret,
|
||||||
|
}
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
|
||||||
|
}
|
||||||
|
return nextState, toSend, AlertNoAlert
|
||||||
|
}
|
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Marker interface for actions that an implementation should take based on
|
||||||
|
// state transitions.
|
||||||
|
type HandshakeAction interface{}
|
||||||
|
|
||||||
|
type SendHandshakeMessage struct {
|
||||||
|
Message *HandshakeMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type SendEarlyData struct{}
|
||||||
|
|
||||||
|
type ReadEarlyData struct{}
|
||||||
|
|
||||||
|
type ReadPastEarlyData struct{}
|
||||||
|
|
||||||
|
type RekeyIn struct {
|
||||||
|
Label string
|
||||||
|
KeySet keySet
|
||||||
|
}
|
||||||
|
|
||||||
|
type RekeyOut struct {
|
||||||
|
Label string
|
||||||
|
KeySet keySet
|
||||||
|
}
|
||||||
|
|
||||||
|
type StorePSK struct {
|
||||||
|
PSK PreSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandshakeState interface {
|
||||||
|
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppExtensionHandler interface {
|
||||||
|
Send(hs HandshakeType, el *ExtensionList) error
|
||||||
|
Receive(hs HandshakeType, el *ExtensionList) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capabilities objects represent the capabilities of a TLS client or server,
|
||||||
|
// as an input to TLS negotiation
|
||||||
|
type Capabilities struct {
|
||||||
|
// For both client and server
|
||||||
|
CipherSuites []CipherSuite
|
||||||
|
Groups []NamedGroup
|
||||||
|
SignatureSchemes []SignatureScheme
|
||||||
|
PSKs PreSharedKeyCache
|
||||||
|
Certificates []*Certificate
|
||||||
|
AuthCertificate func(chain []CertificateEntry) error
|
||||||
|
ExtensionHandler AppExtensionHandler
|
||||||
|
|
||||||
|
// For client
|
||||||
|
PSKModes []PSKKeyExchangeMode
|
||||||
|
|
||||||
|
// For server
|
||||||
|
NextProtos []string
|
||||||
|
AllowEarlyData bool
|
||||||
|
RequireCookie bool
|
||||||
|
CookieHandler CookieHandler
|
||||||
|
RequireClientAuth bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionOptions objects represent per-connection settings for a client
|
||||||
|
// initiating a connection
|
||||||
|
type ConnectionOptions struct {
|
||||||
|
ServerName string
|
||||||
|
NextProtos []string
|
||||||
|
EarlyData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionParameters objects represent the parameters negotiated for a
|
||||||
|
// connection.
|
||||||
|
type ConnectionParameters struct {
|
||||||
|
UsingPSK bool
|
||||||
|
UsingDH bool
|
||||||
|
ClientSendingEarlyData bool
|
||||||
|
UsingEarlyData bool
|
||||||
|
UsingClientAuth bool
|
||||||
|
|
||||||
|
CipherSuite CipherSuite
|
||||||
|
ServerName string
|
||||||
|
NextProto string
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateConnected is symmetric between client and server
|
||||||
|
type StateConnected struct {
|
||||||
|
Params ConnectionParameters
|
||||||
|
isClient bool
|
||||||
|
cryptoParams CipherSuiteParams
|
||||||
|
resumptionSecret []byte
|
||||||
|
clientTrafficSecret []byte
|
||||||
|
serverTrafficSecret []byte
|
||||||
|
exporterSecret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||||
|
var trafficKeys keySet
|
||||||
|
if state.isClient {
|
||||||
|
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||||
|
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||||
|
} else {
|
||||||
|
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||||
|
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
SendHandshakeMessage{kum},
|
||||||
|
RekeyOut{Label: "update", KeySet: trafficKeys},
|
||||||
|
}
|
||||||
|
return toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||||
|
tkt, err := NewSessionTicket(length, lifetime)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||||
|
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
|
||||||
|
|
||||||
|
newPSK := PreSharedKey{
|
||||||
|
CipherSuite: state.cryptoParams.Suite,
|
||||||
|
IsResumption: true,
|
||||||
|
Identity: tkt.Ticket,
|
||||||
|
Key: resumptionKey,
|
||||||
|
NextProto: state.Params.NextProto,
|
||||||
|
ReceivedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
|
||||||
|
TicketAgeAdd: tkt.TicketAgeAdd,
|
||||||
|
}
|
||||||
|
|
||||||
|
tktm, err := HandshakeMessageFromBody(tkt)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
||||||
|
return nil, AlertInternalError
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{
|
||||||
|
StorePSK{newPSK},
|
||||||
|
SendHandshakeMessage{tktm},
|
||||||
|
}
|
||||||
|
return toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
if hm == nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyGeneric, err := hm.ToBody()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
|
switch body := bodyGeneric.(type) {
|
||||||
|
case *KeyUpdateBody:
|
||||||
|
var trafficKeys keySet
|
||||||
|
if !state.isClient {
|
||||||
|
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||||
|
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||||
|
} else {
|
||||||
|
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||||
|
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||||
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
|
||||||
|
|
||||||
|
// If requested, roll outbound keys and send a KeyUpdate
|
||||||
|
if body.KeyUpdateRequest == KeyUpdateRequested {
|
||||||
|
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend = append(toSend, moreToSend...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, toSend, AlertNoAlert
|
||||||
|
|
||||||
|
case *NewSessionTicketBody:
|
||||||
|
// XXX: Allow NewSessionTicket in both directions?
|
||||||
|
if !state.isClient {
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||||
|
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
||||||
|
|
||||||
|
psk := PreSharedKey{
|
||||||
|
CipherSuite: state.cryptoParams.Suite,
|
||||||
|
IsResumption: true,
|
||||||
|
Identity: body.Ticket,
|
||||||
|
Key: resumptionKey,
|
||||||
|
NextProto: state.Params.NextProto,
|
||||||
|
ReceivedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
|
||||||
|
TicketAgeAdd: body.TicketAgeAdd,
|
||||||
|
}
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{StorePSK{psk}}
|
||||||
|
return state, toSend, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
package syntax
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Unmarshal(data []byte, v interface{}) (int, error) {
|
||||||
|
// Check for well-formedness.
|
||||||
|
// Avoids filling out half a data structure
|
||||||
|
// before discovering a JSON syntax error.
|
||||||
|
d := decodeState{}
|
||||||
|
d.Write(data)
|
||||||
|
return d.unmarshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// These are the options that can be specified in the struct tag. Right now,
|
||||||
|
// all of them apply to variable-length vectors and nothing else
|
||||||
|
type decOpts struct {
|
||||||
|
head uint // length of length in bytes
|
||||||
|
min uint // minimum size in bytes
|
||||||
|
max uint // maximum size in bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
type decodeState struct {
|
||||||
|
bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if _, ok := r.(runtime.Error); ok {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
panic(s)
|
||||||
|
}
|
||||||
|
err = r.(error)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||||
|
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
|
||||||
|
}
|
||||||
|
|
||||||
|
read = d.value(rv)
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *decodeState) value(v reflect.Value) int {
|
||||||
|
return valueDecoder(v)(e, v, decOpts{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
|
||||||
|
|
||||||
|
func valueDecoder(v reflect.Value) decoderFunc {
|
||||||
|
return typeDecoder(v.Type().Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeDecoder(t reflect.Type) decoderFunc {
|
||||||
|
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||||
|
return newTypeDecoder(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||||
|
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return uintDecoder
|
||||||
|
case reflect.Array:
|
||||||
|
return newArrayDecoder(t)
|
||||||
|
case reflect.Slice:
|
||||||
|
return newSliceDecoder(t)
|
||||||
|
case reflect.Struct:
|
||||||
|
return newStructDecoder(t)
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///// Specific decoders below
|
||||||
|
|
||||||
|
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
var uintLen int
|
||||||
|
switch v.Elem().Kind() {
|
||||||
|
case reflect.Uint8:
|
||||||
|
uintLen = 1
|
||||||
|
case reflect.Uint16:
|
||||||
|
uintLen = 2
|
||||||
|
case reflect.Uint32:
|
||||||
|
uintLen = 4
|
||||||
|
case reflect.Uint64:
|
||||||
|
uintLen = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, uintLen)
|
||||||
|
n, err := d.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if n != uintLen {
|
||||||
|
panic(fmt.Errorf("Insufficient data to read uint"))
|
||||||
|
}
|
||||||
|
|
||||||
|
val := uint64(0)
|
||||||
|
for _, b := range buf {
|
||||||
|
val = (val << 8) + uint64(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
v.Elem().SetUint(val)
|
||||||
|
return uintLen
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type arrayDecoder struct {
|
||||||
|
elemDec decoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
n := v.Elem().Type().Len()
|
||||||
|
read := 0
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
|
||||||
|
}
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
|
func newArrayDecoder(t reflect.Type) decoderFunc {
|
||||||
|
dec := &arrayDecoder{typeDecoder(t.Elem())}
|
||||||
|
return dec.decode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type sliceDecoder struct {
|
||||||
|
elementType reflect.Type
|
||||||
|
elementDec decoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
if opts.head == 0 {
|
||||||
|
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
||||||
|
}
|
||||||
|
|
||||||
|
lengthBytes := make([]byte, opts.head)
|
||||||
|
n, err := d.Read(lengthBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if uint(n) != opts.head {
|
||||||
|
panic(fmt.Errorf("Not enough data to read header"))
|
||||||
|
}
|
||||||
|
|
||||||
|
length := uint(0)
|
||||||
|
for _, b := range lengthBytes {
|
||||||
|
length = (length << 8) + uint(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.max > 0 && length > opts.max {
|
||||||
|
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||||
|
}
|
||||||
|
if length < opts.min {
|
||||||
|
panic(fmt.Errorf("Length of vector below declared min"))
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, length)
|
||||||
|
n, err = d.Read(data)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if uint(n) != length {
|
||||||
|
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
|
||||||
|
}
|
||||||
|
|
||||||
|
elemBuf := &decodeState{}
|
||||||
|
elemBuf.Write(data)
|
||||||
|
elems := []reflect.Value{}
|
||||||
|
read := int(opts.head)
|
||||||
|
for elemBuf.Len() > 0 {
|
||||||
|
elem := reflect.New(sd.elementType)
|
||||||
|
read += sd.elementDec(elemBuf, elem, opts)
|
||||||
|
elems = append(elems, elem)
|
||||||
|
}
|
||||||
|
|
||||||
|
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
|
||||||
|
for i := 0; i < len(elems); i += 1 {
|
||||||
|
v.Elem().Index(i).Set(elems[i].Elem())
|
||||||
|
}
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSliceDecoder(t reflect.Type) decoderFunc {
|
||||||
|
dec := &sliceDecoder{
|
||||||
|
elementType: t.Elem(),
|
||||||
|
elementDec: typeDecoder(t.Elem()),
|
||||||
|
}
|
||||||
|
return dec.decode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type structDecoder struct {
|
||||||
|
fieldOpts []decOpts
|
||||||
|
fieldDecs []decoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
read := 0
|
||||||
|
for i := range sd.fieldDecs {
|
||||||
|
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
|
||||||
|
}
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStructDecoder(t reflect.Type) decoderFunc {
|
||||||
|
n := t.NumField()
|
||||||
|
sd := structDecoder{
|
||||||
|
fieldOpts: make([]decOpts, n),
|
||||||
|
fieldDecs: make([]decoderFunc, n),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
f := t.Field(i)
|
||||||
|
|
||||||
|
tag := f.Tag.Get("tls")
|
||||||
|
tagOpts := parseTag(tag)
|
||||||
|
|
||||||
|
sd.fieldOpts[i] = decOpts{
|
||||||
|
head: tagOpts["head"],
|
||||||
|
max: tagOpts["max"],
|
||||||
|
min: tagOpts["min"],
|
||||||
|
}
|
||||||
|
|
||||||
|
sd.fieldDecs[i] = typeDecoder(f.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sd.decode
|
||||||
|
}
|
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
package syntax
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Marshal(v interface{}) ([]byte, error) {
|
||||||
|
e := &encodeState{}
|
||||||
|
err := e.marshal(v, encOpts{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return e.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// These are the options that can be specified in the struct tag. Right now,
|
||||||
|
// all of them apply to variable-length vectors and nothing else
|
||||||
|
type encOpts struct {
|
||||||
|
head uint // length of length in bytes
|
||||||
|
min uint // minimum size in bytes
|
||||||
|
max uint // maximum size in bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
type encodeState struct {
|
||||||
|
bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if _, ok := r.(runtime.Error); ok {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
panic(s)
|
||||||
|
}
|
||||||
|
err = r.(error)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
e.reflectValue(reflect.ValueOf(v), opts)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
|
||||||
|
valueEncoder(v)(e, v, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
|
||||||
|
|
||||||
|
func valueEncoder(v reflect.Value) encoderFunc {
|
||||||
|
if !v.IsValid() {
|
||||||
|
panic(fmt.Errorf("Cannot encode an invalid value"))
|
||||||
|
}
|
||||||
|
return typeEncoder(v.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeEncoder(t reflect.Type) encoderFunc {
|
||||||
|
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||||
|
return newTypeEncoder(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||||
|
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||||
|
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return uintEncoder
|
||||||
|
case reflect.Array:
|
||||||
|
return newArrayEncoder(t)
|
||||||
|
case reflect.Slice:
|
||||||
|
return newSliceEncoder(t)
|
||||||
|
case reflect.Struct:
|
||||||
|
return newStructEncoder(t)
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///// Specific encoders below
|
||||||
|
|
||||||
|
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
u := v.Uint()
|
||||||
|
switch v.Type().Kind() {
|
||||||
|
case reflect.Uint8:
|
||||||
|
e.WriteByte(byte(u))
|
||||||
|
case reflect.Uint16:
|
||||||
|
e.Write([]byte{byte(u >> 8), byte(u)})
|
||||||
|
case reflect.Uint32:
|
||||||
|
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||||
|
case reflect.Uint64:
|
||||||
|
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
|
||||||
|
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type arrayEncoder struct {
|
||||||
|
elemEnc encoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
n := v.Len()
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
ae.elemEnc(e, v.Index(i), opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newArrayEncoder(t reflect.Type) encoderFunc {
|
||||||
|
enc := &arrayEncoder{typeEncoder(t.Elem())}
|
||||||
|
return enc.encode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type sliceEncoder struct {
|
||||||
|
ae *arrayEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
if opts.head == 0 {
|
||||||
|
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayState := &encodeState{}
|
||||||
|
se.ae.encode(arrayState, v, opts)
|
||||||
|
|
||||||
|
n := uint(arrayState.Len())
|
||||||
|
if opts.max > 0 && n > opts.max {
|
||||||
|
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
||||||
|
}
|
||||||
|
if n>>(8*opts.head) > 0 {
|
||||||
|
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||||
|
}
|
||||||
|
if n < opts.min {
|
||||||
|
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int(opts.head - 1); i >= 0; i -= 1 {
|
||||||
|
e.WriteByte(byte(n >> (8 * uint(i))))
|
||||||
|
}
|
||||||
|
e.Write(arrayState.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSliceEncoder(t reflect.Type) encoderFunc {
|
||||||
|
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
|
||||||
|
return enc.encode
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type structEncoder struct {
|
||||||
|
fieldOpts []encOpts
|
||||||
|
fieldEncs []encoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
for i := range se.fieldEncs {
|
||||||
|
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStructEncoder(t reflect.Type) encoderFunc {
|
||||||
|
n := t.NumField()
|
||||||
|
se := structEncoder{
|
||||||
|
fieldOpts: make([]encOpts, n),
|
||||||
|
fieldEncs: make([]encoderFunc, n),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 1 {
|
||||||
|
f := t.Field(i)
|
||||||
|
tag := f.Tag.Get("tls")
|
||||||
|
tagOpts := parseTag(tag)
|
||||||
|
|
||||||
|
se.fieldOpts[i] = encOpts{
|
||||||
|
head: tagOpts["head"],
|
||||||
|
max: tagOpts["max"],
|
||||||
|
min: tagOpts["min"],
|
||||||
|
}
|
||||||
|
se.fieldEncs[i] = typeEncoder(f.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return se.encode
|
||||||
|
}
|
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package syntax
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// `tls:"head=2,min=2,max=255"`
|
||||||
|
|
||||||
|
type tagOptions map[string]uint
|
||||||
|
|
||||||
|
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
||||||
|
// name=value pairs, where the values MUST be unsigned integers
|
||||||
|
func parseTag(tag string) tagOptions {
|
||||||
|
opts := tagOptions{}
|
||||||
|
for _, token := range strings.Split(tag, ",") {
|
||||||
|
if strings.Index(token, "=") == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(token, "=")
|
||||||
|
if len(parts[0]) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||||
|
opts[parts[0]] = uint(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return opts
|
||||||
|
}
|
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server returns a new TLS server side connection
|
||||||
|
// using conn as the underlying transport.
|
||||||
|
// The configuration config must be non-nil and must include
|
||||||
|
// at least one certificate or else set GetCertificate.
|
||||||
|
func Server(conn net.Conn, config *Config) *Conn {
|
||||||
|
return NewConn(conn, config, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client returns a new TLS client side connection
|
||||||
|
// using conn as the underlying transport.
|
||||||
|
// The config cannot be nil: users must set either ServerName or
|
||||||
|
// InsecureSkipVerify in the config.
|
||||||
|
func Client(conn net.Conn, config *Config) *Conn {
|
||||||
|
return NewConn(conn, config, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||||
|
type Listener struct {
|
||||||
|
net.Listener
|
||||||
|
config *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept waits for and returns the next incoming TLS connection.
|
||||||
|
// The returned connection c is a *tls.Conn.
|
||||||
|
func (l *Listener) Accept() (c net.Conn, err error) {
|
||||||
|
c, err = l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server := Server(c, l.config)
|
||||||
|
err = server.Handshake()
|
||||||
|
if err == AlertNoAlert {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
c = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewListener creates a Listener which accepts connections from an inner
|
||||||
|
// Listener and wraps each connection with Server.
|
||||||
|
// The configuration config must be non-nil and must include
|
||||||
|
// at least one certificate or else set GetCertificate.
|
||||||
|
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||||
|
l := new(Listener)
|
||||||
|
l.Listener = inner
|
||||||
|
l.config = config
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen creates a TLS listener accepting connections on the
|
||||||
|
// given network address using net.Listen.
|
||||||
|
// The configuration config must be non-nil and must include
|
||||||
|
// at least one certificate or else set GetCertificate.
|
||||||
|
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||||
|
if config == nil || !config.ValidForServer() {
|
||||||
|
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
|
||||||
|
}
|
||||||
|
l, err := net.Listen(network, laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewListener(l, config), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TimeoutError struct{}
|
||||||
|
|
||||||
|
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
||||||
|
func (TimeoutError) Timeout() bool { return true }
|
||||||
|
func (TimeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
// DialWithDialer connects to the given network address using dialer.Dial and
|
||||||
|
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
||||||
|
// timeout or deadline given in the dialer apply to connection and TLS
|
||||||
|
// handshake as a whole.
|
||||||
|
//
|
||||||
|
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||||
|
// configuration; see the documentation of Config for the defaults.
|
||||||
|
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
|
// We want the Timeout and Deadline values from dialer to cover the
|
||||||
|
// whole process: TCP connection and TLS handshake. This means that we
|
||||||
|
// also need to start our own timers now.
|
||||||
|
timeout := dialer.Timeout
|
||||||
|
|
||||||
|
if !dialer.Deadline.IsZero() {
|
||||||
|
deadlineTimeout := dialer.Deadline.Sub(time.Now())
|
||||||
|
if timeout == 0 || deadlineTimeout < timeout {
|
||||||
|
timeout = deadlineTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errChannel chan error
|
||||||
|
|
||||||
|
if timeout != 0 {
|
||||||
|
errChannel = make(chan error, 2)
|
||||||
|
time.AfterFunc(timeout, func() {
|
||||||
|
errChannel <- TimeoutError{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
rawConn, err := dialer.Dial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
|
if colonPos == -1 {
|
||||||
|
colonPos = len(addr)
|
||||||
|
}
|
||||||
|
hostname := addr[:colonPos]
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
config = &Config{}
|
||||||
|
}
|
||||||
|
// If no ServerName is set, infer the ServerName
|
||||||
|
// from the hostname we're connecting to.
|
||||||
|
if config.ServerName == "" {
|
||||||
|
// Make a copy to avoid polluting argument or default.
|
||||||
|
c := config.Clone()
|
||||||
|
c.ServerName = hostname
|
||||||
|
config = c
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := Client(rawConn, config)
|
||||||
|
|
||||||
|
if timeout == 0 {
|
||||||
|
err = conn.Handshake()
|
||||||
|
if err == AlertNoAlert {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
go func() {
|
||||||
|
errChannel <- conn.Handshake()
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = <-errChannel
|
||||||
|
if err == AlertNoAlert {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
rawConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address using net.Dial
|
||||||
|
// and then initiates a TLS handshake, returning the resulting
|
||||||
|
// TLS connection.
|
||||||
|
// Dial interprets a nil configuration as equivalent to
|
||||||
|
// the zero configuration; see the documentation of Config
|
||||||
|
// for the defaults.
|
||||||
|
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||||
|
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||||
|
}
|
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
@ -1,32 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SentPacketHandler handles ACKs received for outgoing packets
|
|
||||||
type SentPacketHandler interface {
|
|
||||||
// SentPacket may modify the packet
|
|
||||||
SentPacket(packet *Packet) error
|
|
||||||
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
|
|
||||||
|
|
||||||
SendingAllowed() bool
|
|
||||||
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
|
|
||||||
DequeuePacketForRetransmission() (packet *Packet)
|
|
||||||
GetLeastUnacked() protocol.PacketNumber
|
|
||||||
|
|
||||||
GetAlarmTimeout() time.Time
|
|
||||||
OnAlarm()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
|
||||||
type ReceivedPacketHandler interface {
|
|
||||||
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
|
|
||||||
SetLowerLimit(protocol.PacketNumber)
|
|
||||||
|
|
||||||
GetAlarmTimeout() time.Time
|
|
||||||
GetAckFrame() *frames.AckFrame
|
|
||||||
}
|
|
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
@ -3,7 +3,7 @@ package quic
|
|||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bufferPool sync.Pool
|
var bufferPool sync.Pool
|
||||||
|
298
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
298
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
@ -10,32 +10,39 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
listenErr error
|
|
||||||
|
|
||||||
conn connection
|
conn connection
|
||||||
hostname string
|
hostname string
|
||||||
|
|
||||||
errorChan chan struct{}
|
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||||
handshakeChan <-chan handshakeEvent
|
versionNegotiated bool // has the server accepted our version
|
||||||
|
receivedVersionNegotiationPacket bool
|
||||||
|
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||||
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
config *Config
|
config *Config
|
||||||
versionNegotiated bool // has version negotiation completed yet
|
tls handshake.MintTLS // only used when using TLS
|
||||||
|
|
||||||
connectionID protocol.ConnectionID
|
connectionID protocol.ConnectionID
|
||||||
|
|
||||||
|
initialVersion protocol.VersionNumber
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
|
||||||
session packetHandler
|
session packetHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
// make it possible to mock connection ID generation in the tests
|
||||||
|
generateConnectionID = utils.GenerateConnectionID
|
||||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,34 +60,16 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||||||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||||
// The hostname for SNI is taken from the given address.
|
|
||||||
func DialAddrNonFWSecure(
|
|
||||||
addr string,
|
|
||||||
tlsConf *tls.Config,
|
|
||||||
config *Config,
|
|
||||||
) (NonFWSession, error) {
|
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
|
||||||
// The host parameter is used for SNI.
|
// The host parameter is used for SNI.
|
||||||
func DialNonFWSecure(
|
func Dial(
|
||||||
pconn net.PacketConn,
|
pconn net.PacketConn,
|
||||||
remoteAddr net.Addr,
|
remoteAddr net.Addr,
|
||||||
host string,
|
host string,
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (NonFWSession, error) {
|
) (Session, error) {
|
||||||
connID, err := utils.GenerateConnectionID()
|
connID, err := generateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -89,7 +78,6 @@ func DialNonFWSecure(
|
|||||||
if tlsConf != nil {
|
if tlsConf != nil {
|
||||||
hostname = tlsConf.ServerName
|
hostname = tlsConf.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostname == "" {
|
if hostname == "" {
|
||||||
hostname, _, err = net.SplitHostPort(host)
|
hostname, _, err = net.SplitHostPort(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -105,37 +93,15 @@ func DialNonFWSecure(
|
|||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: clientConfig,
|
config: clientConfig,
|
||||||
version: clientConfig.Versions[0],
|
version: clientConfig.Versions[0],
|
||||||
errorChan: make(chan struct{}),
|
versionNegotiationChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.createNewSession(nil)
|
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||||
if err != nil {
|
|
||||||
|
if err := c.dial(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return c.session, nil
|
||||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
|
||||||
|
|
||||||
return c.session.(NonFWSession), c.establishSecureConnection()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
|
||||||
// The host parameter is used for SNI.
|
|
||||||
func Dial(
|
|
||||||
pconn net.PacketConn,
|
|
||||||
remoteAddr net.Addr,
|
|
||||||
host string,
|
|
||||||
tlsConf *tls.Config,
|
|
||||||
config *Config,
|
|
||||||
) (Session, error) {
|
|
||||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = sess.WaitUntilHandshakeComplete()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return sess, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||||
@ -153,6 +119,10 @@ func populateClientConfig(config *Config) *Config {
|
|||||||
if config.HandshakeTimeout != 0 {
|
if config.HandshakeTimeout != 0 {
|
||||||
handshakeTimeout = config.HandshakeTimeout
|
handshakeTimeout = config.HandshakeTimeout
|
||||||
}
|
}
|
||||||
|
idleTimeout := protocol.DefaultIdleTimeout
|
||||||
|
if config.IdleTimeout != 0 {
|
||||||
|
idleTimeout = config.IdleTimeout
|
||||||
|
}
|
||||||
|
|
||||||
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
||||||
if maxReceiveStreamFlowControlWindow == 0 {
|
if maxReceiveStreamFlowControlWindow == 0 {
|
||||||
@ -166,32 +136,109 @@ func populateClientConfig(config *Config) *Config {
|
|||||||
return &Config{
|
return &Config{
|
||||||
Versions: versions,
|
Versions: versions,
|
||||||
HandshakeTimeout: handshakeTimeout,
|
HandshakeTimeout: handshakeTimeout,
|
||||||
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
|
IdleTimeout: idleTimeout,
|
||||||
|
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||||
KeepAlive: config.KeepAlive,
|
KeepAlive: config.KeepAlive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
func (c *client) dial() error {
|
||||||
func (c *client) establishSecureConnection() error {
|
var err error
|
||||||
go c.listen()
|
if c.version.UsesTLS() {
|
||||||
|
err = c.dialTLS()
|
||||||
select {
|
} else {
|
||||||
case <-c.errorChan:
|
err = c.dialGQUIC()
|
||||||
return c.listenErr
|
}
|
||||||
case ev := <-c.handshakeChan:
|
if err == errCloseSessionForNewVersion {
|
||||||
if ev.err != nil {
|
return c.dial()
|
||||||
return ev.err
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialGQUIC() error {
|
||||||
|
if err := c.createNewGQUICSession(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go c.listen()
|
||||||
|
return c.establishSecureConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialTLS() error {
|
||||||
|
params := &handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
IdleTimeout: c.config.IdleTimeout,
|
||||||
|
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||||
|
// TODO(#523): make these values configurable
|
||||||
|
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||||
|
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||||
|
}
|
||||||
|
csc := handshake.NewCryptoStreamConn(nil)
|
||||||
|
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
|
||||||
|
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mintConf.ExtensionHandler = extHandler
|
||||||
|
mintConf.ServerName = c.hostname
|
||||||
|
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||||
|
|
||||||
|
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go c.listen()
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
if err != handshake.ErrCloseSessionForRetry {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
utils.Infof("Received a Retry packet. Recreating session.")
|
||||||
|
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if ev.encLevel != protocol.EncryptionSecure {
|
|
||||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||||
|
// It returns:
|
||||||
|
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||||
|
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||||
|
// - any other error that might occur
|
||||||
|
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||||
|
func (c *client) establishSecureConnection() error {
|
||||||
|
var runErr error
|
||||||
|
errorChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
runErr = c.session.run() // returns as soon as the session is closed
|
||||||
|
close(errorChan)
|
||||||
|
utils.Infof("Connection %x closed.", c.connectionID)
|
||||||
|
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait until the server accepts the QUIC version (or an error occurs)
|
||||||
|
select {
|
||||||
|
case <-errorChan:
|
||||||
|
return runErr
|
||||||
|
case <-c.versionNegotiationChan:
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-errorChan:
|
||||||
|
return runErr
|
||||||
|
case err := <-c.session.handshakeStatus():
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen listens
|
// Listen listens on the underlying connection and passes packets on for handling.
|
||||||
|
// It returns when the connection is closed.
|
||||||
func (c *client) listen() {
|
func (c *client) listen() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@ -205,13 +252,15 @@ func (c *client) listen() {
|
|||||||
n, addr, err = c.conn.Read(data)
|
n, addr, err = c.conn.Read(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||||
|
c.mutex.Lock()
|
||||||
|
if c.session != nil {
|
||||||
c.session.Close(err)
|
c.session.Close(err)
|
||||||
}
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
data = data[:n]
|
c.handlePacket(addr, data[:n])
|
||||||
|
|
||||||
c.handlePacket(addr, data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,10 +268,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||||||
rcvTime := time.Now()
|
rcvTime := time.Now()
|
||||||
|
|
||||||
r := bytes.NewReader(packet)
|
r := bytes.NewReader(packet)
|
||||||
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
|
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||||
// drop this packet if we can't parse the Public Header
|
// drop this packet if we can't parse the header
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// reject packets with truncated connection id if we didn't request truncation
|
||||||
|
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
@ -230,6 +283,11 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// reject packets with the wrong connection ID
|
||||||
|
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if hdr.ResetFlag {
|
if hdr.ResetFlag {
|
||||||
cr := c.conn.RemoteAddr()
|
cr := c.conn.RemoteAddr()
|
||||||
// check if the remote address and the connection ID match
|
// check if the remote address and the connection ID match
|
||||||
@ -238,44 +296,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||||||
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pr, err := parsePublicReset(r)
|
pr, err := wire.ParsePublicReset(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
|
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
|
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
|
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handle Version Negotiation Packets
|
||||||
|
if hdr.IsVersionNegotiation {
|
||||||
// ignore delayed / duplicated version negotiation packets
|
// ignore delayed / duplicated version negotiation packets
|
||||||
if c.versionNegotiated && hdr.VersionFlag {
|
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is the first packet after the client sent a packet with the VersionFlag set
|
|
||||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
|
||||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
|
||||||
c.versionNegotiated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdr.VersionFlag {
|
|
||||||
// version negotiation packets have no payload
|
// version negotiation packets have no payload
|
||||||
if err := c.handlePacketWithVersionFlag(hdr); err != nil {
|
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||||
c.session.Close(err)
|
c.session.Close(err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this is the first packet we are receiving
|
||||||
|
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||||
|
if !c.versionNegotiated {
|
||||||
|
c.versionNegotiated = true
|
||||||
|
close(c.versionNegotiationChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
|
||||||
|
|
||||||
c.session.handlePacket(&receivedPacket{
|
c.session.handlePacket(&receivedPacket{
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
publicHeader: hdr,
|
header: hdr,
|
||||||
data: packet[len(packet)-r.Len():],
|
data: packet[len(packet)-r.Len():],
|
||||||
rcvTime: rcvTime,
|
rcvTime: rcvTime,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
for _, v := range hdr.SupportedVersions {
|
for _, v := range hdr.SupportedVersions {
|
||||||
if v == c.version {
|
if v == c.version {
|
||||||
// the version negotiation packet contains the version that we offered
|
// the version negotiation packet contains the version that we offered
|
||||||
@ -285,51 +347,57 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||||
if newVersion == protocol.VersionUnsupported {
|
if !ok {
|
||||||
return qerr.InvalidVersion
|
return qerr.InvalidVersion
|
||||||
}
|
}
|
||||||
|
c.receivedVersionNegotiationPacket = true
|
||||||
|
c.negotiatedVersions = hdr.SupportedVersions
|
||||||
|
|
||||||
// switch to negotiated version
|
// switch to negotiated version
|
||||||
|
c.initialVersion = c.version
|
||||||
c.version = newVersion
|
c.version = newVersion
|
||||||
c.versionNegotiated = true
|
|
||||||
var err error
|
var err error
|
||||||
c.connectionID, err = utils.GenerateConnectionID()
|
c.connectionID, err = utils.GenerateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
|
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||||
|
|
||||||
c.session.Close(errCloseSessionForNewVersion)
|
c.session.Close(errCloseSessionForNewVersion)
|
||||||
return c.createNewSession(hdr.SupportedVersions)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
func (c *client) createNewGQUICSession() (err error) {
|
||||||
var err error
|
c.mutex.Lock()
|
||||||
c.session, c.handshakeChan, err = newClientSession(
|
defer c.mutex.Unlock()
|
||||||
|
c.session, err = newClientSession(
|
||||||
c.conn,
|
c.conn,
|
||||||
c.hostname,
|
c.hostname,
|
||||||
c.version,
|
c.version,
|
||||||
c.connectionID,
|
c.connectionID,
|
||||||
c.tlsConf,
|
c.tlsConf,
|
||||||
c.config,
|
c.config,
|
||||||
negotiatedVersions,
|
c.initialVersion,
|
||||||
|
c.negotiatedVersions,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) createNewTLSSession(
|
||||||
|
paramsChan <-chan handshake.TransportParameters,
|
||||||
|
version protocol.VersionNumber,
|
||||||
|
) (err error) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
c.session, err = newTLSClientSession(
|
||||||
|
c.conn,
|
||||||
|
c.hostname,
|
||||||
|
c.version,
|
||||||
|
c.connectionID,
|
||||||
|
c.config,
|
||||||
|
c.tls,
|
||||||
|
paramsChan,
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// session.run() returns as soon as the session is closed
|
|
||||||
err := c.session.run()
|
|
||||||
if err == errCloseSessionForNewVersion {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.listenErr = err
|
|
||||||
close(c.errorChan)
|
|
||||||
|
|
||||||
utils.Infof("Connection %x closed.", c.connectionID)
|
|
||||||
c.conn.Close()
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
@ -1,58 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/cipher"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/aes12"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type aeadAESGCM struct {
|
|
||||||
otherIV []byte
|
|
||||||
myIV []byte
|
|
||||||
encrypter cipher.AEAD
|
|
||||||
decrypter cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size
|
|
||||||
//
|
|
||||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
|
||||||
// tag size, and couples the cipher and aes packages closely.
|
|
||||||
// See https://github.com/lucas-clemente/aes12.
|
|
||||||
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
|
||||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
|
||||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
|
||||||
}
|
|
||||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &aeadAESGCM{
|
|
||||||
otherIV: otherIV,
|
|
||||||
myIV: myIV,
|
|
||||||
encrypter: encrypter,
|
|
||||||
decrypter: decrypter,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
@ -1,14 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
|
||||||
res := make([]byte, 12)
|
|
||||||
copy(res[0:4], iv)
|
|
||||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
|
||||||
return res
|
|
||||||
}
|
|
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
@ -1,76 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StkSource is used to create and verify source address tokens
|
|
||||||
type StkSource interface {
|
|
||||||
// NewToken creates a new token
|
|
||||||
NewToken([]byte) ([]byte, error)
|
|
||||||
// DecodeToken decodes a token
|
|
||||||
DecodeToken([]byte) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type stkSource struct {
|
|
||||||
aead cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
const stkKeySize = 16
|
|
||||||
|
|
||||||
// Chrome currently sets this to 12, but discusses changing it to 16. We start
|
|
||||||
// at 16 :)
|
|
||||||
const stkNonceSize = 16
|
|
||||||
|
|
||||||
// NewStkSource creates a source for source address tokens
|
|
||||||
func NewStkSource() (StkSource, error) {
|
|
||||||
secret := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(secret); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
key, err := deriveKey(secret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &stkSource{aead: aead}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
|
||||||
nonce := make([]byte, stkNonceSize)
|
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.aead.Seal(nonce, nonce, data, nil), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
|
|
||||||
if len(p) < stkNonceSize {
|
|
||||||
return nil, fmt.Errorf("STK too short: %d", len(p))
|
|
||||||
}
|
|
||||||
nonce := p[:stkNonceSize]
|
|
||||||
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deriveKey(secret []byte) ([]byte, error) {
|
|
||||||
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
|
|
||||||
key := make([]byte, stkKeySize)
|
|
||||||
if _, err := io.ReadFull(r, key); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
|
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cryptoStreamI interface {
|
||||||
|
StreamID() protocol.StreamID
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
handleStreamFrame(*wire.StreamFrame) error
|
||||||
|
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||||
|
closeForShutdown(error)
|
||||||
|
setReadOffset(protocol.ByteCount)
|
||||||
|
// methods needed for flow control
|
||||||
|
getWindowUpdate() protocol.ByteCount
|
||||||
|
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cryptoStream struct {
|
||||||
|
*stream
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ cryptoStreamI = &cryptoStream{}
|
||||||
|
|
||||||
|
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||||
|
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||||
|
return &cryptoStream{str}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadOffset sets the read offset.
|
||||||
|
// It is only needed for the crypto stream.
|
||||||
|
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||||
|
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||||
|
s.receiveStream.readOffset = offset
|
||||||
|
s.receiveStream.frameQueue.readPosition = offset
|
||||||
|
}
|
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
@ -7,12 +7,15 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/h2quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
verbose := flag.Bool("v", false, "verbose")
|
verbose := flag.Bool("v", false, "verbose")
|
||||||
|
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
urls := flag.Args()
|
urls := flag.Args()
|
||||||
|
|
||||||
@ -23,8 +26,17 @@ func main() {
|
|||||||
}
|
}
|
||||||
utils.SetLogTimeFormat("")
|
utils.SetLogTimeFormat("")
|
||||||
|
|
||||||
|
versions := protocol.SupportedVersions
|
||||||
|
if *tls {
|
||||||
|
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
roundTripper := &h2quic.RoundTripper{
|
||||||
|
QuicConfig: &quic.Config{Versions: versions},
|
||||||
|
}
|
||||||
|
defer roundTripper.Close()
|
||||||
hclient := &http.Client{
|
hclient := &http.Client{
|
||||||
Transport: &h2quic.RoundTripper{},
|
Transport: roundTripper,
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
@ -17,7 +17,9 @@ import (
|
|||||||
|
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/h2quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,6 +123,7 @@ func main() {
|
|||||||
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
||||||
www := flag.String("www", "/var/www", "www data")
|
www := flag.String("www", "/var/www", "www data")
|
||||||
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
||||||
|
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *verbose {
|
if *verbose {
|
||||||
@ -130,6 +133,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
utils.SetLogTimeFormat("")
|
utils.SetLogTimeFormat("")
|
||||||
|
|
||||||
|
versions := protocol.SupportedVersions
|
||||||
|
if *tls {
|
||||||
|
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||||
|
}
|
||||||
|
|
||||||
certFile := *certPath + "/fullchain.pem"
|
certFile := *certPath + "/fullchain.pem"
|
||||||
keyFile := *certPath + "/privkey.pem"
|
keyFile := *certPath + "/privkey.pem"
|
||||||
|
|
||||||
@ -148,7 +156,11 @@ func main() {
|
|||||||
if *tcp {
|
if *tcp {
|
||||||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
||||||
} else {
|
} else {
|
||||||
err = h2quic.ListenAndServeQUIC(bCap, certFile, keyFile, nil)
|
server := h2quic.Server{
|
||||||
|
Server: &http.Server{Addr: bCap},
|
||||||
|
QuicConfig: &quic.Config{Versions: versions},
|
||||||
|
}
|
||||||
|
err = server.ListenAndServeTLS(certFile, keyFile)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
|
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
@ -1,240 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/handshake"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
type flowControlManager struct {
|
|
||||||
connectionParameters handshake.ConnectionParametersManager
|
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
|
|
||||||
streamFlowController map[protocol.StreamID]*flowController
|
|
||||||
connFlowController *flowController
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ FlowControlManager = &flowControlManager{}
|
|
||||||
|
|
||||||
var errMapAccess = errors.New("Error accessing the flowController map.")
|
|
||||||
|
|
||||||
// NewFlowControlManager creates a new flow control manager
|
|
||||||
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
|
|
||||||
return &flowControlManager{
|
|
||||||
connectionParameters: connectionParameters,
|
|
||||||
rttStats: rttStats,
|
|
||||||
streamFlowController: make(map[protocol.StreamID]*flowController),
|
|
||||||
connFlowController: newFlowController(0, false, connectionParameters, rttStats),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStream creates new flow controllers for a stream
|
|
||||||
// it does nothing if the stream already exists
|
|
||||||
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
if _, ok := f.streamFlowController[streamID]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveStream removes a closed stream from flow control
|
|
||||||
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
delete(f.streamFlowController, streamID)
|
|
||||||
f.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetStream should be called when receiving a RstStreamFrame
|
|
||||||
// it updates the byte offset to the value in the RstStreamFrame
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
streamFlowController, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
|
|
||||||
if err != nil {
|
|
||||||
return qerr.StreamDataAfterTermination
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamFlowController.ContributesToConnection() {
|
|
||||||
f.connFlowController.IncrementHighestReceived(increment)
|
|
||||||
if f.connFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateHighestReceived updates the highest received byte offset for a stream
|
|
||||||
// it adds the number of additional bytes to connection level flow control
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
streamFlowController, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
|
|
||||||
// this error can be ignored here
|
|
||||||
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
|
|
||||||
|
|
||||||
if streamFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamFlowController.ContributesToConnection() {
|
|
||||||
f.connFlowController.IncrementHighestReceived(increment)
|
|
||||||
if f.connFlowController.CheckFlowControlViolation() {
|
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
fc, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fc.AddBytesRead(n)
|
|
||||||
if fc.ContributesToConnection() {
|
|
||||||
f.connFlowController.AddBytesRead(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
// get WindowUpdates for streams
|
|
||||||
for id, fc := range f.streamFlowController {
|
|
||||||
if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary {
|
|
||||||
res = append(res, WindowUpdate{StreamID: id, Offset: offset})
|
|
||||||
if fc.ContributesToConnection() && newIncrement != 0 {
|
|
||||||
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// get a WindowUpdate for the connection
|
|
||||||
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary {
|
|
||||||
res = append(res, WindowUpdate{StreamID: 0, Offset: offset})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
|
||||||
f.mutex.RLock()
|
|
||||||
defer f.mutex.RUnlock()
|
|
||||||
|
|
||||||
// StreamID can be 0 when retransmitting
|
|
||||||
if streamID == 0 {
|
|
||||||
return f.connFlowController.receiveWindow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
flowController, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return flowController.receiveWindow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamID must not be 0 here
|
|
||||||
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
fc, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fc.AddBytesSent(n)
|
|
||||||
if fc.ContributesToConnection() {
|
|
||||||
f.connFlowController.AddBytesSent(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// must not be called with StreamID 0
|
|
||||||
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
|
||||||
f.mutex.RLock()
|
|
||||||
defer f.mutex.RUnlock()
|
|
||||||
|
|
||||||
fc, err := f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
res := fc.SendWindowSize()
|
|
||||||
|
|
||||||
if fc.ContributesToConnection() {
|
|
||||||
res = utils.MinByteCount(res, f.connFlowController.SendWindowSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
|
|
||||||
f.mutex.RLock()
|
|
||||||
defer f.mutex.RUnlock()
|
|
||||||
|
|
||||||
return f.connFlowController.SendWindowSize()
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamID may be 0 here
|
|
||||||
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
|
|
||||||
f.mutex.Lock()
|
|
||||||
defer f.mutex.Unlock()
|
|
||||||
|
|
||||||
var fc *flowController
|
|
||||||
if streamID == 0 {
|
|
||||||
fc = f.connFlowController
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
fc, err = f.getFlowController(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fc.UpdateSendWindow(offset), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) {
|
|
||||||
streamFlowController, ok := f.streamFlowController[streamID]
|
|
||||||
if !ok {
|
|
||||||
return nil, errMapAccess
|
|
||||||
}
|
|
||||||
return streamFlowController, nil
|
|
||||||
}
|
|
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
@ -1,198 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/handshake"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type flowController struct {
|
|
||||||
streamID protocol.StreamID
|
|
||||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
|
||||||
|
|
||||||
connectionParameters handshake.ConnectionParametersManager
|
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
|
|
||||||
bytesSent protocol.ByteCount
|
|
||||||
sendWindow protocol.ByteCount
|
|
||||||
|
|
||||||
lastWindowUpdateTime time.Time
|
|
||||||
|
|
||||||
bytesRead protocol.ByteCount
|
|
||||||
highestReceived protocol.ByteCount
|
|
||||||
receiveWindow protocol.ByteCount
|
|
||||||
receiveWindowIncrement protocol.ByteCount
|
|
||||||
maxReceiveWindowIncrement protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
|
|
||||||
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
|
|
||||||
|
|
||||||
// newFlowController gets a new flow controller
|
|
||||||
func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
|
|
||||||
fc := flowController{
|
|
||||||
streamID: streamID,
|
|
||||||
contributesToConnection: contributesToConnection,
|
|
||||||
connectionParameters: connectionParameters,
|
|
||||||
rttStats: rttStats,
|
|
||||||
}
|
|
||||||
|
|
||||||
if streamID == 0 {
|
|
||||||
fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
|
|
||||||
fc.receiveWindowIncrement = fc.receiveWindow
|
|
||||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
|
|
||||||
} else {
|
|
||||||
fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
|
|
||||||
fc.receiveWindowIncrement = fc.receiveWindow
|
|
||||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &fc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) ContributesToConnection() bool {
|
|
||||||
return c.contributesToConnection
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) getSendWindow() protocol.ByteCount {
|
|
||||||
if c.sendWindow == 0 {
|
|
||||||
if c.streamID == 0 {
|
|
||||||
return c.connectionParameters.GetSendConnectionFlowControlWindow()
|
|
||||||
}
|
|
||||||
return c.connectionParameters.GetSendStreamFlowControlWindow()
|
|
||||||
}
|
|
||||||
return c.sendWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
|
|
||||||
c.bytesSent += n
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
|
||||||
// it returns true if the window was actually updated
|
|
||||||
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
|
|
||||||
if newOffset > c.sendWindow {
|
|
||||||
c.sendWindow = newOffset
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) SendWindowSize() protocol.ByteCount {
|
|
||||||
sendWindow := c.getSendWindow()
|
|
||||||
|
|
||||||
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return sendWindow - c.bytesSent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) SendWindowOffset() protocol.ByteCount {
|
|
||||||
return c.getSendWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
|
||||||
// Should **only** be used for the stream-level FlowController
|
|
||||||
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
|
|
||||||
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
|
|
||||||
// It should only be treated as an error when resetting a stream
|
|
||||||
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
|
|
||||||
if byteOffset == c.highestReceived {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if byteOffset > c.highestReceived {
|
|
||||||
increment := byteOffset - c.highestReceived
|
|
||||||
c.highestReceived = byteOffset
|
|
||||||
return increment, nil
|
|
||||||
}
|
|
||||||
return 0, ErrReceivedSmallerByteOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
|
||||||
// Should **only** be used for the connection-level FlowController
|
|
||||||
func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) {
|
|
||||||
c.highestReceived += increment
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) AddBytesRead(n protocol.ByteCount) {
|
|
||||||
// pretend we sent a WindowUpdate when reading the first byte
|
|
||||||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
|
||||||
if c.bytesRead == 0 {
|
|
||||||
c.lastWindowUpdateTime = time.Now()
|
|
||||||
}
|
|
||||||
c.bytesRead += n
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaybeUpdateWindow updates the receive window, if necessary
|
|
||||||
// if the receive window increment is changed, the new value is returned, otherwise a 0
|
|
||||||
// the last return value is the new offset of the receive window
|
|
||||||
func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) {
|
|
||||||
diff := c.receiveWindow - c.bytesRead
|
|
||||||
|
|
||||||
// Chromium implements the same threshold
|
|
||||||
if diff < (c.receiveWindowIncrement / 2) {
|
|
||||||
var newWindowIncrement protocol.ByteCount
|
|
||||||
oldWindowIncrement := c.receiveWindowIncrement
|
|
||||||
|
|
||||||
c.maybeAdjustWindowIncrement()
|
|
||||||
if c.receiveWindowIncrement != oldWindowIncrement {
|
|
||||||
newWindowIncrement = c.receiveWindowIncrement
|
|
||||||
}
|
|
||||||
|
|
||||||
c.lastWindowUpdateTime = time.Now()
|
|
||||||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
|
|
||||||
return true, newWindowIncrement, c.receiveWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
|
||||||
func (c *flowController) maybeAdjustWindowIncrement() {
|
|
||||||
if c.lastWindowUpdateTime.IsZero() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rtt := c.rttStats.SmoothedRTT()
|
|
||||||
if rtt == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
|
|
||||||
|
|
||||||
// interval between the window updates is sufficiently large, no need to increase the increment
|
|
||||||
if timeSinceLastWindowUpdate >= 2*rtt {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
oldWindowSize := c.receiveWindowIncrement
|
|
||||||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
|
|
||||||
|
|
||||||
// debug log, if the window size was actually increased
|
|
||||||
if oldWindowSize < c.receiveWindowIncrement {
|
|
||||||
newWindowSize := c.receiveWindowIncrement / (1 << 10)
|
|
||||||
if c.streamID == 0 {
|
|
||||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
|
|
||||||
} else {
|
|
||||||
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
|
||||||
// it is intended be used for the connection-level flow controller
|
|
||||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
|
||||||
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
|
|
||||||
if inc > c.receiveWindowIncrement {
|
|
||||||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
|
|
||||||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) CheckFlowControlViolation() bool {
|
|
||||||
return c.highestReceived > c.receiveWindow
|
|
||||||
}
|
|
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
@ -1,26 +0,0 @@
|
|||||||
package flowcontrol
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
|
|
||||||
// WindowUpdate provides the data for WindowUpdateFrames.
|
|
||||||
type WindowUpdate struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
Offset protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// A FlowControlManager manages the flow control
|
|
||||||
type FlowControlManager interface {
|
|
||||||
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
|
|
||||||
RemoveStream(streamID protocol.StreamID)
|
|
||||||
// methods needed for receiving data
|
|
||||||
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
|
||||||
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
|
||||||
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
|
|
||||||
GetWindowUpdates() []WindowUpdate
|
|
||||||
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
|
|
||||||
// methods needed for sending data
|
|
||||||
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
|
|
||||||
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
|
|
||||||
RemainingConnectionWindowSize() protocol.ByteCount
|
|
||||||
UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error)
|
|
||||||
}
|
|
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
@ -1,9 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
|
|
||||||
// AckRange is an ACK range
|
|
||||||
type AckRange struct {
|
|
||||||
FirstPacketNumber protocol.PacketNumber
|
|
||||||
LastPacketNumber protocol.PacketNumber
|
|
||||||
}
|
|
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
@ -1,44 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A BlockedFrame in QUIC
|
|
||||||
type BlockedFrame struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
}
|
|
||||||
|
|
||||||
//Write writes a BlockedFrame frame
|
|
||||||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x05)
|
|
||||||
utils.WriteUint32(b, uint32(f.StreamID))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseBlockedFrame parses a BLOCKED frame
|
|
||||||
func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) {
|
|
||||||
frame := &BlockedFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sid, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.StreamID = protocol.StreamID(sid)
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
@ -1,73 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A ConnectionCloseFrame in QUIC
|
|
||||||
type ConnectionCloseFrame struct {
|
|
||||||
ErrorCode qerr.ErrorCode
|
|
||||||
ReasonPhrase string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame
|
|
||||||
func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) {
|
|
||||||
frame := &ConnectionCloseFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
errorCode, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ErrorCode = qerr.ErrorCode(errorCode)
|
|
||||||
|
|
||||||
reasonPhraseLen, err := utils.ReadUint16(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if reasonPhraseLen > uint16(protocol.MaxPacketSize) {
|
|
||||||
return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
reasonPhrase := make([]byte, reasonPhraseLen)
|
|
||||||
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ReasonPhrase = string(reasonPhrase)
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes an CONNECTION_CLOSE frame.
|
|
||||||
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x02)
|
|
||||||
utils.WriteUint32(b, uint32(f.ErrorCode))
|
|
||||||
|
|
||||||
if len(f.ReasonPhrase) > math.MaxUint16 {
|
|
||||||
return errors.New("ConnectionFrame: ReasonPhrase too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
reasonPhraseLen := uint16(len(f.ReasonPhrase))
|
|
||||||
utils.WriteUint16(b, reasonPhraseLen)
|
|
||||||
b.WriteString(f.ReasonPhrase)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
@ -1,13 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Frame in QUIC
|
|
||||||
type Frame interface {
|
|
||||||
Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
|
||||||
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
|
|
||||||
}
|
|
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
@ -1,28 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
// LogFrame logs a frame, either sent or received
|
|
||||||
func LogFrame(frame Frame, sent bool) {
|
|
||||||
if !utils.Debug() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dir := "<-"
|
|
||||||
if sent {
|
|
||||||
dir = "->"
|
|
||||||
}
|
|
||||||
switch f := frame.(type) {
|
|
||||||
case *StreamFrame:
|
|
||||||
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
|
|
||||||
case *StopWaitingFrame:
|
|
||||||
if sent {
|
|
||||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
|
|
||||||
} else {
|
|
||||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
|
|
||||||
}
|
|
||||||
case *AckFrame:
|
|
||||||
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
|
|
||||||
default:
|
|
||||||
utils.Debugf("\t%s %#v", dir, frame)
|
|
||||||
}
|
|
||||||
}
|
|
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
@ -1,59 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A RstStreamFrame in QUIC
|
|
||||||
type RstStreamFrame struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
ErrorCode uint32
|
|
||||||
ByteOffset protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
//Write writes a RST_STREAM frame
|
|
||||||
func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x01)
|
|
||||||
utils.WriteUint32(b, uint32(f.StreamID))
|
|
||||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
|
||||||
utils.WriteUint32(b, f.ErrorCode)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4 + 8 + 4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseRstStreamFrame parses a RST_STREAM frame
|
|
||||||
func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) {
|
|
||||||
frame := &RstStreamFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sid, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.StreamID = protocol.StreamID(sid)
|
|
||||||
|
|
||||||
byteOffset, err := utils.ReadUint64(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
|
||||||
|
|
||||||
frame.ErrorCode, err = utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
@ -1,54 +0,0 @@
|
|||||||
package frames
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A WindowUpdateFrame in QUIC
|
|
||||||
type WindowUpdateFrame struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
ByteOffset protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
//Write writes a RST_STREAM frame
|
|
||||||
func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
typeByte := uint8(0x04)
|
|
||||||
b.WriteByte(typeByte)
|
|
||||||
|
|
||||||
utils.WriteUint32(b, uint32(f.StreamID))
|
|
||||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinLength of a written frame
|
|
||||||
func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
return 1 + 4 + 8, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseWindowUpdateFrame parses a RST_STREAM frame
|
|
||||||
func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) {
|
|
||||||
frame := &WindowUpdateFrame{}
|
|
||||||
|
|
||||||
// read the TypeByte
|
|
||||||
_, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sid, err := utils.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.StreamID = protocol.StreamID(sid)
|
|
||||||
|
|
||||||
byteOffset, err := utils.ReadUint64(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
|
||||||
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
49
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
49
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
@ -15,8 +15,8 @@ import (
|
|||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,9 +35,9 @@ type client struct {
|
|||||||
opts *roundTripperOpts
|
opts *roundTripperOpts
|
||||||
|
|
||||||
hostname string
|
hostname string
|
||||||
encryptionLevel protocol.EncryptionLevel
|
|
||||||
handshakeErr error
|
handshakeErr error
|
||||||
dialOnce sync.Once
|
dialOnce sync.Once
|
||||||
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||||
|
|
||||||
session quic.Session
|
session quic.Session
|
||||||
headerStream quic.Stream
|
headerStream quic.Stream
|
||||||
@ -51,7 +51,7 @@ type client struct {
|
|||||||
var _ http.RoundTripper = &client{}
|
var _ http.RoundTripper = &client{}
|
||||||
|
|
||||||
var defaultQuicConfig = &quic.Config{
|
var defaultQuicConfig = &quic.Config{
|
||||||
RequestConnectionIDTruncation: true,
|
RequestConnectionIDOmission: true,
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,6 +61,7 @@ func newClient(
|
|||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
opts *roundTripperOpts,
|
opts *roundTripperOpts,
|
||||||
quicConfig *quic.Config,
|
quicConfig *quic.Config,
|
||||||
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||||
) *client {
|
) *client {
|
||||||
config := defaultQuicConfig
|
config := defaultQuicConfig
|
||||||
if quicConfig != nil {
|
if quicConfig != nil {
|
||||||
@ -69,18 +70,22 @@ func newClient(
|
|||||||
return &client{
|
return &client{
|
||||||
hostname: authorityAddr("https", hostname),
|
hostname: authorityAddr("https", hostname),
|
||||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||||
encryptionLevel: protocol.EncryptionUnencrypted,
|
|
||||||
tlsConf: tlsConfig,
|
tlsConf: tlsConfig,
|
||||||
config: config,
|
config: config,
|
||||||
opts: opts,
|
opts: opts,
|
||||||
headerErrored: make(chan struct{}),
|
headerErrored: make(chan struct{}),
|
||||||
|
dialer: dialer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial dials the connection
|
// dial dials the connection
|
||||||
func (c *client) dial() error {
|
func (c *client) dial() error {
|
||||||
var err error
|
var err error
|
||||||
|
if c.dialer != nil {
|
||||||
|
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||||
|
} else {
|
||||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -90,9 +95,6 @@ func (c *client) dial() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.headerStream.StreamID() != 3 {
|
|
||||||
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
|
|
||||||
}
|
|
||||||
c.requestWriter = newRequestWriter(c.headerStream)
|
c.requestWriter = newRequestWriter(c.headerStream)
|
||||||
go c.handleHeaderStream()
|
go c.handleHeaderStream()
|
||||||
return nil
|
return nil
|
||||||
@ -102,45 +104,44 @@ func (c *client) handleHeaderStream() {
|
|||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||||
|
|
||||||
var lastStream protocol.StreamID
|
var err error
|
||||||
|
for err == nil {
|
||||||
|
err = c.readResponse(h2framer, decoder)
|
||||||
|
}
|
||||||
|
utils.Debugf("Error handling header stream: %s", err)
|
||||||
|
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||||
|
// stop all running request
|
||||||
|
close(c.headerErrored)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
||||||
frame, err := h2framer.ReadFrame()
|
frame, err := h2framer.ReadFrame()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
return err
|
||||||
break
|
|
||||||
}
|
}
|
||||||
lastStream = protocol.StreamID(frame.Header().StreamID)
|
|
||||||
hframe, ok := frame.(*http2.HeadersFrame)
|
hframe, ok := frame.(*http2.HeadersFrame)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
|
return errors.New("not a headers frame")
|
||||||
break
|
|
||||||
}
|
}
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
|
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.mutex.RLock()
|
c.mutex.RLock()
|
||||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||||
c.mutex.RUnlock()
|
c.mutex.RUnlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rsp, err := responseFromHeaders(mhframe)
|
rsp, err := responseFromHeaders(mhframe)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
return err
|
||||||
}
|
}
|
||||||
responseChan <- rsp
|
responseChan <- rsp
|
||||||
}
|
return nil
|
||||||
|
|
||||||
// stop all running request
|
|
||||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
|
||||||
close(c.headerErrored)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Roundtrip executes a request and returns a response
|
// Roundtrip executes a request and returns a response
|
||||||
|
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
@ -13,8 +13,8 @@ import (
|
|||||||
"golang.org/x/net/lex/httplex"
|
"golang.org/x/net/lex/httplex"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type requestWriter struct {
|
type requestWriter struct {
|
||||||
|
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
@ -8,8 +8,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
)
|
)
|
||||||
@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
|||||||
|
|
||||||
func (w *responseWriter) Flush() {}
|
func (w *responseWriter) Flush() {}
|
||||||
|
|
||||||
// TODO: Implement a functional CloseNotify method.
|
// This is a NOP. Use http.Request.Context
|
||||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
||||||
|
|
||||||
// test that we implement http.Flusher
|
// test that we implement http.Flusher
|
||||||
|
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
@ -41,6 +41,11 @@ type RoundTripper struct {
|
|||||||
// If nil, reasonable default values will be used.
|
// If nil, reasonable default values will be used.
|
||||||
QuicConfig *quic.Config
|
QuicConfig *quic.Config
|
||||||
|
|
||||||
|
// Dial specifies an optional dial function for creating QUIC
|
||||||
|
// connections for requests.
|
||||||
|
// If Dial is nil, quic.DialAddr will be used.
|
||||||
|
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||||
|
|
||||||
clients map[string]roundTripCloser
|
clients map[string]roundTripCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
|||||||
if onlyCached {
|
if onlyCached {
|
||||||
return nil, ErrNoCachedConn
|
return nil, ErrNoCachedConn
|
||||||
}
|
}
|
||||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
|
client = newClient(
|
||||||
|
hostname,
|
||||||
|
r.TLSClientConfig,
|
||||||
|
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||||
|
r.QuicConfig,
|
||||||
|
r.Dial,
|
||||||
|
)
|
||||||
r.clients[hostname] = client
|
r.clients[hostname] = client
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
|
47
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
47
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
@ -7,14 +7,14 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
@ -50,6 +50,7 @@ type Server struct {
|
|||||||
|
|
||||||
listenerMutex sync.Mutex
|
listenerMutex sync.Mutex
|
||||||
listener quic.Listener
|
listener quic.Listener
|
||||||
|
closed bool
|
||||||
|
|
||||||
supportedVersionsAsString string
|
supportedVersionsAsString string
|
||||||
}
|
}
|
||||||
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||||||
return errors.New("use of h2quic.Server without http.Server")
|
return errors.New("use of h2quic.Server without http.Server")
|
||||||
}
|
}
|
||||||
s.listenerMutex.Lock()
|
s.listenerMutex.Lock()
|
||||||
|
if s.closed {
|
||||||
|
s.listenerMutex.Unlock()
|
||||||
|
return errors.New("Server is already closed")
|
||||||
|
}
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
s.listenerMutex.Unlock()
|
s.listenerMutex.Unlock()
|
||||||
return errors.New("ListenAndServe may only be called once")
|
return errors.New("ListenAndServe may only be called once")
|
||||||
@ -122,15 +127,10 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||||||
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if stream.StreamID() != 3 {
|
|
||||||
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||||
h2framer := http2.NewFramer(nil, stream)
|
h2framer := http2.NewFramer(nil, stream)
|
||||||
|
|
||||||
go func() {
|
|
||||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||||
for {
|
for {
|
||||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||||
@ -144,7 +144,6 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||||
@ -170,8 +169,6 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.RemoteAddr = session.RemoteAddr().String()
|
|
||||||
|
|
||||||
if utils.Debug() {
|
if utils.Debug() {
|
||||||
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||||
} else {
|
} else {
|
||||||
@ -187,19 +184,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var streamEnded bool
|
// handleRequest should be as non-blocking as possible to minimize
|
||||||
if h2headersFrame.StreamEnded() {
|
// head-of-line blocking. Potentially blocking code is run in a separate
|
||||||
|
// goroutine, enabling handleRequest to return before the code is executed.
|
||||||
|
go func() {
|
||||||
|
streamEnded := h2headersFrame.StreamEnded()
|
||||||
|
if streamEnded {
|
||||||
dataStream.(remoteCloser).CloseRemote(0)
|
dataStream.(remoteCloser).CloseRemote(0)
|
||||||
streamEnded = true
|
streamEnded = true
|
||||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(dataStream.Context())
|
||||||
reqBody := newRequestBody(dataStream)
|
reqBody := newRequestBody(dataStream)
|
||||||
req.Body = reqBody
|
req.Body = reqBody
|
||||||
|
|
||||||
|
req.RemoteAddr = session.RemoteAddr().String()
|
||||||
|
|
||||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||||
|
|
||||||
go func() {
|
|
||||||
handler := s.Handler
|
handler := s.Handler
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = http.DefaultServeMux
|
handler = http.DefaultServeMux
|
||||||
@ -225,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
}
|
}
|
||||||
if responseWriter.dataStream != nil {
|
if responseWriter.dataStream != nil {
|
||||||
if !streamEnded && !reqBody.requestRead {
|
if !streamEnded && !reqBody.requestRead {
|
||||||
responseWriter.dataStream.Reset(nil)
|
// in gQUIC, the error code doesn't matter, so just use 0 here
|
||||||
|
responseWriter.dataStream.CancelRead(0)
|
||||||
}
|
}
|
||||||
responseWriter.dataStream.Close()
|
responseWriter.dataStream.Close()
|
||||||
}
|
}
|
||||||
@ -243,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
func (s *Server) Close() error {
|
func (s *Server) Close() error {
|
||||||
s.listenerMutex.Lock()
|
s.listenerMutex.Lock()
|
||||||
defer s.listenerMutex.Unlock()
|
defer s.listenerMutex.Unlock()
|
||||||
|
s.closed = true
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
err := s.listener.Close()
|
err := s.listener.Close()
|
||||||
s.listener = nil
|
s.listener = nil
|
||||||
@ -279,12 +284,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.supportedVersionsAsString == "" {
|
if s.supportedVersionsAsString == "" {
|
||||||
for i, v := range protocol.SupportedVersions {
|
var versions []string
|
||||||
s.supportedVersionsAsString += strconv.Itoa(int(v))
|
for _, v := range protocol.SupportedVersions {
|
||||||
if i != len(protocol.SupportedVersions)-1 {
|
versions = append(versions, v.ToAltSvc())
|
||||||
s.supportedVersionsAsString += ","
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
s.supportedVersionsAsString = strings.Join(versions, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
||||||
@ -344,6 +348,9 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||||||
}
|
}
|
||||||
defer tcpConn.Close()
|
defer tcpConn.Close()
|
||||||
|
|
||||||
|
tlsConn := tls.NewListener(tcpConn, config)
|
||||||
|
defer tlsConn.Close()
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
@ -365,7 +372,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||||||
hErr := make(chan error)
|
hErr := make(chan error)
|
||||||
qErr := make(chan error)
|
qErr := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
hErr <- httpServer.Serve(tcpConn)
|
hErr <- httpServer.Serve(tlsConn)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
qErr <- quicServer.Serve(udpConn)
|
qErr <- quicServer.Serve(udpConn)
|
||||||
|
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
@ -1,265 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConnectionParametersManager negotiates and stores the connection parameters
|
|
||||||
// A ConnectionParametersManager can be used for a server as well as a client
|
|
||||||
// For the server:
|
|
||||||
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
|
|
||||||
// 2. call GetHelloMap to get the values to send in the SHLO
|
|
||||||
// For the client:
|
|
||||||
// 1. call GetHelloMap to get the values to send in a CHLO
|
|
||||||
// 2. call SetFromMap with the values received in the SHLO
|
|
||||||
type ConnectionParametersManager interface {
|
|
||||||
SetFromMap(map[Tag][]byte) error
|
|
||||||
GetHelloMap() (map[Tag][]byte, error)
|
|
||||||
|
|
||||||
GetSendStreamFlowControlWindow() protocol.ByteCount
|
|
||||||
GetSendConnectionFlowControlWindow() protocol.ByteCount
|
|
||||||
GetReceiveStreamFlowControlWindow() protocol.ByteCount
|
|
||||||
GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount
|
|
||||||
GetReceiveConnectionFlowControlWindow() protocol.ByteCount
|
|
||||||
GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount
|
|
||||||
GetMaxOutgoingStreams() uint32
|
|
||||||
GetMaxIncomingStreams() uint32
|
|
||||||
GetIdleConnectionStateLifetime() time.Duration
|
|
||||||
TruncateConnectionID() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type connectionParametersManager struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
perspective protocol.Perspective
|
|
||||||
|
|
||||||
flowControlNegotiated bool
|
|
||||||
|
|
||||||
truncateConnectionID bool
|
|
||||||
maxStreamsPerConnection uint32
|
|
||||||
maxIncomingDynamicStreamsPerConnection uint32
|
|
||||||
idleConnectionStateLifetime time.Duration
|
|
||||||
sendStreamFlowControlWindow protocol.ByteCount
|
|
||||||
sendConnectionFlowControlWindow protocol.ByteCount
|
|
||||||
receiveStreamFlowControlWindow protocol.ByteCount
|
|
||||||
receiveConnectionFlowControlWindow protocol.ByteCount
|
|
||||||
maxReceiveStreamFlowControlWindow protocol.ByteCount
|
|
||||||
maxReceiveConnectionFlowControlWindow protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ConnectionParametersManager = &connectionParametersManager{}
|
|
||||||
|
|
||||||
// ErrMalformedTag is returned when the tag value cannot be read
|
|
||||||
var (
|
|
||||||
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
|
||||||
ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewConnectionParamatersManager creates a new connection parameters manager
|
|
||||||
func NewConnectionParamatersManager(
|
|
||||||
pers protocol.Perspective, v protocol.VersionNumber,
|
|
||||||
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
|
|
||||||
) ConnectionParametersManager {
|
|
||||||
h := &connectionParametersManager{
|
|
||||||
perspective: pers,
|
|
||||||
version: v,
|
|
||||||
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
|
|
||||||
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
|
|
||||||
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
|
||||||
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
|
||||||
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
|
||||||
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
|
||||||
h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout
|
|
||||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
|
||||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective
|
|
||||||
} else {
|
|
||||||
h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient
|
|
||||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
|
||||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective
|
|
||||||
}
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFromMap reads all params
|
|
||||||
func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.truncateConnectionID = (clientValue == 0)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagMSPC]; ok {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagMIDS]; ok {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagICSL]; ok {
|
|
||||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagSFCW]; ok {
|
|
||||||
if h.flowControlNegotiated {
|
|
||||||
return ErrFlowControlRenegotiationNotSupported
|
|
||||||
}
|
|
||||||
sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagCFCW]; ok {
|
|
||||||
if h.flowControlNegotiated {
|
|
||||||
return ErrFlowControlRenegotiationNotSupported
|
|
||||||
}
|
|
||||||
sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return ErrMalformedTag
|
|
||||||
}
|
|
||||||
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, containsSFCW := params[TagSFCW]
|
|
||||||
_, containsCFCW := params[TagCFCW]
|
|
||||||
if containsCFCW || containsSFCW {
|
|
||||||
h.flowControlNegotiated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 {
|
|
||||||
return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 {
|
|
||||||
return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration {
|
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
|
||||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer)
|
|
||||||
}
|
|
||||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHelloMap gets all parameters needed for the Hello message
|
|
||||||
func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) {
|
|
||||||
sfcw := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow()))
|
|
||||||
cfcw := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow()))
|
|
||||||
mspc := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(mspc, h.maxStreamsPerConnection)
|
|
||||||
mids := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
|
|
||||||
icsl := bytes.NewBuffer([]byte{})
|
|
||||||
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second))
|
|
||||||
|
|
||||||
return map[Tag][]byte{
|
|
||||||
TagICSL: icsl.Bytes(),
|
|
||||||
TagMSPC: mspc.Bytes(),
|
|
||||||
TagMIDS: mids.Bytes(),
|
|
||||||
TagCFCW: cfcw.Bytes(),
|
|
||||||
TagSFCW: sfcw.Bytes(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.sendStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.sendConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
|
||||||
func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.receiveStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
return h.maxReceiveStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
|
||||||
func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.receiveConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
|
||||||
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
return h.maxReceiveConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
|
|
||||||
func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
return h.maxIncomingDynamicStreamsPerConnection
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxIncomingStreams get the maximum number of incoming streams per connection
|
|
||||||
func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection
|
|
||||||
return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetIdleConnectionStateLifetime gets the idle timeout
|
|
||||||
func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.idleConnectionStateLifetime
|
|
||||||
}
|
|
||||||
|
|
||||||
// TruncateConnectionID determines if the client requests truncated ConnectionIDs
|
|
||||||
func (h *connectionParametersManager) TruncateConnectionID() bool {
|
|
||||||
if h.perspective == protocol.PerspectiveClient {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.truncateConnectionID
|
|
||||||
}
|
|
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
@ -1,24 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
|
|
||||||
// Sealer seals a packet
|
|
||||||
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
|
||||||
|
|
||||||
// CryptoSetup is a crypto setup
|
|
||||||
type CryptoSetup interface {
|
|
||||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
|
||||||
HandleCryptoStream() error
|
|
||||||
// TODO: clean up this interface
|
|
||||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
|
||||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
|
||||||
|
|
||||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
|
||||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
|
||||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TransportParameters are parameters sent to the peer during the handshake
|
|
||||||
type TransportParameters struct {
|
|
||||||
RequestConnectionIDTruncation bool
|
|
||||||
}
|
|
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
@ -1,100 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/asn1"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/crypto"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
stkPrefixIP byte = iota
|
|
||||||
stkPrefixString
|
|
||||||
)
|
|
||||||
|
|
||||||
// An STK is a source address token
|
|
||||||
type STK struct {
|
|
||||||
RemoteAddr string
|
|
||||||
SentTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// token is the struct that is used for ASN1 serialization and deserialization
|
|
||||||
type token struct {
|
|
||||||
Data []byte
|
|
||||||
Timestamp int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// An STKGenerator generates STKs
|
|
||||||
type STKGenerator struct {
|
|
||||||
stkSource crypto.StkSource
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSTKGenerator initializes a new STKGenerator
|
|
||||||
func NewSTKGenerator() (*STKGenerator, error) {
|
|
||||||
stkSource, err := crypto.NewStkSource()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &STKGenerator{
|
|
||||||
stkSource: stkSource,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewToken generates a new STK token for a given source address
|
|
||||||
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
|
||||||
data, err := asn1.Marshal(token{
|
|
||||||
Data: encodeRemoteAddr(raddr),
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return g.stkSource.NewToken(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecodeToken decodes an STK token
|
|
||||||
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
|
|
||||||
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
|
|
||||||
if len(encrypted) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := g.stkSource.DecodeToken(encrypted)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t := &token{}
|
|
||||||
rest, err := asn1.Unmarshal(data, t)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(rest) != 0 {
|
|
||||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
|
||||||
}
|
|
||||||
return &STK{
|
|
||||||
RemoteAddr: decodeRemoteAddr(t.Data),
|
|
||||||
SentTime: time.Unix(t.Timestamp, 0),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
|
||||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
|
||||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
|
||||||
return append([]byte{stkPrefixIP}, udpAddr.IP...)
|
|
||||||
}
|
|
||||||
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeRemoteAddr decodes the remote address saved in the STK
|
|
||||||
func decodeRemoteAddr(data []byte) string {
|
|
||||||
// data will never be empty for an STK that we generated. Check it to be on the safe side
|
|
||||||
if len(data) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if data[0] == stkPrefixIP {
|
|
||||||
return net.IP(data[1:]).String()
|
|
||||||
}
|
|
||||||
return string(data[1:])
|
|
||||||
}
|
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
@ -1 +0,0 @@
|
|||||||
package chrome
|
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
@ -1 +0,0 @@
|
|||||||
package gquic
|
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
@ -1 +0,0 @@
|
|||||||
package self
|
|
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
@ -1,14 +1,12 @@
|
|||||||
package quicproxy
|
package quicproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Connection is a UDP connection
|
// Connection is a UDP connection
|
||||||
@ -28,21 +26,43 @@ const (
|
|||||||
DirectionIncoming Direction = iota
|
DirectionIncoming Direction = iota
|
||||||
// DirectionOutgoing is the direction from the server to the client.
|
// DirectionOutgoing is the direction from the server to the client.
|
||||||
DirectionOutgoing
|
DirectionOutgoing
|
||||||
|
// DirectionBoth is both incoming and outgoing
|
||||||
|
DirectionBoth
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (d Direction) String() string {
|
||||||
|
switch d {
|
||||||
|
case DirectionIncoming:
|
||||||
|
return "incoming"
|
||||||
|
case DirectionOutgoing:
|
||||||
|
return "outgoing"
|
||||||
|
case DirectionBoth:
|
||||||
|
return "both"
|
||||||
|
default:
|
||||||
|
panic("unknown direction")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Direction) Is(dir Direction) bool {
|
||||||
|
if d == DirectionBoth || dir == DirectionBoth {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return d == dir
|
||||||
|
}
|
||||||
|
|
||||||
// DropCallback is a callback that determines which packet gets dropped.
|
// DropCallback is a callback that determines which packet gets dropped.
|
||||||
type DropCallback func(Direction, protocol.PacketNumber) bool
|
type DropCallback func(dir Direction, packetCount uint64) bool
|
||||||
|
|
||||||
// NoDropper doesn't drop packets.
|
// NoDropper doesn't drop packets.
|
||||||
var NoDropper DropCallback = func(Direction, protocol.PacketNumber) bool {
|
var NoDropper DropCallback = func(Direction, uint64) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
||||||
type DelayCallback func(Direction, protocol.PacketNumber) time.Duration
|
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
|
||||||
|
|
||||||
// NoDelay doesn't apply a delay.
|
// NoDelay doesn't apply a delay.
|
||||||
var NoDelay DelayCallback = func(Direction, protocol.PacketNumber) time.Duration {
|
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,6 +82,8 @@ type Opts struct {
|
|||||||
type QuicProxy struct {
|
type QuicProxy struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
|
version protocol.VersionNumber
|
||||||
|
|
||||||
conn *net.UDPConn
|
conn *net.UDPConn
|
||||||
serverAddr *net.UDPAddr
|
serverAddr *net.UDPAddr
|
||||||
|
|
||||||
@ -73,7 +95,10 @@ type QuicProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewQuicProxy creates a new UDP proxy
|
// NewQuicProxy creates a new UDP proxy
|
||||||
func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
|
||||||
|
if opts == nil {
|
||||||
|
opts = &Opts{}
|
||||||
|
}
|
||||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -103,6 +128,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
|||||||
serverAddr: raddr,
|
serverAddr: raddr,
|
||||||
dropPacket: packetDropper,
|
dropPacket: packetDropper,
|
||||||
delayPacket: packetDelayer,
|
delayPacket: packetDelayer,
|
||||||
|
version: version,
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.runProxy()
|
go p.runProxy()
|
||||||
@ -119,6 +145,7 @@ func (p *QuicProxy) LocalAddr() net.Addr {
|
|||||||
return p.conn.LocalAddr()
|
return p.conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LocalPort is the UDP port number the proxy is listening on.
|
||||||
func (p *QuicProxy) LocalPort() int {
|
func (p *QuicProxy) LocalPort() int {
|
||||||
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
}
|
}
|
||||||
@ -137,7 +164,7 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
|||||||
// runProxy listens on the proxy address and handles incoming packets.
|
// runProxy listens on the proxy address and handles incoming packets.
|
||||||
func (p *QuicProxy) runProxy() error {
|
func (p *QuicProxy) runProxy() error {
|
||||||
for {
|
for {
|
||||||
buffer := make([]byte, protocol.MaxPacketSize)
|
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||||
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -159,20 +186,14 @@ func (p *QuicProxy) runProxy() error {
|
|||||||
}
|
}
|
||||||
p.mutex.Unlock()
|
p.mutex.Unlock()
|
||||||
|
|
||||||
atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||||
|
|
||||||
r := bytes.NewReader(raw)
|
if p.dropPacket(DirectionIncoming, packetCount) {
|
||||||
hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.dropPacket(DirectionIncoming, hdr.PacketNumber) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the packet to the server
|
// Send the packet to the server
|
||||||
delay := p.delayPacket(DirectionIncoming, hdr.PacketNumber)
|
delay := p.delayPacket(DirectionIncoming, packetCount)
|
||||||
if delay != 0 {
|
if delay != 0 {
|
||||||
time.AfterFunc(delay, func() {
|
time.AfterFunc(delay, func() {
|
||||||
// TODO: handle error
|
// TODO: handle error
|
||||||
@ -190,28 +211,20 @@ func (p *QuicProxy) runProxy() error {
|
|||||||
// runConnection handles packets from server to a single client
|
// runConnection handles packets from server to a single client
|
||||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
func (p *QuicProxy) runConnection(conn *connection) error {
|
||||||
for {
|
for {
|
||||||
buffer := make([]byte, protocol.MaxPacketSize)
|
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||||
n, err := conn.ServerConn.Read(buffer)
|
n, err := conn.ServerConn.Read(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
raw := buffer[0:n]
|
raw := buffer[0:n]
|
||||||
|
|
||||||
// TODO: Switch back to using the public header once Chrome properly sets the type byte.
|
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||||
// r := bytes.NewReader(raw)
|
|
||||||
// , err := quic.ParsePublicHeader(r, protocol.PerspectiveServer)
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
|
|
||||||
v := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
if p.dropPacket(DirectionOutgoing, packetCount) {
|
||||||
|
|
||||||
packetNumber := protocol.PacketNumber(v)
|
|
||||||
if p.dropPacket(DirectionOutgoing, packetNumber) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
delay := p.delayPacket(DirectionOutgoing, packetNumber)
|
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
||||||
if delay != 0 {
|
if delay != 0 {
|
||||||
time.AfterFunc(delay, func() {
|
time.AfterFunc(delay, func() {
|
||||||
// TODO: handle error
|
// TODO: handle error
|
||||||
|
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
@ -27,7 +27,7 @@ var _ = BeforeEach(func() {
|
|||||||
|
|
||||||
if len(logFileName) > 0 {
|
if len(logFileName) > 0 {
|
||||||
var err error
|
var err error
|
||||||
logFile, err = os.Create("./log.txt")
|
logFile, err = os.Create(logFileName)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
log.SetOutput(logFile)
|
log.SetOutput(logFile)
|
||||||
utils.SetLogLevel(utils.LogLevelDebug)
|
utils.SetLogLevel(utils.LogLevelDebug)
|
||||||
|
14
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
@ -7,7 +7,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/h2quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
@ -24,6 +26,7 @@ var (
|
|||||||
PRDataLong = GeneratePRData(dataLenLong)
|
PRDataLong = GeneratePRData(dataLenLong)
|
||||||
|
|
||||||
server *h2quic.Server
|
server *h2quic.Server
|
||||||
|
stoppedServing chan struct{}
|
||||||
port string
|
port string
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,11 +78,16 @@ func GeneratePRData(l int) []byte {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartQuicServer() {
|
// StartQuicServer starts a h2quic.Server.
|
||||||
|
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
||||||
|
func StartQuicServer(versions []protocol.VersionNumber) {
|
||||||
server = &h2quic.Server{
|
server = &h2quic.Server{
|
||||||
Server: &http.Server{
|
Server: &http.Server{
|
||||||
TLSConfig: testdata.GetTLSConfig(),
|
TLSConfig: testdata.GetTLSConfig(),
|
||||||
},
|
},
|
||||||
|
QuicConfig: &quic.Config{
|
||||||
|
Versions: versions,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||||
@ -88,14 +96,18 @@ func StartQuicServer() {
|
|||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
||||||
|
|
||||||
|
stoppedServing = make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
server.Serve(conn)
|
server.Serve(conn)
|
||||||
|
close(stoppedServing)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func StopQuicServer() {
|
func StopQuicServer() {
|
||||||
Expect(server.Close()).NotTo(HaveOccurred())
|
Expect(server.Close()).NotTo(HaveOccurred())
|
||||||
|
Eventually(stoppedServing).Should(BeClosed())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Port() string {
|
func Port() string {
|
||||||
|
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
@ -6,23 +6,55 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// The StreamID is the ID of a QUIC stream.
|
||||||
|
type StreamID = protocol.StreamID
|
||||||
|
|
||||||
|
// A VersionNumber is a QUIC version number.
|
||||||
|
type VersionNumber = protocol.VersionNumber
|
||||||
|
|
||||||
|
// A Cookie can be used to verify the ownership of the client address.
|
||||||
|
type Cookie = handshake.Cookie
|
||||||
|
|
||||||
|
// ConnectionState records basic details about the QUIC connection.
|
||||||
|
type ConnectionState = handshake.ConnectionState
|
||||||
|
|
||||||
|
// An ErrorCode is an application-defined error code.
|
||||||
|
type ErrorCode = protocol.ApplicationErrorCode
|
||||||
|
|
||||||
// Stream is the interface implemented by QUIC streams
|
// Stream is the interface implemented by QUIC streams
|
||||||
type Stream interface {
|
type Stream interface {
|
||||||
|
// StreamID returns the stream ID.
|
||||||
|
StreamID() StreamID
|
||||||
// Read reads data from the stream.
|
// Read reads data from the stream.
|
||||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||||
|
// If the stream was canceled by the peer, the error implements the StreamError
|
||||||
|
// interface, and Canceled() == true.
|
||||||
io.Reader
|
io.Reader
|
||||||
// Write writes data to the stream.
|
// Write writes data to the stream.
|
||||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||||
|
// If the stream was canceled by the peer, the error implements the StreamError
|
||||||
|
// interface, and Canceled() == true.
|
||||||
io.Writer
|
io.Writer
|
||||||
|
// Close closes the write-direction of the stream.
|
||||||
|
// Future calls to Write are not permitted after calling Close.
|
||||||
|
// It must not be called concurrently with Write.
|
||||||
|
// It must not be called after calling CancelWrite.
|
||||||
io.Closer
|
io.Closer
|
||||||
StreamID() protocol.StreamID
|
// CancelWrite aborts sending on this stream.
|
||||||
// Reset closes the stream with an error.
|
// It must not be called after Close.
|
||||||
Reset(error)
|
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||||
|
// Write will unblock immediately, and future calls to Write will fail.
|
||||||
|
CancelWrite(ErrorCode) error
|
||||||
|
// CancelRead aborts receiving on this stream.
|
||||||
|
// It will ask the peer to stop transmitting stream data.
|
||||||
|
// Read will unblock immediately, and future Read calls will fail.
|
||||||
|
CancelRead(ErrorCode) error
|
||||||
// The context is canceled as soon as the write-side of the stream is closed.
|
// The context is canceled as soon as the write-side of the stream is closed.
|
||||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
@ -43,6 +75,41 @@ type Stream interface {
|
|||||||
SetDeadline(t time.Time) error
|
SetDeadline(t time.Time) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A ReceiveStream is a unidirectional Receive Stream.
|
||||||
|
type ReceiveStream interface {
|
||||||
|
// see Stream.StreamID
|
||||||
|
StreamID() StreamID
|
||||||
|
// see Stream.Read
|
||||||
|
io.Reader
|
||||||
|
// see Stream.CancelRead
|
||||||
|
CancelRead(ErrorCode) error
|
||||||
|
// see Stream.SetReadDealine
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// A SendStream is a unidirectional Send Stream.
|
||||||
|
type SendStream interface {
|
||||||
|
// see Stream.StreamID
|
||||||
|
StreamID() StreamID
|
||||||
|
// see Stream.Write
|
||||||
|
io.Writer
|
||||||
|
// see Stream.Close
|
||||||
|
io.Closer
|
||||||
|
// see Stream.CancelWrite
|
||||||
|
CancelWrite(ErrorCode) error
|
||||||
|
// see Stream.Context
|
||||||
|
Context() context.Context
|
||||||
|
// see Stream.SetWriteDeadline
|
||||||
|
SetWriteDeadline(t time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||||
|
type StreamError interface {
|
||||||
|
error
|
||||||
|
Canceled() bool
|
||||||
|
ErrorCode() ErrorCode
|
||||||
|
}
|
||||||
|
|
||||||
// A Session is a QUIC connection between two peers.
|
// A Session is a QUIC connection between two peers.
|
||||||
type Session interface {
|
type Session interface {
|
||||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||||
@ -64,53 +131,41 @@ type Session interface {
|
|||||||
// The context is cancelled when the session is closed.
|
// The context is cancelled when the session is closed.
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
}
|
// ConnectionState returns basic details about the QUIC connection.
|
||||||
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
ConnectionState() ConnectionState
|
||||||
// The communication is encrypted, but not yet forward secure.
|
|
||||||
type NonFWSession interface {
|
|
||||||
Session
|
|
||||||
WaitUntilHandshakeComplete() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// An STK is a Source Address token.
|
|
||||||
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
|
|
||||||
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
|
|
||||||
type STK struct {
|
|
||||||
// The remote address this token was issued for.
|
|
||||||
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
|
|
||||||
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
|
|
||||||
remoteAddr string
|
|
||||||
// The time that the STK was issued (resolution 1 second)
|
|
||||||
sentTime time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config contains all configuration data needed for a QUIC server or client.
|
// Config contains all configuration data needed for a QUIC server or client.
|
||||||
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// The QUIC versions that can be negotiated.
|
// The QUIC versions that can be negotiated.
|
||||||
// If not set, it uses all versions available.
|
// If not set, it uses all versions available.
|
||||||
// Warning: This API should not be considered stable and will change soon.
|
// Warning: This API should not be considered stable and will change soon.
|
||||||
Versions []protocol.VersionNumber
|
Versions []VersionNumber
|
||||||
// Ask the server to truncate the connection ID sent in the Public Header.
|
// Ask the server to omit the connection ID sent in the Public Header.
|
||||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||||
// Currently only valid for the client.
|
// Currently only valid for the client.
|
||||||
RequestConnectionIDTruncation bool
|
RequestConnectionIDOmission bool
|
||||||
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
||||||
// If the timeout is exceeded, the connection is closed.
|
// If the timeout is exceeded, the connection is closed.
|
||||||
// If this value is zero, the timeout is set to 10 seconds.
|
// If this value is zero, the timeout is set to 10 seconds.
|
||||||
HandshakeTimeout time.Duration
|
HandshakeTimeout time.Duration
|
||||||
// AcceptSTK determines if an STK is accepted.
|
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||||
// It is called with stk = nil if the client didn't send an STK.
|
// This value only applies after the handshake has completed.
|
||||||
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
|
// If the timeout is exceeded, the connection is closed.
|
||||||
|
// If this value is zero, the timeout is set to 30 seconds.
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
// AcceptCookie determines if a Cookie is accepted.
|
||||||
|
// It is called with cookie = nil if the client didn't send an Cookie.
|
||||||
|
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
|
||||||
// This option is only valid for the server.
|
// This option is only valid for the server.
|
||||||
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
|
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
|
||||||
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
|
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
|
||||||
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
|
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
|
||||||
MaxReceiveStreamFlowControlWindow protocol.ByteCount
|
MaxReceiveStreamFlowControlWindow uint64
|
||||||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||||
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
|
MaxReceiveConnectionFlowControlWindow uint64
|
||||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||||
KeepAlive bool
|
KeepAlive bool
|
||||||
}
|
}
|
||||||
|
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SentPacketHandler handles ACKs received for outgoing packets
|
||||||
|
type SentPacketHandler interface {
|
||||||
|
// SentPacket may modify the packet
|
||||||
|
SentPacket(packet *Packet) error
|
||||||
|
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||||
|
SetHandshakeComplete()
|
||||||
|
|
||||||
|
// SendingAllowed says if a packet can be sent.
|
||||||
|
// Sending packets might not be possible because:
|
||||||
|
// * we're congestion limited
|
||||||
|
// * we're tracking the maximum number of sent packets
|
||||||
|
SendingAllowed() bool
|
||||||
|
// TimeUntilSend is the time when the next packet should be sent.
|
||||||
|
// It is used for pacing packets.
|
||||||
|
TimeUntilSend() time.Time
|
||||||
|
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
|
||||||
|
// It always returns a number greater or equal than 1.
|
||||||
|
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
|
||||||
|
// Note that the number of packets is only calculated based on the pacing algorithm.
|
||||||
|
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||||
|
ShouldSendNumPackets() int
|
||||||
|
|
||||||
|
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||||
|
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||||
|
DequeuePacketForRetransmission() (packet *Packet)
|
||||||
|
GetLeastUnacked() protocol.PacketNumber
|
||||||
|
|
||||||
|
GetAlarmTimeout() time.Time
|
||||||
|
OnAlarm()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||||
|
type ReceivedPacketHandler interface {
|
||||||
|
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||||
|
IgnoreBelow(protocol.PacketNumber)
|
||||||
|
|
||||||
|
GetAlarmTimeout() time.Time
|
||||||
|
GetAckFrame() *wire.AckFrame
|
||||||
|
}
|
@ -3,29 +3,30 @@ package ackhandler
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Packet is a packet
|
// A Packet is a packet
|
||||||
// +gen linkedlist
|
// +gen linkedlist
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
PacketNumber protocol.PacketNumber
|
PacketNumber protocol.PacketNumber
|
||||||
Frames []frames.Frame
|
Frames []wire.Frame
|
||||||
Length protocol.ByteCount
|
Length protocol.ByteCount
|
||||||
EncryptionLevel protocol.EncryptionLevel
|
EncryptionLevel protocol.EncryptionLevel
|
||||||
|
|
||||||
SendTime time.Time
|
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
|
||||||
|
sendTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFramesForRetransmission gets all the frames for retransmission
|
// GetFramesForRetransmission gets all the frames for retransmission
|
||||||
func (p *Packet) GetFramesForRetransmission() []frames.Frame {
|
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
|
||||||
var fs []frames.Frame
|
var fs []wire.Frame
|
||||||
for _, frame := range p.Frames {
|
for _, frame := range p.Frames {
|
||||||
switch frame.(type) {
|
switch frame.(type) {
|
||||||
case *frames.AckFrame:
|
case *wire.AckFrame:
|
||||||
continue
|
continue
|
||||||
case *frames.StopWaitingFrame:
|
case *wire.StopWaitingFrame:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fs = append(fs, frame)
|
fs = append(fs, frame)
|
@ -1,18 +1,15 @@
|
|||||||
package ackhandler
|
package ackhandler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
|
|
||||||
|
|
||||||
type receivedPacketHandler struct {
|
type receivedPacketHandler struct {
|
||||||
largestObserved protocol.PacketNumber
|
largestObserved protocol.PacketNumber
|
||||||
lowerLimit protocol.PacketNumber
|
ignoreBelow protocol.PacketNumber
|
||||||
largestObservedReceivedTime time.Time
|
largestObservedReceivedTime time.Time
|
||||||
|
|
||||||
packetHistory *receivedPacketHistory
|
packetHistory *receivedPacketHistory
|
||||||
@ -23,46 +20,45 @@ type receivedPacketHandler struct {
|
|||||||
retransmittablePacketsReceivedSinceLastAck int
|
retransmittablePacketsReceivedSinceLastAck int
|
||||||
ackQueued bool
|
ackQueued bool
|
||||||
ackAlarm time.Time
|
ackAlarm time.Time
|
||||||
lastAck *frames.AckFrame
|
lastAck *wire.AckFrame
|
||||||
|
|
||||||
|
version protocol.VersionNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||||
func NewReceivedPacketHandler() ReceivedPacketHandler {
|
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
|
||||||
return &receivedPacketHandler{
|
return &receivedPacketHandler{
|
||||||
packetHistory: newReceivedPacketHistory(),
|
packetHistory: newReceivedPacketHistory(),
|
||||||
ackSendDelay: protocol.AckSendDelay,
|
ackSendDelay: protocol.AckSendDelay,
|
||||||
|
version: version,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
|
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||||
if packetNumber == 0 {
|
|
||||||
return errInvalidPacketNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
if packetNumber > h.largestObserved {
|
if packetNumber > h.largestObserved {
|
||||||
h.largestObserved = packetNumber
|
h.largestObserved = packetNumber
|
||||||
h.largestObservedReceivedTime = time.Now()
|
h.largestObservedReceivedTime = rcvTime
|
||||||
}
|
}
|
||||||
|
|
||||||
if packetNumber <= h.lowerLimit {
|
if packetNumber < h.ignoreBelow {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.maybeQueueAck(packetNumber, shouldInstigateAck)
|
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLowerLimit sets a lower limit for acking packets.
|
// IgnoreBelow sets a lower limit for acking packets.
|
||||||
// Packets with packet numbers smaller or equal than p will not be acked.
|
// Packets with packet numbers smaller than p will not be acked.
|
||||||
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
|
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
||||||
h.lowerLimit = p
|
h.ignoreBelow = p
|
||||||
h.packetHistory.DeleteUpTo(p)
|
h.packetHistory.DeleteBelow(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
|
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
|
||||||
h.packetsReceivedSinceLastAck++
|
h.packetsReceivedSinceLastAck++
|
||||||
|
|
||||||
if shouldInstigateAck {
|
if shouldInstigateAck {
|
||||||
@ -74,12 +70,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||||||
h.ackQueued = true
|
h.ackQueued = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always send an ack every 20 packets in order to allow the peer to discard
|
|
||||||
// information from the SentPacketManager and provide an RTT measurement.
|
|
||||||
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
|
|
||||||
h.ackQueued = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
|
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
|
||||||
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
|
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
|
||||||
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
|
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
|
||||||
@ -87,7 +77,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if a new missing range above the previously was created
|
// check if a new missing range above the previously was created
|
||||||
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked {
|
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
|
||||||
h.ackQueued = true
|
h.ackQueued = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||||||
h.ackQueued = true
|
h.ackQueued = true
|
||||||
} else {
|
} else {
|
||||||
if h.ackAlarm.IsZero() {
|
if h.ackAlarm.IsZero() {
|
||||||
h.ackAlarm = time.Now().Add(h.ackSendDelay)
|
h.ackAlarm = rcvTime.Add(h.ackSendDelay)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,15 +97,15 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
|
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
||||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
|
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ackRanges := h.packetHistory.GetAckRanges()
|
ackRanges := h.packetHistory.GetAckRanges()
|
||||||
ack := &frames.AckFrame{
|
ack := &wire.AckFrame{
|
||||||
LargestAcked: h.largestObserved,
|
LargestAcked: h.largestObserved,
|
||||||
LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber,
|
LowestAcked: ackRanges[len(ackRanges)-1].First,
|
||||||
PacketReceivedTime: h.largestObservedReceivedTime,
|
PacketReceivedTime: h.largestObservedReceivedTime,
|
||||||
}
|
}
|
||||||
|
|
@ -1,9 +1,9 @@
|
|||||||
package ackhandler
|
package ackhandler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,21 +12,15 @@ import (
|
|||||||
type receivedPacketHistory struct {
|
type receivedPacketHistory struct {
|
||||||
ranges *utils.PacketIntervalList
|
ranges *utils.PacketIntervalList
|
||||||
|
|
||||||
// the map is used as a replacement for a set here. The bool is always supposed to be set to true
|
|
||||||
receivedPacketNumbers map[protocol.PacketNumber]bool
|
|
||||||
lowestInReceivedPacketNumbers protocol.PacketNumber
|
lowestInReceivedPacketNumbers protocol.PacketNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
|
||||||
errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
|
|
||||||
errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets")
|
|
||||||
)
|
|
||||||
|
|
||||||
// newReceivedPacketHistory creates a new received packet history
|
// newReceivedPacketHistory creates a new received packet history
|
||||||
func newReceivedPacketHistory() *receivedPacketHistory {
|
func newReceivedPacketHistory() *receivedPacketHistory {
|
||||||
return &receivedPacketHistory{
|
return &receivedPacketHistory{
|
||||||
ranges: utils.NewPacketIntervalList(),
|
ranges: utils.NewPacketIntervalList(),
|
||||||
receivedPacketNumbers: make(map[protocol.PacketNumber]bool),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,12 +30,6 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
|
|||||||
return errTooManyOutstandingReceivedAckRanges
|
return errTooManyOutstandingReceivedAckRanges
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets {
|
|
||||||
return errTooManyOutstandingReceivedPackets
|
|
||||||
}
|
|
||||||
|
|
||||||
h.receivedPacketNumbers[p] = true
|
|
||||||
|
|
||||||
if h.ranges.Len() == 0 {
|
if h.ranges.Len() == 0 {
|
||||||
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
|
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
|
||||||
return nil
|
return nil
|
||||||
@ -86,23 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUpTo deletes all entries up to (and including) p
|
// DeleteBelow deletes all entries below (but not including) p
|
||||||
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
|
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||||
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
|
if p <= h.lowestInReceivedPacketNumbers {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.lowestInReceivedPacketNumbers = p
|
||||||
|
|
||||||
nextEl := h.ranges.Front()
|
nextEl := h.ranges.Front()
|
||||||
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
||||||
nextEl = el.Next()
|
nextEl = el.Next()
|
||||||
|
|
||||||
if p >= el.Value.Start && p < el.Value.End {
|
if p > el.Value.Start && p <= el.Value.End {
|
||||||
for i := el.Value.Start; i <= p; i++ { // adjust start value of a range
|
el.Value.Start = p
|
||||||
delete(h.receivedPacketNumbers, i)
|
} else if el.Value.End < p { // delete a whole range
|
||||||
}
|
|
||||||
el.Value.Start = p + 1
|
|
||||||
} else if el.Value.End <= p { // delete a whole range
|
|
||||||
for i := el.Value.Start; i <= el.Value.End; i++ {
|
|
||||||
delete(h.receivedPacketNumbers, i)
|
|
||||||
}
|
|
||||||
h.ranges.Remove(el)
|
h.ranges.Remove(el)
|
||||||
} else { // no ranges affected. Nothing to do
|
} else { // no ranges affected. Nothing to do
|
||||||
return
|
return
|
||||||
@ -110,38 +95,27 @@ func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDuplicate determines if a packet should be regarded as a duplicate packet
|
|
||||||
// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed
|
|
||||||
func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool {
|
|
||||||
if p < h.lowestInReceivedPacketNumbers {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := h.receivedPacketNumbers[p]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
|
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
|
||||||
func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
|
func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
|
||||||
if h.ranges.Len() == 0 {
|
if h.ranges.Len() == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ackRanges []frames.AckRange
|
ackRanges := make([]wire.AckRange, h.ranges.Len())
|
||||||
|
i := 0
|
||||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||||
ackRanges = append(ackRanges, frames.AckRange{FirstPacketNumber: el.Value.Start, LastPacketNumber: el.Value.End})
|
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
return ackRanges
|
return ackRanges
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange {
|
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
||||||
ackRange := frames.AckRange{}
|
ackRange := wire.AckRange{}
|
||||||
if h.ranges.Len() > 0 {
|
if h.ranges.Len() > 0 {
|
||||||
r := h.ranges.Back().Value
|
r := h.ranges.Back().Value
|
||||||
ackRange.FirstPacketNumber = r.Start
|
ackRange.First = r.Start
|
||||||
ackRange.LastPacketNumber = r.End
|
ackRange.Last = r.End
|
||||||
}
|
}
|
||||||
return ackRange
|
return ackRange
|
||||||
}
|
}
|
@ -1,12 +1,10 @@
|
|||||||
package ackhandler
|
package ackhandler
|
||||||
|
|
||||||
import (
|
import "github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Returns a new slice with all non-retransmittable frames deleted.
|
// Returns a new slice with all non-retransmittable frames deleted.
|
||||||
func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
|
func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
|
||||||
res := make([]frames.Frame, 0, len(fs))
|
res := make([]wire.Frame, 0, len(fs))
|
||||||
for _, f := range fs {
|
for _, f := range fs {
|
||||||
if IsFrameRetransmittable(f) {
|
if IsFrameRetransmittable(f) {
|
||||||
res = append(res, f)
|
res = append(res, f)
|
||||||
@ -16,11 +14,11 @@ func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
||||||
func IsFrameRetransmittable(f frames.Frame) bool {
|
func IsFrameRetransmittable(f wire.Frame) bool {
|
||||||
switch f.(type) {
|
switch f.(type) {
|
||||||
case *frames.StopWaitingFrame:
|
case *wire.StopWaitingFrame:
|
||||||
return false
|
return false
|
||||||
case *frames.AckFrame:
|
case *wire.AckFrame:
|
||||||
return false
|
return false
|
||||||
default:
|
default:
|
||||||
return true
|
return true
|
||||||
@ -28,7 +26,7 @@ func IsFrameRetransmittable(f frames.Frame) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
|
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
|
||||||
func HasRetransmittableFrames(fs []frames.Frame) bool {
|
func HasRetransmittableFrames(fs []wire.Frame) bool {
|
||||||
for _, f := range fs {
|
for _, f := range fs {
|
||||||
if IsFrameRetransmittable(f) {
|
if IsFrameRetransmittable(f) {
|
||||||
return true
|
return true
|
@ -3,12 +3,13 @@ package ackhandler
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,33 +17,33 @@ const (
|
|||||||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||||
// In fraction of an RTT.
|
// In fraction of an RTT.
|
||||||
timeReorderingFraction = 1.0 / 8
|
timeReorderingFraction = 1.0 / 8
|
||||||
|
// The default RTT used before an RTT sample is taken.
|
||||||
|
// Note: This constant is also defined in the congestion package.
|
||||||
|
defaultInitialRTT = 100 * time.Millisecond
|
||||||
// defaultRTOTimeout is the RTO time on new connections
|
// defaultRTOTimeout is the RTO time on new connections
|
||||||
defaultRTOTimeout = 500 * time.Millisecond
|
defaultRTOTimeout = 500 * time.Millisecond
|
||||||
|
// Minimum time in the future a tail loss probe alarm may be set for.
|
||||||
|
minTPLTimeout = 10 * time.Millisecond
|
||||||
// Minimum time in the future an RTO alarm may be set for.
|
// Minimum time in the future an RTO alarm may be set for.
|
||||||
minRTOTimeout = 200 * time.Millisecond
|
minRTOTimeout = 200 * time.Millisecond
|
||||||
// maxRTOTimeout is the maximum RTO time
|
// maxRTOTimeout is the maximum RTO time
|
||||||
maxRTOTimeout = 60 * time.Second
|
maxRTOTimeout = 60 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
|
||||||
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
|
var ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
|
||||||
ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
|
|
||||||
// ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets
|
|
||||||
ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets")
|
|
||||||
// ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped
|
|
||||||
ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
|
||||||
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
|
||||||
)
|
|
||||||
|
|
||||||
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
|
|
||||||
|
|
||||||
type sentPacketHandler struct {
|
type sentPacketHandler struct {
|
||||||
lastSentPacketNumber protocol.PacketNumber
|
lastSentPacketNumber protocol.PacketNumber
|
||||||
|
nextPacketSendTime time.Time
|
||||||
skippedPackets []protocol.PacketNumber
|
skippedPackets []protocol.PacketNumber
|
||||||
|
|
||||||
LargestAcked protocol.PacketNumber
|
largestAcked protocol.PacketNumber
|
||||||
|
|
||||||
largestReceivedPacketWithAck protocol.PacketNumber
|
largestReceivedPacketWithAck protocol.PacketNumber
|
||||||
|
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||||
|
// example: we send an ACK for packets 90-100 with packet number 20
|
||||||
|
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
|
||||||
|
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||||
|
|
||||||
packetHistory *PacketList
|
packetHistory *PacketList
|
||||||
stopWaitingManager stopWaitingManager
|
stopWaitingManager stopWaitingManager
|
||||||
@ -54,6 +55,10 @@ type sentPacketHandler struct {
|
|||||||
congestion congestion.SendAlgorithm
|
congestion congestion.SendAlgorithm
|
||||||
rttStats *congestion.RTTStats
|
rttStats *congestion.RTTStats
|
||||||
|
|
||||||
|
handshakeComplete bool
|
||||||
|
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
||||||
|
handshakeCount uint32
|
||||||
|
|
||||||
// The number of times an RTO has been sent without receiving an ack.
|
// The number of times an RTO has been sent without receiving an ack.
|
||||||
rtoCount uint32
|
rtoCount uint32
|
||||||
|
|
||||||
@ -82,20 +87,27 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
|
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
|
||||||
if f := h.packetHistory.Front(); f != nil {
|
if f := h.packetHistory.Front(); f != nil {
|
||||||
return f.Value.PacketNumber - 1
|
return f.Value.PacketNumber
|
||||||
}
|
}
|
||||||
return h.LargestAcked
|
return h.largestAcked + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||||
|
var queue []*Packet
|
||||||
|
for _, packet := range h.retransmissionQueue {
|
||||||
|
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||||
|
queue = append(queue, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.retransmissionQueue = queue
|
||||||
|
h.handshakeComplete = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
||||||
if packet.PacketNumber <= h.lastSentPacketNumber {
|
|
||||||
return errPacketNumberNotIncreasing
|
|
||||||
}
|
|
||||||
|
|
||||||
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
|
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
|
||||||
return ErrTooManyTrackedSentPackets
|
return errors.New("Too many outstanding non-acked and non-retransmitted packets")
|
||||||
}
|
}
|
||||||
|
|
||||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||||
@ -106,14 +118,22 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.lastSentPacketNumber = packet.PacketNumber
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
h.lastSentPacketNumber = packet.PacketNumber
|
||||||
|
|
||||||
|
var largestAcked protocol.PacketNumber
|
||||||
|
if len(packet.Frames) > 0 {
|
||||||
|
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||||
|
largestAcked = ackFrame.LargestAcked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
||||||
isRetransmittable := len(packet.Frames) != 0
|
isRetransmittable := len(packet.Frames) != 0
|
||||||
|
|
||||||
if isRetransmittable {
|
if isRetransmittable {
|
||||||
packet.SendTime = now
|
packet.sendTime = now
|
||||||
|
packet.largestAcked = largestAcked
|
||||||
h.bytesInFlight += packet.Length
|
h.bytesInFlight += packet.Length
|
||||||
h.packetHistory.PushBack(*packet)
|
h.packetHistory.PushBack(*packet)
|
||||||
}
|
}
|
||||||
@ -126,29 +146,32 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||||||
isRetransmittable,
|
isRetransmittable,
|
||||||
)
|
)
|
||||||
|
|
||||||
h.updateLossDetectionAlarm()
|
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
|
||||||
|
|
||||||
|
h.updateLossDetectionAlarm(now)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, rcvTime time.Time) error {
|
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||||
if ackFrame.LargestAcked > h.lastSentPacketNumber {
|
if ackFrame.LargestAcked > h.lastSentPacketNumber {
|
||||||
return errAckForUnsentPacket
|
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||||
}
|
}
|
||||||
|
|
||||||
// duplicate or out-of-order ACK
|
// duplicate or out-of-order ACK
|
||||||
|
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
|
||||||
if withPacketNumber <= h.largestReceivedPacketWithAck {
|
if withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||||
return ErrDuplicateOrOutOfOrderAck
|
return ErrDuplicateOrOutOfOrderAck
|
||||||
}
|
}
|
||||||
h.largestReceivedPacketWithAck = withPacketNumber
|
h.largestReceivedPacketWithAck = withPacketNumber
|
||||||
|
|
||||||
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
|
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
|
||||||
if ackFrame.LargestAcked <= h.largestInOrderAcked() {
|
if ackFrame.LargestAcked < h.lowestUnacked() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
h.LargestAcked = ackFrame.LargestAcked
|
h.largestAcked = ackFrame.LargestAcked
|
||||||
|
|
||||||
if h.skippedPacketsAcked(ackFrame) {
|
if h.skippedPacketsAcked(ackFrame) {
|
||||||
return ErrAckForSkippedPacket
|
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||||
}
|
}
|
||||||
|
|
||||||
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
|
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
|
||||||
@ -164,13 +187,22 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
|
|||||||
|
|
||||||
if len(ackedPackets) > 0 {
|
if len(ackedPackets) > 0 {
|
||||||
for _, p := range ackedPackets {
|
for _, p := range ackedPackets {
|
||||||
|
if encLevel < p.Value.EncryptionLevel {
|
||||||
|
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
|
||||||
|
}
|
||||||
|
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||||
|
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||||
|
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||||
|
if p.Value.largestAcked != 0 {
|
||||||
|
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
|
||||||
|
}
|
||||||
h.onPacketAcked(p)
|
h.onPacketAcked(p)
|
||||||
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.detectLostPackets()
|
h.detectLostPackets(rcvTime)
|
||||||
h.updateLossDetectionAlarm()
|
h.updateLossDetectionAlarm(rcvTime)
|
||||||
|
|
||||||
h.garbageCollectSkippedPackets()
|
h.garbageCollectSkippedPackets()
|
||||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||||
@ -178,7 +210,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame) ([]*PacketElement, error) {
|
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||||
|
return h.lowestPacketNotConfirmedAcked
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
|
||||||
var ackedPackets []*PacketElement
|
var ackedPackets []*PacketElement
|
||||||
ackRangeIndex := 0
|
ackRangeIndex := 0
|
||||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||||
@ -197,14 +233,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame
|
|||||||
if ackFrame.HasMissingRanges() {
|
if ackFrame.HasMissingRanges() {
|
||||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||||
|
|
||||||
for packetNumber > ackRange.LastPacketNumber && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||||
ackRangeIndex++
|
ackRangeIndex++
|
||||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||||
}
|
}
|
||||||
|
|
||||||
if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range
|
if packetNumber >= ackRange.First { // packet i contained in ACK range
|
||||||
if packetNumber > ackRange.LastPacketNumber {
|
if packetNumber > ackRange.Last {
|
||||||
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber)
|
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
|
||||||
}
|
}
|
||||||
ackedPackets = append(ackedPackets, el)
|
ackedPackets = append(ackedPackets, el)
|
||||||
}
|
}
|
||||||
@ -212,7 +248,6 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame
|
|||||||
ackedPackets = append(ackedPackets, el)
|
ackedPackets = append(ackedPackets, el)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ackedPackets, nil
|
return ackedPackets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,7 +255,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
|
|||||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||||
packet := el.Value
|
packet := el.Value
|
||||||
if packet.PacketNumber == largestAcked {
|
if packet.PacketNumber == largestAcked {
|
||||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
|
h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Packets are sorted by number, so we can stop searching
|
// Packets are sorted by number, so we can stop searching
|
||||||
@ -231,27 +266,27 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
|
||||||
// Cancel the alarm if no packets are outstanding
|
// Cancel the alarm if no packets are outstanding
|
||||||
if h.packetHistory.Len() == 0 {
|
if h.packetHistory.Len() == 0 {
|
||||||
h.alarm = time.Time{}
|
h.alarm = time.Time{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(#496): Handle handshake packets separately
|
|
||||||
// TODO(#497): TLP
|
// TODO(#497): TLP
|
||||||
if !h.lossTime.IsZero() {
|
if !h.handshakeComplete {
|
||||||
|
h.alarm = now.Add(h.computeHandshakeTimeout())
|
||||||
|
} else if !h.lossTime.IsZero() {
|
||||||
// Early retransmit timer or time loss detection.
|
// Early retransmit timer or time loss detection.
|
||||||
h.alarm = h.lossTime
|
h.alarm = h.lossTime
|
||||||
} else {
|
} else {
|
||||||
// RTO
|
// RTO
|
||||||
h.alarm = time.Now().Add(h.computeRTOTimeout())
|
h.alarm = now.Add(h.computeRTOTimeout())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) detectLostPackets() {
|
func (h *sentPacketHandler) detectLostPackets(now time.Time) {
|
||||||
h.lossTime = time.Time{}
|
h.lossTime = time.Time{}
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||||
@ -260,11 +295,11 @@ func (h *sentPacketHandler) detectLostPackets() {
|
|||||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||||
packet := el.Value
|
packet := el.Value
|
||||||
|
|
||||||
if packet.PacketNumber > h.LargestAcked {
|
if packet.PacketNumber > h.largestAcked {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
timeSinceSent := now.Sub(packet.SendTime)
|
timeSinceSent := now.Sub(packet.sendTime)
|
||||||
if timeSinceSent > delayUntilLost {
|
if timeSinceSent > delayUntilLost {
|
||||||
lostPackets = append(lostPackets, el)
|
lostPackets = append(lostPackets, el)
|
||||||
} else if h.lossTime.IsZero() {
|
} else if h.lossTime.IsZero() {
|
||||||
@ -282,18 +317,22 @@ func (h *sentPacketHandler) detectLostPackets() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) OnAlarm() {
|
func (h *sentPacketHandler) OnAlarm() {
|
||||||
// TODO(#496): Handle handshake packets separately
|
now := time.Now()
|
||||||
|
|
||||||
// TODO(#497): TLP
|
// TODO(#497): TLP
|
||||||
if !h.lossTime.IsZero() {
|
if !h.handshakeComplete {
|
||||||
|
h.queueHandshakePacketsForRetransmission()
|
||||||
|
h.handshakeCount++
|
||||||
|
} else if !h.lossTime.IsZero() {
|
||||||
// Early retransmit or time loss detection
|
// Early retransmit or time loss detection
|
||||||
h.detectLostPackets()
|
h.detectLostPackets(now)
|
||||||
} else {
|
} else {
|
||||||
// RTO
|
// RTO
|
||||||
h.retransmitOldestTwoPackets()
|
h.retransmitOldestTwoPackets()
|
||||||
h.rtoCount++
|
h.rtoCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateLossDetectionAlarm()
|
h.updateLossDetectionAlarm(now)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||||
@ -303,6 +342,7 @@ func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
|||||||
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
|
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
|
||||||
h.bytesInFlight -= packetElement.Value.Length
|
h.bytesInFlight -= packetElement.Value.Length
|
||||||
h.rtoCount = 0
|
h.rtoCount = 0
|
||||||
|
h.handshakeCount = 0
|
||||||
// TODO(#497): h.tlpCount = 0
|
// TODO(#497): h.tlpCount = 0
|
||||||
h.packetHistory.Remove(packetElement)
|
h.packetHistory.Remove(packetElement)
|
||||||
}
|
}
|
||||||
@ -320,20 +360,19 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
|
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
|
||||||
return h.largestInOrderAcked() + 1
|
return h.lowestUnacked()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
|
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) SendingAllowed() bool {
|
func (h *sentPacketHandler) SendingAllowed() bool {
|
||||||
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
|
cwnd := h.congestion.GetCongestionWindow()
|
||||||
|
congestionLimited := h.bytesInFlight > cwnd
|
||||||
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
|
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
|
||||||
if congestionLimited {
|
if congestionLimited {
|
||||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
|
utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||||
h.bytesInFlight,
|
|
||||||
h.congestion.GetCongestionWindow())
|
|
||||||
}
|
}
|
||||||
// Workaround for #555:
|
// Workaround for #555:
|
||||||
// Always allow sending of retransmissions. This should probably be limited
|
// Always allow sending of retransmissions. This should probably be limited
|
||||||
@ -342,6 +381,18 @@ func (h *sentPacketHandler) SendingAllowed() bool {
|
|||||||
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
|
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||||
|
return h.nextPacketSendTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
||||||
|
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||||
|
if delay == 0 || delay > protocol.MinPacingDelay {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
|
||||||
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
||||||
if p := h.packetHistory.Front(); p != nil {
|
if p := h.packetHistory.Front(); p != nil {
|
||||||
h.queueRTO(p)
|
h.queueRTO(p)
|
||||||
@ -363,6 +414,18 @@ func (h *sentPacketHandler) queueRTO(el *PacketElement) {
|
|||||||
h.congestion.OnRetransmissionTimeout(true)
|
h.congestion.OnRetransmissionTimeout(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
|
||||||
|
var handshakePackets []*PacketElement
|
||||||
|
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||||
|
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||||
|
handshakePackets = append(handshakePackets, el)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, el := range handshakePackets {
|
||||||
|
h.queuePacketForRetransmission(el)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
|
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
|
||||||
packet := &packetElement.Value
|
packet := &packetElement.Value
|
||||||
h.bytesInFlight -= packet.Length
|
h.bytesInFlight -= packet.Length
|
||||||
@ -371,6 +434,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketEl
|
|||||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
|
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||||
|
duration := 2 * h.rttStats.SmoothedRTT()
|
||||||
|
if duration == 0 {
|
||||||
|
duration = 2 * defaultInitialRTT
|
||||||
|
}
|
||||||
|
duration = utils.MaxDuration(duration, minTPLTimeout)
|
||||||
|
// exponential backoff
|
||||||
|
// There's an implicit limit to this set by the handshake timeout.
|
||||||
|
return duration << h.handshakeCount
|
||||||
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||||
rto := h.congestion.RetransmissionDelay()
|
rto := h.congestion.RetransmissionDelay()
|
||||||
if rto == 0 {
|
if rto == 0 {
|
||||||
@ -382,7 +456,7 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
|||||||
return utils.MinDuration(rto, maxRTOTimeout)
|
return utils.MinDuration(rto, maxRTOTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool {
|
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||||
for _, p := range h.skippedPackets {
|
for _, p := range h.skippedPackets {
|
||||||
if ackFrame.AcksPacket(p) {
|
if ackFrame.AcksPacket(p) {
|
||||||
return true
|
return true
|
||||||
@ -392,10 +466,10 @@ func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||||
lioa := h.largestInOrderAcked()
|
lowestUnacked := h.lowestUnacked()
|
||||||
deleteIndex := 0
|
deleteIndex := 0
|
||||||
for i, p := range h.skippedPackets {
|
for i, p := range h.skippedPackets {
|
||||||
if p <= lioa {
|
if p < lowestUnacked {
|
||||||
deleteIndex = i + 1
|
deleteIndex = i + 1
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,8 +1,8 @@
|
|||||||
package ackhandler
|
package ackhandler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/lucas-clemente/quic-go/frames"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
|
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
|
||||||
@ -10,10 +10,10 @@ type stopWaitingManager struct {
|
|||||||
largestLeastUnackedSent protocol.PacketNumber
|
largestLeastUnackedSent protocol.PacketNumber
|
||||||
nextLeastUnacked protocol.PacketNumber
|
nextLeastUnacked protocol.PacketNumber
|
||||||
|
|
||||||
lastStopWaitingFrame *frames.StopWaitingFrame
|
lastStopWaitingFrame *wire.StopWaitingFrame
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
|
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||||
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
|
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
|
||||||
if force {
|
if force {
|
||||||
return s.lastStopWaitingFrame
|
return s.lastStopWaitingFrame
|
||||||
@ -22,14 +22,14 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaiting
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.largestLeastUnackedSent = s.nextLeastUnacked
|
s.largestLeastUnackedSent = s.nextLeastUnacked
|
||||||
swf := &frames.StopWaitingFrame{
|
swf := &wire.StopWaitingFrame{
|
||||||
LeastUnacked: s.nextLeastUnacked,
|
LeastUnacked: s.nextLeastUnacked,
|
||||||
}
|
}
|
||||||
s.lastStopWaitingFrame = swf
|
s.lastStopWaitingFrame = swf
|
||||||
return swf
|
return swf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stopWaitingManager) ReceivedAck(ack *frames.AckFrame) {
|
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||||
if ack.LargestAcked >= s.nextLeastUnacked {
|
if ack.LargestAcked >= s.nextLeastUnacked {
|
||||||
s.nextLeastUnacked = ack.LargestAcked + 1
|
s.nextLeastUnacked = ack.LargestAcked + 1
|
||||||
}
|
}
|
@ -3,7 +3,7 @@ package congestion
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Bandwidth of a connection
|
// Bandwidth of a connection
|
@ -4,8 +4,8 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
// This cubic implementation is based on the one found in Chromiums's QUIC
|
@ -3,8 +3,8 @@ package congestion
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -76,15 +76,19 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
// TimeUntilSend returns when the next packet should be sent.
|
||||||
|
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
|
||||||
if c.InRecovery() {
|
if c.InRecovery() {
|
||||||
// PRR is used when in recovery.
|
// PRR is used when in recovery.
|
||||||
return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold())
|
if c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) == 0 {
|
||||||
}
|
|
||||||
if c.GetCongestionWindow() > bytesInFlight {
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return utils.InfDuration
|
}
|
||||||
|
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()/protocol.DefaultTCPMSS)
|
||||||
|
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
|
||||||
|
delay = delay * 8 / 5
|
||||||
|
}
|
||||||
|
return delay
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
@ -3,8 +3,8 @@ package congestion
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
// Note(pwestin): the magic clamping numbers come from the original code in
|
@ -3,12 +3,12 @@ package congestion
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A SendAlgorithm performs congestion control and calculates the congestion window
|
// A SendAlgorithm performs congestion control and calculates the congestion window
|
||||||
type SendAlgorithm interface {
|
type SendAlgorithm interface {
|
||||||
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
|
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
|
||||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
|
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
|
||||||
GetCongestionWindow() protocol.ByteCount
|
GetCongestionWindow() protocol.ByteCount
|
||||||
MaybeExitSlowStart()
|
MaybeExitSlowStart()
|
@ -3,8 +3,8 @@ package congestion
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
@ -7,6 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Note: This constant is also defined in the ackhandler package.
|
||||||
initialRTTus = 100 * 1000
|
initialRTTus = 100 * 1000
|
||||||
rttAlpha float32 = 0.125
|
rttAlpha float32 = 0.125
|
||||||
oneMinusAlpha float32 = (1 - rttAlpha)
|
oneMinusAlpha float32 = (1 - rttAlpha)
|
||||||
@ -97,10 +98,10 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
|||||||
r.updateRecentMinRTT(sendDelta, now)
|
r.updateRecentMinRTT(sendDelta, now)
|
||||||
|
|
||||||
// Correct for ackDelay if information received from the peer results in a
|
// Correct for ackDelay if information received from the peer results in a
|
||||||
// positive RTT sample. Otherwise, we use the sendDelta as a reasonable
|
// an RTT sample at least as large as minRTT. Otherwise, only use the
|
||||||
// measure for smoothedRTT.
|
// sendDelta.
|
||||||
sample := sendDelta
|
sample := sendDelta
|
||||||
if sample > ackDelay {
|
if sample-r.minRTT >= ackDelay {
|
||||||
sample -= ackDelay
|
sample -= ackDelay
|
||||||
}
|
}
|
||||||
r.latestRTT = sample
|
r.latestRTT = sample
|
@ -1,6 +1,6 @@
|
|||||||
package congestion
|
package congestion
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
|
||||||
type connectionStats struct {
|
type connectionStats struct {
|
||||||
slowstartPacketsLost protocol.PacketNumber
|
slowstartPacketsLost protocol.PacketNumber
|
@ -1,9 +1,10 @@
|
|||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/protocol"
|
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
|
||||||
// An AEAD implements QUIC's authenticated encryption and associated data
|
// An AEAD implements QUIC's authenticated encryption and associated data
|
||||||
type AEAD interface {
|
type AEAD interface {
|
||||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||||
|
Overhead() int
|
||||||
}
|
}
|
72
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go
generated
vendored
Normal file
72
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go
generated
vendored
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/aes12"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
type aeadAESGCM12 struct {
|
||||||
|
otherIV []byte
|
||||||
|
myIV []byte
|
||||||
|
encrypter cipher.AEAD
|
||||||
|
decrypter cipher.AEAD
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ AEAD = &aeadAESGCM12{}
|
||||||
|
|
||||||
|
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
|
||||||
|
//
|
||||||
|
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
||||||
|
// tag size, and couples the cipher and aes packages closely.
|
||||||
|
// See https://github.com/lucas-clemente/aes12.
|
||||||
|
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||||
|
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||||
|
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
||||||
|
}
|
||||||
|
encrypterCipher, err := aes12.NewCipher(myKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
encrypter, err := aes12.NewGCM(encrypterCipher)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decrypterCipher, err := aes12.NewCipher(otherKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decrypter, err := aes12.NewGCM(decrypterCipher)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &aeadAESGCM12{
|
||||||
|
otherIV: otherIV,
|
||||||
|
myIV: myIV,
|
||||||
|
encrypter: encrypter,
|
||||||
|
decrypter: decrypter,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||||
|
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||||
|
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||||
|
res := make([]byte, 12)
|
||||||
|
copy(res[0:4], iv)
|
||||||
|
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM12) Overhead() int {
|
||||||
|
return aead.encrypter.Overhead()
|
||||||
|
}
|
74
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go
generated
vendored
Normal file
74
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go
generated
vendored
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
type aeadAESGCM struct {
|
||||||
|
otherIV []byte
|
||||||
|
myIV []byte
|
||||||
|
encrypter cipher.AEAD
|
||||||
|
decrypter cipher.AEAD
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ AEAD = &aeadAESGCM{}
|
||||||
|
|
||||||
|
const ivLen = 12
|
||||||
|
|
||||||
|
// NewAEADAESGCM creates a AEAD using AES-GCM
|
||||||
|
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||||
|
// the IVs need to be at least 8 bytes long, otherwise we can't compute the nonce
|
||||||
|
if len(otherIV) != ivLen || len(myIV) != ivLen {
|
||||||
|
return nil, errors.New("AES-GCM: expected 12 byte IVs")
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypterCipher, err := aes.NewCipher(myKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
encrypter, err := cipher.NewGCM(encrypterCipher)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decrypterCipher, err := aes.NewCipher(otherKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
decrypter, err := cipher.NewGCM(decrypterCipher)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &aeadAESGCM{
|
||||||
|
otherIV: otherIV,
|
||||||
|
myIV: myIV,
|
||||||
|
encrypter: encrypter,
|
||||||
|
decrypter: decrypter,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||||
|
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||||
|
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||||
|
nonce := make([]byte, ivLen)
|
||||||
|
binary.BigEndian.PutUint64(nonce[ivLen-8:], uint64(packetNumber))
|
||||||
|
for i := 0; i < ivLen; i++ {
|
||||||
|
nonce[i] ^= iv[i]
|
||||||
|
}
|
||||||
|
return nonce
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadAESGCM) Overhead() int {
|
||||||
|
return aead.encrypter.Overhead()
|
||||||
|
}
|
@ -5,7 +5,7 @@ import (
|
|||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
|
|
||||||
"github.com/hashicorp/golang-lru"
|
"github.com/hashicorp/golang-lru"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
@ -51,10 +51,10 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
|
|||||||
res.WriteByte(uint8(e.t))
|
res.WriteByte(uint8(e.t))
|
||||||
switch e.t {
|
switch e.t {
|
||||||
case entryCached:
|
case entryCached:
|
||||||
utils.WriteUint64(res, e.h)
|
utils.LittleEndian.WriteUint64(res, e.h)
|
||||||
case entryCommon:
|
case entryCommon:
|
||||||
utils.WriteUint64(res, e.h)
|
utils.LittleEndian.WriteUint64(res, e.h)
|
||||||
utils.WriteUint32(res, e.i)
|
utils.LittleEndian.WriteUint32(res, e.i)
|
||||||
case entryCompressed:
|
case entryCompressed:
|
||||||
totalUncompressedLen += 4 + len(chain[i])
|
totalUncompressedLen += 4 + len(chain[i])
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
|
|||||||
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
|
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.WriteUint32(res, uint32(totalUncompressedLen))
|
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
|
||||||
|
|
||||||
for i, e := range entries {
|
for i, e := range entries {
|
||||||
if e.t != entryCompressed {
|
if e.t != entryCompressed {
|
||||||
@ -115,11 +115,11 @@ func decompressChain(data []byte) ([][]byte, error) {
|
|||||||
return nil, errors.New("unexpected cached certificate")
|
return nil, errors.New("unexpected cached certificate")
|
||||||
case entryCommon:
|
case entryCommon:
|
||||||
e := entry{t: entryCommon}
|
e := entry{t: entryCommon}
|
||||||
e.h, err = utils.ReadUint64(r)
|
e.h, err = utils.LittleEndian.ReadUint64(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
e.i, err = utils.ReadUint32(r)
|
e.i, err = utils.LittleEndian.ReadUint32(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -146,7 +146,7 @@ func decompressChain(data []byte) ([][]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if hasCompressedCerts {
|
if hasCompressedCerts {
|
||||||
uncompressedLength, err := utils.ReadUint32(r)
|
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(4)
|
fmt.Println(4)
|
||||||
return nil, err
|
return nil, err
|
@ -18,6 +18,7 @@ type CertManager interface {
|
|||||||
GetLeafCertHash() (uint64, error)
|
GetLeafCertHash() (uint64, error)
|
||||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
||||||
Verify(hostname string) error
|
Verify(hostname string) error
|
||||||
|
GetChain() []*x509.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
type certManager struct {
|
type certManager struct {
|
||||||
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *certManager) GetChain() []*x509.Certificate {
|
||||||
|
return c.chain
|
||||||
|
}
|
||||||
|
|
||||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
func (c *certManager) GetCommonCertificateHashes() []byte {
|
||||||
return getCommonCertificateHashes()
|
return getCommonCertificateHashes()
|
||||||
}
|
}
|
@ -4,11 +4,12 @@ package crypto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/aead/chacha20"
|
"github.com/aead/chacha20"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
type aeadChacha20Poly1305 struct {
|
type aeadChacha20Poly1305 struct {
|
||||||
@ -45,9 +46,16 @@ func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||||
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||||
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
|
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||||
|
res := make([]byte, 12)
|
||||||
|
copy(res[0:4], iv)
|
||||||
|
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||||
|
return res
|
||||||
}
|
}
|
49
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
Normal file
49
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret"
|
||||||
|
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||||
|
type TLSExporter interface {
|
||||||
|
GetCipherSuite() mint.CipherSuiteParams
|
||||||
|
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
||||||
|
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
||||||
|
var myLabel, otherLabel string
|
||||||
|
if pers == protocol.PerspectiveClient {
|
||||||
|
myLabel = clientExporterLabel
|
||||||
|
otherLabel = serverExporterLabel
|
||||||
|
} else {
|
||||||
|
myLabel = serverExporterLabel
|
||||||
|
otherLabel = clientExporterLabel
|
||||||
|
}
|
||||||
|
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
|
||||||
|
cs := tls.GetCipherSuite()
|
||||||
|
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen)
|
||||||
|
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen)
|
||||||
|
return key, iv, nil
|
||||||
|
}
|
@ -5,8 +5,8 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
"golang.org/x/crypto/hkdf"
|
||||||
)
|
)
|
||||||
@ -20,8 +20,8 @@ import (
|
|||||||
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
|
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance
|
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
|
||||||
func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
|
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
|
||||||
var swap bool
|
var swap bool
|
||||||
if pers == protocol.PerspectiveClient {
|
if pers == protocol.PerspectiveClient {
|
||||||
swap = true
|
swap = true
|
||||||
@ -30,7 +30,7 @@ func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID pr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
|
||||||
}
|
}
|
||||||
|
|
||||||
// deriveKeys derives the keys and the IVs
|
// deriveKeys derives the keys and the IVs
|
||||||
@ -42,7 +42,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
|
|||||||
} else {
|
} else {
|
||||||
info.Write([]byte("QUIC key expansion\x00"))
|
info.Write([]byte("QUIC key expansion\x00"))
|
||||||
}
|
}
|
||||||
utils.WriteUint64(&info, uint64(connID))
|
utils.BigEndian.WriteUint64(&info, uint64(connID))
|
||||||
info.Write(chlo)
|
info.Write(chlo)
|
||||||
info.Write(scfg)
|
info.Write(scfg)
|
||||||
info.Write(cert)
|
info.Write(cert)
|
11
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go
generated
vendored
Normal file
11
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go
generated
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
|
||||||
|
// NewNullAEAD creates a NullAEAD
|
||||||
|
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
|
||||||
|
if v.UsesTLS() {
|
||||||
|
return newNullAEADAESGCM(connID, p)
|
||||||
|
}
|
||||||
|
return &nullAEADFNV128a{perspective: p}, nil
|
||||||
|
}
|
44
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
Normal file
44
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"encoding/binary"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39}
|
||||||
|
|
||||||
|
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||||
|
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||||
|
|
||||||
|
var mySecret, otherSecret []byte
|
||||||
|
if pers == protocol.PerspectiveClient {
|
||||||
|
mySecret = clientSecret
|
||||||
|
otherSecret = serverSecret
|
||||||
|
} else {
|
||||||
|
mySecret = serverSecret
|
||||||
|
otherSecret = clientSecret
|
||||||
|
}
|
||||||
|
|
||||||
|
myKey, myIV := computeNullAEADKeyAndIV(mySecret)
|
||||||
|
otherKey, otherIV := computeNullAEADKeyAndIV(otherSecret)
|
||||||
|
|
||||||
|
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||||
|
connID := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint64(connID, uint64(connectionID))
|
||||||
|
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID)
|
||||||
|
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||||
|
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
||||||
|
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16)
|
||||||
|
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12)
|
||||||
|
return
|
||||||
|
}
|
@ -5,27 +5,18 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/lucas-clemente/fnv128a"
|
"github.com/lucas-clemente/fnv128a"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nullAEAD handles not-yet encrypted packets
|
// nullAEAD handles not-yet encrypted packets
|
||||||
type nullAEAD struct {
|
type nullAEADFNV128a struct {
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
version protocol.VersionNumber
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ AEAD = &nullAEAD{}
|
var _ AEAD = &nullAEADFNV128a{}
|
||||||
|
|
||||||
// NewNullAEAD creates a NullAEAD
|
|
||||||
func NewNullAEAD(p protocol.Perspective, v protocol.VersionNumber) AEAD {
|
|
||||||
return &nullAEAD{
|
|
||||||
perspective: p,
|
|
||||||
version: v,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open and verify the ciphertext
|
// Open and verify the ciphertext
|
||||||
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||||
if len(src) < 12 {
|
if len(src) < 12 {
|
||||||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
||||||
}
|
}
|
||||||
@ -33,13 +24,11 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||||||
hash := fnv128a.New()
|
hash := fnv128a.New()
|
||||||
hash.Write(associatedData)
|
hash.Write(associatedData)
|
||||||
hash.Write(src[12:])
|
hash.Write(src[12:])
|
||||||
if n.version >= protocol.Version37 {
|
|
||||||
if n.perspective == protocol.PerspectiveServer {
|
if n.perspective == protocol.PerspectiveServer {
|
||||||
hash.Write([]byte("Client"))
|
hash.Write([]byte("Client"))
|
||||||
} else {
|
} else {
|
||||||
hash.Write([]byte("Server"))
|
hash.Write([]byte("Server"))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
testHigh, testLow := hash.Sum128()
|
testHigh, testLow := hash.Sum128()
|
||||||
|
|
||||||
low := binary.LittleEndian.Uint64(src)
|
low := binary.LittleEndian.Uint64(src)
|
||||||
@ -52,7 +41,7 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Seal writes hash and ciphertext to the buffer
|
// Seal writes hash and ciphertext to the buffer
|
||||||
func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||||
if cap(dst) < 12+len(src) {
|
if cap(dst) < 12+len(src) {
|
||||||
dst = make([]byte, 12+len(src))
|
dst = make([]byte, 12+len(src))
|
||||||
} else {
|
} else {
|
||||||
@ -63,13 +52,11 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||||||
hash.Write(associatedData)
|
hash.Write(associatedData)
|
||||||
hash.Write(src)
|
hash.Write(src)
|
||||||
|
|
||||||
if n.version >= protocol.Version37 {
|
|
||||||
if n.perspective == protocol.PerspectiveServer {
|
if n.perspective == protocol.PerspectiveServer {
|
||||||
hash.Write([]byte("Server"))
|
hash.Write([]byte("Server"))
|
||||||
} else {
|
} else {
|
||||||
hash.Write([]byte("Client"))
|
hash.Write([]byte("Client"))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
high, low := hash.Sum128()
|
high, low := hash.Sum128()
|
||||||
|
|
||||||
@ -78,3 +65,7 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||||||
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
|
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
|
||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *nullAEADFNV128a) Overhead() int {
|
||||||
|
return 12
|
||||||
|
}
|
108
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
108
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
package flowcontrol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type baseFlowController struct {
|
||||||
|
// for sending data
|
||||||
|
bytesSent protocol.ByteCount
|
||||||
|
sendWindow protocol.ByteCount
|
||||||
|
|
||||||
|
// for receiving data
|
||||||
|
mutex sync.RWMutex
|
||||||
|
bytesRead protocol.ByteCount
|
||||||
|
highestReceived protocol.ByteCount
|
||||||
|
receiveWindow protocol.ByteCount
|
||||||
|
receiveWindowSize protocol.ByteCount
|
||||||
|
maxReceiveWindowSize protocol.ByteCount
|
||||||
|
|
||||||
|
epochStartTime time.Time
|
||||||
|
epochStartOffset protocol.ByteCount
|
||||||
|
rttStats *congestion.RTTStats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||||
|
c.bytesSent += n
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
||||||
|
// it returns true if the window was actually updated
|
||||||
|
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||||
|
if offset > c.sendWindow {
|
||||||
|
c.sendWindow = offset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
|
||||||
|
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
|
||||||
|
if c.bytesSent > c.sendWindow {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.sendWindow - c.bytesSent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// pretend we sent a WindowUpdate when reading the first byte
|
||||||
|
// this way auto-tuning of the window size already works for the first WindowUpdate
|
||||||
|
if c.bytesRead == 0 {
|
||||||
|
c.startNewAutoTuningEpoch()
|
||||||
|
}
|
||||||
|
c.bytesRead += n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||||
|
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||||
|
// update the window when more than the threshold was consumed
|
||||||
|
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWindowUpdate updates the receive window, if necessary
|
||||||
|
// it returns the new offset
|
||||||
|
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
||||||
|
if !c.hasWindowUpdate() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
c.maybeAdjustWindowSize()
|
||||||
|
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||||
|
return c.receiveWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||||
|
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||||
|
func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||||
|
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||||
|
// don't do anything if less than half the window has been consumed
|
||||||
|
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rtt := c.rttStats.SmoothedRTT()
|
||||||
|
if rtt == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||||
|
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||||
|
// window is consumed too fast, try to increase the window size
|
||||||
|
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||||
|
}
|
||||||
|
c.startNewAutoTuningEpoch()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) startNewAutoTuningEpoch() {
|
||||||
|
c.epochStartTime = time.Now()
|
||||||
|
c.epochStartOffset = c.bytesRead
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||||
|
return c.highestReceived > c.receiveWindow
|
||||||
|
}
|
83
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
83
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package flowcontrol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type connectionFlowController struct {
|
||||||
|
lastBlockedAt protocol.ByteCount
|
||||||
|
baseFlowController
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ConnectionFlowController = &connectionFlowController{}
|
||||||
|
|
||||||
|
// NewConnectionFlowController gets a new flow controller for the connection
|
||||||
|
// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0.
|
||||||
|
func NewConnectionFlowController(
|
||||||
|
receiveWindow protocol.ByteCount,
|
||||||
|
maxReceiveWindow protocol.ByteCount,
|
||||||
|
rttStats *congestion.RTTStats,
|
||||||
|
) ConnectionFlowController {
|
||||||
|
return &connectionFlowController{
|
||||||
|
baseFlowController: baseFlowController{
|
||||||
|
rttStats: rttStats,
|
||||||
|
receiveWindow: receiveWindow,
|
||||||
|
receiveWindowSize: receiveWindow,
|
||||||
|
maxReceiveWindowSize: maxReceiveWindow,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||||
|
return c.baseFlowController.sendWindowSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||||
|
// For every offset, it only returns true once.
|
||||||
|
// If it is blocked, the offset is returned.
|
||||||
|
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||||
|
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
c.lastBlockedAt = c.sendWindow
|
||||||
|
return true, c.sendWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||||
|
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
c.highestReceived += increment
|
||||||
|
if c.checkFlowControlViolation() {
|
||||||
|
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||||
|
c.mutex.Lock()
|
||||||
|
oldWindowSize := c.receiveWindowSize
|
||||||
|
offset := c.baseFlowController.getWindowUpdate()
|
||||||
|
if oldWindowSize < c.receiveWindowSize {
|
||||||
|
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||||
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
|
return offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureMinimumWindowSize sets a minimum window size
|
||||||
|
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||||
|
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
if inc > c.receiveWindowSize {
|
||||||
|
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
|
||||||
|
c.startNewAutoTuningEpoch()
|
||||||
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
|
}
|
42
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
42
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package flowcontrol
|
||||||
|
|
||||||
|
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
|
||||||
|
type flowController interface {
|
||||||
|
// for sending
|
||||||
|
SendWindowSize() protocol.ByteCount
|
||||||
|
UpdateSendWindow(protocol.ByteCount)
|
||||||
|
AddBytesSent(protocol.ByteCount)
|
||||||
|
// for receiving
|
||||||
|
AddBytesRead(protocol.ByteCount)
|
||||||
|
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||||
|
}
|
||||||
|
|
||||||
|
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||||
|
type StreamFlowController interface {
|
||||||
|
flowController
|
||||||
|
// for sending
|
||||||
|
IsBlocked() (bool, protocol.ByteCount)
|
||||||
|
// for receiving
|
||||||
|
// UpdateHighestReceived should be called when a new highest offset is received
|
||||||
|
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
||||||
|
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||||
|
// HasWindowUpdate says if it is necessary to update the window
|
||||||
|
HasWindowUpdate() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ConnectionFlowController is the flow controller for the connection.
|
||||||
|
type ConnectionFlowController interface {
|
||||||
|
flowController
|
||||||
|
// for sending
|
||||||
|
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectionFlowControllerI interface {
|
||||||
|
ConnectionFlowController
|
||||||
|
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||||
|
// for sending
|
||||||
|
EnsureMinimumWindowSize(protocol.ByteCount)
|
||||||
|
// for receiving
|
||||||
|
IncrementHighestReceived(protocol.ByteCount) error
|
||||||
|
}
|
147
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
147
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package flowcontrol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type streamFlowController struct {
|
||||||
|
baseFlowController
|
||||||
|
|
||||||
|
streamID protocol.StreamID
|
||||||
|
|
||||||
|
connection connectionFlowControllerI
|
||||||
|
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||||
|
|
||||||
|
receivedFinalOffset bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ StreamFlowController = &streamFlowController{}
|
||||||
|
|
||||||
|
// NewStreamFlowController gets a new flow controller for a stream
|
||||||
|
func NewStreamFlowController(
|
||||||
|
streamID protocol.StreamID,
|
||||||
|
contributesToConnection bool,
|
||||||
|
cfc ConnectionFlowController,
|
||||||
|
receiveWindow protocol.ByteCount,
|
||||||
|
maxReceiveWindow protocol.ByteCount,
|
||||||
|
initialSendWindow protocol.ByteCount,
|
||||||
|
rttStats *congestion.RTTStats,
|
||||||
|
) StreamFlowController {
|
||||||
|
return &streamFlowController{
|
||||||
|
streamID: streamID,
|
||||||
|
contributesToConnection: contributesToConnection,
|
||||||
|
connection: cfc.(connectionFlowControllerI),
|
||||||
|
baseFlowController: baseFlowController{
|
||||||
|
rttStats: rttStats,
|
||||||
|
receiveWindow: receiveWindow,
|
||||||
|
receiveWindowSize: receiveWindow,
|
||||||
|
maxReceiveWindowSize: maxReceiveWindow,
|
||||||
|
sendWindow: initialSendWindow,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
||||||
|
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
|
||||||
|
func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount, final bool) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// when receiving a final offset, check that this final offset is consistent with a final offset we might have received earlier
|
||||||
|
if final && c.receivedFinalOffset && byteOffset != c.highestReceived {
|
||||||
|
return qerr.Error(qerr.StreamDataAfterTermination, fmt.Sprintf("Received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, byteOffset))
|
||||||
|
}
|
||||||
|
// if we already received a final offset, check that the offset in the STREAM frames is below the final offset
|
||||||
|
if c.receivedFinalOffset && byteOffset > c.highestReceived {
|
||||||
|
return qerr.StreamDataAfterTermination
|
||||||
|
}
|
||||||
|
if final {
|
||||||
|
c.receivedFinalOffset = true
|
||||||
|
}
|
||||||
|
if byteOffset == c.highestReceived {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if byteOffset <= c.highestReceived {
|
||||||
|
// a STREAM_FRAME with a higher offset was received before.
|
||||||
|
if final {
|
||||||
|
// If the current byteOffset is smaller than the offset in that STREAM_FRAME, this STREAM_FRAME contained data after the end of the stream
|
||||||
|
return qerr.StreamDataAfterTermination
|
||||||
|
}
|
||||||
|
// this is a reordered STREAM_FRAME
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
increment := byteOffset - c.highestReceived
|
||||||
|
c.highestReceived = byteOffset
|
||||||
|
if c.checkFlowControlViolation() {
|
||||||
|
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
|
||||||
|
}
|
||||||
|
if c.contributesToConnection {
|
||||||
|
return c.connection.IncrementHighestReceived(increment)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||||
|
c.baseFlowController.AddBytesRead(n)
|
||||||
|
if c.contributesToConnection {
|
||||||
|
c.connection.AddBytesRead(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||||
|
c.baseFlowController.AddBytesSent(n)
|
||||||
|
if c.contributesToConnection {
|
||||||
|
c.connection.AddBytesSent(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||||
|
window := c.baseFlowController.sendWindowSize()
|
||||||
|
if c.contributesToConnection {
|
||||||
|
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
||||||
|
}
|
||||||
|
return window
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlocked says if it is blocked by stream-level flow control.
|
||||||
|
// If it is blocked, the offset is returned.
|
||||||
|
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
|
||||||
|
if c.sendWindowSize() != 0 {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
return true, c.sendWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamFlowController) HasWindowUpdate() bool {
|
||||||
|
c.mutex.Lock()
|
||||||
|
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||||
|
c.mutex.Unlock()
|
||||||
|
return hasWindowUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||||
|
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
|
||||||
|
c.mutex.Lock()
|
||||||
|
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
|
||||||
|
if c.receivedFinalOffset {
|
||||||
|
c.mutex.Unlock()
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
oldWindowSize := c.receiveWindowSize
|
||||||
|
offset := c.baseFlowController.getWindowUpdate()
|
||||||
|
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||||
|
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||||
|
if c.contributesToConnection {
|
||||||
|
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
|
return offset
|
||||||
|
}
|
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
Normal file
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/asn1"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cookiePrefixIP byte = iota
|
||||||
|
cookiePrefixString
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
|
||||||
|
type Cookie struct {
|
||||||
|
RemoteAddr string
|
||||||
|
// The time that the STK was issued (resolution 1 second)
|
||||||
|
SentTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// token is the struct that is used for ASN1 serialization and deserialization
|
||||||
|
type token struct {
|
||||||
|
Data []byte
|
||||||
|
Timestamp int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// A CookieGenerator generates Cookies
|
||||||
|
type CookieGenerator struct {
|
||||||
|
cookieProtector mint.CookieProtector
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCookieGenerator initializes a new CookieGenerator
|
||||||
|
func NewCookieGenerator() (*CookieGenerator, error) {
|
||||||
|
cookieProtector, err := mint.NewDefaultCookieProtector()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &CookieGenerator{
|
||||||
|
cookieProtector: cookieProtector,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToken generates a new Cookie for a given source address
|
||||||
|
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||||
|
data, err := asn1.Marshal(token{
|
||||||
|
Data: encodeRemoteAddr(raddr),
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return g.cookieProtector.NewToken(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeToken decodes a Cookie
|
||||||
|
func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
|
||||||
|
// if the client didn't send any Cookie, DecodeToken will be called with a nil-slice
|
||||||
|
if len(encrypted) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := g.cookieProtector.DecodeToken(encrypted)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := &token{}
|
||||||
|
rest, err := asn1.Unmarshal(data, t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(rest) != 0 {
|
||||||
|
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||||
|
}
|
||||||
|
return &Cookie{
|
||||||
|
RemoteAddr: decodeRemoteAddr(t.Data),
|
||||||
|
SentTime: time.Unix(t.Timestamp, 0),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
|
||||||
|
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||||
|
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||||
|
return append([]byte{cookiePrefixIP}, udpAddr.IP...)
|
||||||
|
}
|
||||||
|
return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeRemoteAddr decodes the remote address saved in the Cookie
|
||||||
|
func decodeRemoteAddr(data []byte) string {
|
||||||
|
// data will never be empty for a Cookie that we generated. Check it to be on the safe side
|
||||||
|
if len(data) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if data[0] == cookiePrefixIP {
|
||||||
|
return net.IP(data[1:]).String()
|
||||||
|
}
|
||||||
|
return string(data[1:])
|
||||||
|
}
|
43
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
Normal file
43
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CookieHandler struct {
|
||||||
|
callback func(net.Addr, *Cookie) bool
|
||||||
|
|
||||||
|
cookieGenerator *CookieGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ mint.CookieHandler = &CookieHandler{}
|
||||||
|
|
||||||
|
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
|
||||||
|
cookieGenerator, err := NewCookieGenerator()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &CookieHandler{
|
||||||
|
callback: callback,
|
||||||
|
cookieGenerator: cookieGenerator,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||||
|
if h.callback(conn.RemoteAddr(), nil) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||||
|
data, err := h.cookieGenerator.DecodeToken(token)
|
||||||
|
if err != nil {
|
||||||
|
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return h.callback(conn.RemoteAddr(), data)
|
||||||
|
}
|
@ -11,9 +11,9 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/crypto"
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,6 +23,7 @@ type cryptoSetupClient struct {
|
|||||||
hostname string
|
hostname string
|
||||||
connID protocol.ConnectionID
|
connID protocol.ConnectionID
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
initialVersion protocol.VersionNumber
|
||||||
negotiatedVersions []protocol.VersionNumber
|
negotiatedVersions []protocol.VersionNumber
|
||||||
|
|
||||||
cryptoStream io.ReadWriter
|
cryptoStream io.ReadWriter
|
||||||
@ -42,17 +43,18 @@ type cryptoSetupClient struct {
|
|||||||
|
|
||||||
clientHelloCounter int
|
clientHelloCounter int
|
||||||
serverVerified bool // has the certificate chain and the proof already been verified
|
serverVerified bool // has the certificate chain and the proof already been verified
|
||||||
keyDerivation KeyDerivationFunction
|
keyDerivation QuicCryptoKeyDerivationFunction
|
||||||
keyExchange KeyExchangeFunction
|
keyExchange KeyExchangeFunction
|
||||||
|
|
||||||
receivedSecurePacket bool
|
receivedSecurePacket bool
|
||||||
nullAEAD crypto.AEAD
|
nullAEAD crypto.AEAD
|
||||||
secureAEAD crypto.AEAD
|
secureAEAD crypto.AEAD
|
||||||
forwardSecureAEAD crypto.AEAD
|
forwardSecureAEAD crypto.AEAD
|
||||||
aeadChanged chan<- protocol.EncryptionLevel
|
|
||||||
|
paramsChan chan<- TransportParameters
|
||||||
|
handshakeEvent chan<- struct{}
|
||||||
|
|
||||||
params *TransportParameters
|
params *TransportParameters
|
||||||
connectionParameters ConnectionParametersManager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ CryptoSetup = &cryptoSetupClient{}
|
var _ CryptoSetup = &cryptoSetupClient{}
|
||||||
@ -65,36 +67,42 @@ var (
|
|||||||
|
|
||||||
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
||||||
func NewCryptoSetupClient(
|
func NewCryptoSetupClient(
|
||||||
|
cryptoStream io.ReadWriter,
|
||||||
hostname string,
|
hostname string,
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
cryptoStream io.ReadWriter,
|
|
||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
connectionParameters ConnectionParametersManager,
|
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
|
paramsChan chan<- TransportParameters,
|
||||||
|
handshakeEvent chan<- struct{},
|
||||||
|
initialVersion protocol.VersionNumber,
|
||||||
negotiatedVersions []protocol.VersionNumber,
|
negotiatedVersions []protocol.VersionNumber,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &cryptoSetupClient{
|
return &cryptoSetupClient{
|
||||||
|
cryptoStream: cryptoStream,
|
||||||
hostname: hostname,
|
hostname: hostname,
|
||||||
connID: connID,
|
connID: connID,
|
||||||
version: version,
|
version: version,
|
||||||
cryptoStream: cryptoStream,
|
|
||||||
certManager: crypto.NewCertManager(tlsConfig),
|
certManager: crypto.NewCertManager(tlsConfig),
|
||||||
connectionParameters: connectionParameters,
|
params: params,
|
||||||
keyDerivation: crypto.DeriveKeysAESGCM,
|
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||||
keyExchange: getEphermalKEX,
|
keyExchange: getEphermalKEX,
|
||||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
nullAEAD: nullAEAD,
|
||||||
aeadChanged: aeadChanged,
|
paramsChan: paramsChan,
|
||||||
|
handshakeEvent: handshakeEvent,
|
||||||
|
initialVersion: initialVersion,
|
||||||
negotiatedVersions: negotiatedVersions,
|
negotiatedVersions: negotiatedVersions,
|
||||||
divNonceChan: make(chan []byte),
|
divNonceChan: make(chan []byte),
|
||||||
params: params,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||||
messageChan := make(chan HandshakeMessage)
|
messageChan := make(chan HandshakeMessage)
|
||||||
errorChan := make(chan error)
|
errorChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@ -141,15 +149,21 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||||||
utils.Debugf("Got %s", message)
|
utils.Debugf("Got %s", message)
|
||||||
switch message.Tag {
|
switch message.Tag {
|
||||||
case TagREJ:
|
case TagREJ:
|
||||||
err = h.handleREJMessage(message.Data)
|
if err := h.handleREJMessage(message.Data); err != nil {
|
||||||
case TagSHLO:
|
return err
|
||||||
err = h.handleSHLOMessage(message.Data)
|
|
||||||
default:
|
|
||||||
return qerr.InvalidCryptoMessageType
|
|
||||||
}
|
}
|
||||||
|
case TagSHLO:
|
||||||
|
params, err := h.handleSHLOMessage(message.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// blocks until the session has received the parameters
|
||||||
|
h.paramsChan <- *params
|
||||||
|
h.handshakeEvent <- struct{}{}
|
||||||
|
close(h.handshakeEvent)
|
||||||
|
default:
|
||||||
|
return qerr.InvalidCryptoMessageType
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,12 +229,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
defer h.mutex.Unlock()
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
if !h.receivedSecurePacket {
|
if !h.receivedSecurePacket {
|
||||||
return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
||||||
}
|
}
|
||||||
|
|
||||||
if sno, ok := cryptoData[TagSNO]; ok {
|
if sno, ok := cryptoData[TagSNO]; ok {
|
||||||
@ -229,22 +243,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
|||||||
|
|
||||||
serverPubs, ok := cryptoData[TagPUBS]
|
serverPubs, ok := cryptoData[TagPUBS]
|
||||||
if !ok {
|
if !ok {
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||||
}
|
}
|
||||||
|
|
||||||
verTag, ok := cryptoData[TagVER]
|
verTag, ok := cryptoData[TagVER]
|
||||||
if !ok {
|
if !ok {
|
||||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
||||||
}
|
}
|
||||||
if !h.validateVersionList(verTag) {
|
if !h.validateVersionList(verTag) {
|
||||||
return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := append(h.nonc, h.sno...)
|
nonce := append(h.nonc, h.sno...)
|
||||||
|
|
||||||
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
leafCert := h.certManager.GetLeafCert()
|
leafCert := h.certManager.GetLeafCert()
|
||||||
@ -261,39 +275,32 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
|||||||
protocol.PerspectiveClient,
|
protocol.PerspectiveClient,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.connectionParameters.SetFromMap(cryptoData)
|
params, err := readHelloMap(cryptoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return qerr.InvalidCryptoMessageParameter
|
return nil, qerr.InvalidCryptoMessageParameter
|
||||||
}
|
}
|
||||||
|
return params, nil
|
||||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
|
||||||
close(h.aeadChanged)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
||||||
if len(h.negotiatedVersions) == 0 {
|
numNegotiatedVersions := len(h.negotiatedVersions)
|
||||||
|
if numNegotiatedVersions == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) {
|
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
b := bytes.NewReader(verTags)
|
b := bytes.NewReader(verTags)
|
||||||
for _, negotiatedVersion := range h.negotiatedVersions {
|
for i := 0; i < numNegotiatedVersions; i++ {
|
||||||
verTag, err := utils.ReadUint32(b)
|
v, err := utils.BigEndian.ReadUint32(b)
|
||||||
if err != nil { // should never occur, since the length was already checked
|
if err != nil { // should never occur, since the length was already checked
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
ver := protocol.VersionTagToNumber(verTag)
|
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
|
||||||
if !protocol.IsSupportedVersion(protocol.SupportedVersions, ver) {
|
|
||||||
ver = protocol.VersionUnsupported
|
|
||||||
}
|
|
||||||
if ver != negotiatedVersion {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -333,16 +340,16 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
|||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
defer h.mutex.RUnlock()
|
defer h.mutex.RUnlock()
|
||||||
if h.forwardSecureAEAD != nil {
|
if h.forwardSecureAEAD != nil {
|
||||||
return protocol.EncryptionForwardSecure, h.sealForwardSecure
|
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
||||||
} else if h.secureAEAD != nil {
|
} else if h.secureAEAD != nil {
|
||||||
return protocol.EncryptionSecure, h.sealSecure
|
return protocol.EncryptionSecure, h.secureAEAD
|
||||||
} else {
|
} else {
|
||||||
return protocol.EncryptionUnencrypted, h.sealUnencrypted
|
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
||||||
return protocol.EncryptionUnencrypted, h.sealUnencrypted
|
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
||||||
@ -351,33 +358,21 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
|
|||||||
|
|
||||||
switch encLevel {
|
switch encLevel {
|
||||||
case protocol.EncryptionUnencrypted:
|
case protocol.EncryptionUnencrypted:
|
||||||
return h.sealUnencrypted, nil
|
return h.nullAEAD, nil
|
||||||
case protocol.EncryptionSecure:
|
case protocol.EncryptionSecure:
|
||||||
if h.secureAEAD == nil {
|
if h.secureAEAD == nil {
|
||||||
return nil, errors.New("CryptoSetupClient: no secureAEAD")
|
return nil, errors.New("CryptoSetupClient: no secureAEAD")
|
||||||
}
|
}
|
||||||
return h.sealSecure, nil
|
return h.secureAEAD, nil
|
||||||
case protocol.EncryptionForwardSecure:
|
case protocol.EncryptionForwardSecure:
|
||||||
if h.forwardSecureAEAD == nil {
|
if h.forwardSecureAEAD == nil {
|
||||||
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
|
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
|
||||||
}
|
}
|
||||||
return h.sealForwardSecure, nil
|
return h.forwardSecureAEAD, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return h.nullAEAD.Seal(dst, src, packetNumber, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) DiversificationNonce() []byte {
|
func (h *cryptoSetupClient) DiversificationNonce() []byte {
|
||||||
panic("not needed for cryptoSetupClient")
|
panic("not needed for cryptoSetupClient")
|
||||||
}
|
}
|
||||||
@ -386,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
|||||||
h.divNonceChan <- data
|
h.divNonceChan <- data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||||
|
h.mutex.Lock()
|
||||||
|
defer h.mutex.Unlock()
|
||||||
|
return ConnectionState{
|
||||||
|
HandshakeComplete: h.forwardSecureAEAD != nil,
|
||||||
|
PeerCertificates: h.certManager.GetChain(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) sendCHLO() error {
|
func (h *cryptoSetupClient) sendCHLO() error {
|
||||||
h.clientHelloCounter++
|
h.clientHelloCounter++
|
||||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||||
@ -413,15 +417,11 @@ func (h *cryptoSetupClient) sendCHLO() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.lastSentCHLO = b.Bytes()
|
h.lastSentCHLO = b.Bytes()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
||||||
tags, err := h.connectionParameters.GetHelloMap()
|
tags := h.params.getHelloMap()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tags[TagSNI] = []byte(h.hostname)
|
tags[TagSNI] = []byte(h.hostname)
|
||||||
tags[TagPDMD] = []byte("X509")
|
tags[TagPDMD] = []byte("X509")
|
||||||
|
|
||||||
@ -431,12 +431,9 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
versionTag := make([]byte, 4)
|
versionTag := make([]byte, 4)
|
||||||
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
|
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
|
||||||
tags[TagVER] = versionTag
|
tags[TagVER] = versionTag
|
||||||
|
|
||||||
if h.params.RequestConnectionIDTruncation {
|
|
||||||
tags[TagTCID] = []byte{0, 0, 0, 0}
|
|
||||||
}
|
|
||||||
if len(h.stk) > 0 {
|
if len(h.stk) > 0 {
|
||||||
tags[TagSTK] = h.stk
|
tags[TagSTK] = h.stk
|
||||||
}
|
}
|
||||||
@ -470,7 +467,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
|
|||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
||||||
}
|
}
|
||||||
paddingSize := protocol.ClientHelloMinimumSize - size
|
paddingSize := protocol.MinClientHelloSize - size
|
||||||
if paddingSize > 0 {
|
if paddingSize > 0 {
|
||||||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
||||||
}
|
}
|
||||||
@ -508,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
h.handshakeEvent <- struct{}{}
|
||||||
h.aeadChanged <- protocol.EncryptionSecure
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user