vendor: Update dependencies

This commit is contained in:
Matthew Holt 2017-07-27 16:11:56 -06:00
parent 74940af624
commit a48e4ecb5a
No known key found for this signature in database
GPG Key ID: 2A349DD577D586A5
185 changed files with 24095 additions and 13722 deletions

View File

@ -119,7 +119,7 @@ GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32
#define TMP_1 256(SP) #define TMP_1 256(SP)
// func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamAVX2(SB),4,$320-80 TEXT ·xorKeyStreamAVX2(SB), 4, $320-80
MOVQ dst_base+0(FP), DI MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), CX MOVQ src_len+32(FP), CX
@ -182,6 +182,7 @@ at_least_512:
VPADDQ TWO, Y11, Y15 VPADDQ TWO, Y11, Y15
MOVQ DX, R9 MOVQ DX, R9
chacha_loop_512: chacha_loop_512:
VMOVDQA Y8, TMP_0 VMOVDQA Y8, TMP_0
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
@ -286,6 +287,7 @@ between_320_and_448:
VPADDQ TWO, Y7, Y11 VPADDQ TWO, Y7, Y11
MOVQ DX, R9 MOVQ DX, R9
chacha_loop_384: chacha_loop_384:
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
@ -357,6 +359,7 @@ between_192_and_320:
VPADDQ TWO, Y3, Y11 VPADDQ TWO, Y3, Y11
MOVQ DX, R9 MOVQ DX, R9
chacha_loop_256: chacha_loop_256:
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
@ -408,6 +411,7 @@ between_64_and_192:
VMOVDQA Y3, Y7 VMOVDQA Y3, Y7
MOVQ DX, R9 MOVQ DX, R9
chacha_loop_128: chacha_loop_128:
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y5, Y6, Y7) CHACHA_SHUFFLE(Y5, Y6, Y7)
@ -449,6 +453,7 @@ between_0_and_64:
VMOVDQA X3, X7 VMOVDQA X3, X7
MOVQ DX, R9 MOVQ DX, R9
chacha_loop_64: chacha_loop_64:
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE(X5, X6, X7) CHACHA_SHUFFLE(X5, X6, X7)
@ -481,6 +486,7 @@ finalize:
XORQ R11, R11 XORQ R11, R11
XORQ R12, R12 XORQ R12, R12
MOVQ CX, BP MOVQ CX, BP
xor_loop: xor_loop:
MOVB 0(SI), R11 MOVB 0(SI), R11
MOVB 0(BX), R12 MOVB 0(BX), R12
@ -513,6 +519,7 @@ TEXT ·hChaCha20AVX(SB), 4, $0-24
VMOVDQU ·rol8_AVX2<>(SB), X6 VMOVDQU ·rol8_AVX2<>(SB), X6
MOVQ $20, CX MOVQ $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)

View File

@ -30,7 +30,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
PSRLL $(32-n), v; \ PSRLL $(32-n), v; \
PXOR t, v PXOR t, v
#define CHACHA_QROUND_SSE2(v0 , v1 , v2 , v3 , t0) \ #define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
ROTL_SSE2(16, t0, v3); \ ROTL_SSE2(16, t0, v3); \
@ -44,7 +44,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
PXOR v2, v1; \ PXOR v2, v1; \
ROTL_SSE2(7, t0, v1) ROTL_SSE2(7, t0, v1)
#define CHACHA_QROUND_SSSE3(v0 , v1 , v2 , v3 , t0, r16, r8) \ #define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
PSHUFB r16, v3; \ PSHUFB r16, v3; \
@ -63,7 +63,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
PSHUFL $0x4E, v2, v2; \ PSHUFL $0x4E, v2, v2; \
PSHUFL $0x93, v3, v3 PSHUFL $0x93, v3, v3
#define XOR(dst, src, off, v0 , v1 , v2 , v3 , t0) \ #define XOR(dst, src, off, v0, v1, v2, v3, t0) \
MOVOU 0+off(src), t0; \ MOVOU 0+off(src), t0; \
PXOR v0, t0; \ PXOR v0, t0; \
MOVOU t0, 0+off(dst); \ MOVOU t0, 0+off(dst); \
@ -92,7 +92,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
JA finalize \ JA finalize \
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSE2(SB),4,$0-40 TEXT ·xorKeyStreamSSE2(SB), 4, $0-40
MOVL dst_base+0(FP), DI MOVL dst_base+0(FP), DI
MOVL src_base+12(FP), SI MOVL src_base+12(FP), SI
MOVL src_len+16(FP), CX MOVL src_len+16(FP), CX
@ -114,6 +114,7 @@ at_least_64:
MOVO X3, X7 MOVO X3, X7
MOVL DX, BX MOVL DX, BX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE(X5, X6, X7) CHACHA_SHUFFLE(X5, X6, X7)
@ -159,7 +160,7 @@ done:
RET RET
// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSSE3(SB),4,$64-40 TEXT ·xorKeyStreamSSSE3(SB), 4, $64-40
MOVL dst_base+0(FP), DI MOVL dst_base+0(FP), DI
MOVL src_base+12(FP), SI MOVL src_base+12(FP), SI
MOVL src_len+16(FP), CX MOVL src_len+16(FP), CX
@ -192,6 +193,7 @@ at_least_64:
MOVO X3, X7 MOVO X3, X7
MOVL DX, BX MOVL DX, BX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2)
CHACHA_SHUFFLE(X5, X6, X7) CHACHA_SHUFFLE(X5, X6, X7)
@ -268,6 +270,7 @@ TEXT ·hChaCha20SSE2(SB), 4, $0-12
MOVOU 0(AX), X3 MOVOU 0(AX), X3
MOVL $20, CX MOVL $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)
@ -294,6 +297,7 @@ TEXT ·hChaCha20SSSE3(SB), 4, $0-12
MOVOU ·rol8<>(SB), X6 MOVOU ·rol8<>(SB), X6
MOVL $20, CX MOVL $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)

View File

@ -30,7 +30,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
PSRLL $(32-n), v; \ PSRLL $(32-n), v; \
PXOR t, v PXOR t, v
#define CHACHA_QROUND_SSE2(v0 , v1 , v2 , v3 , t0) \ #define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
ROTL_SSE2(16, t0, v3); \ ROTL_SSE2(16, t0, v3); \
@ -44,7 +44,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
PXOR v2, v1; \ PXOR v2, v1; \
ROTL_SSE2(7, t0, v1) ROTL_SSE2(7, t0, v1)
#define CHACHA_QROUND_SSSE3(v0 , v1 , v2 , v3 , t0, r16, r8) \ #define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
PSHUFB r16, v3; \ PSHUFB r16, v3; \
@ -63,7 +63,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
PSHUFL $0x4E, v2, v2; \ PSHUFL $0x4E, v2, v2; \
PSHUFL $0x93, v3, v3 PSHUFL $0x93, v3, v3
#define XOR(dst, src, off, v0 , v1 , v2 , v3 , t0) \ #define XOR(dst, src, off, v0, v1, v2, v3, t0) \
MOVOU 0+off(src), t0; \ MOVOU 0+off(src), t0; \
PXOR v0, t0; \ PXOR v0, t0; \
MOVOU t0, 0+off(dst); \ MOVOU t0, 0+off(dst); \
@ -78,7 +78,7 @@ GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
MOVOU t0, 48+off(dst) MOVOU t0, 48+off(dst)
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSE2(SB),4,$112-80 TEXT ·xorKeyStreamSSE2(SB), 4, $112-80
MOVQ dst_base+0(FP), DI MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), CX MOVQ src_len+32(FP), CX
@ -115,6 +115,7 @@ TEXT ·xorKeyStreamSSE2(SB),4,$112-80
JBE between_128_and_192 JBE between_128_and_192
MOVQ $192, R14 MOVQ $192, R14
at_least_256: at_least_256:
MOVO X0, X4 MOVO X0, X4
MOVO X1, X5 MOVO X1, X5
@ -133,6 +134,7 @@ at_least_256:
PADDQ 64(SP), X11 PADDQ 64(SP), X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_256: chacha_loop_256:
MOVO X8, 80(SP) MOVO X8, 80(SP)
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8) CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8)
@ -236,6 +238,7 @@ between_128_and_192:
PADDQ X15, X11 PADDQ X15, X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_192: chacha_loop_192:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12) CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12)
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12)
@ -297,6 +300,7 @@ between_64_and_128:
PADDQ X15, X11 PADDQ X15, X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_128: chacha_loop_128:
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12)
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12)
@ -335,6 +339,7 @@ between_0_and_64:
MOVO X2, X10 MOVO X2, X10
MOVO X3, X11 MOVO X3, X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_64: chacha_loop_64:
CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12)
CHACHA_SHUFFLE(X9, X10, X11) CHACHA_SHUFFLE(X9, X10, X11)
@ -367,6 +372,7 @@ less_than_64:
XORQ R11, R11 XORQ R11, R11
XORQ R12, R12 XORQ R12, R12
MOVQ CX, BP MOVQ CX, BP
xor_loop: xor_loop:
MOVB 0(SI), R11 MOVB 0(SI), R11
MOVB 0(BX), R12 MOVB 0(BX), R12
@ -385,7 +391,7 @@ done:
RET RET
// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSSE3(SB),4,$144-80 TEXT ·xorKeyStreamSSSE3(SB), 4, $144-80
MOVQ dst_base+0(FP), DI MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), CX MOVQ src_len+32(FP), CX
@ -426,6 +432,7 @@ TEXT ·xorKeyStreamSSSE3(SB),4,$144-80
MOVO X13, 96(SP) MOVO X13, 96(SP)
MOVO X14, 112(SP) MOVO X14, 112(SP)
MOVQ $192, R14 MOVQ $192, R14
at_least_256: at_least_256:
MOVO X0, X4 MOVO X0, X4
MOVO X1, X5 MOVO X1, X5
@ -444,6 +451,7 @@ at_least_256:
PADDQ 64(SP), X11 PADDQ 64(SP), X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_256: chacha_loop_256:
MOVO X8, 80(SP) MOVO X8, 80(SP)
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP)) CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP))
@ -548,6 +556,7 @@ between_128_and_192:
PADDQ X15, X11 PADDQ X15, X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_192: chacha_loop_192:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14) CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14)
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14)
@ -609,6 +618,7 @@ between_64_and_128:
PADDQ X15, X11 PADDQ X15, X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_128: chacha_loop_128:
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14)
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14)
@ -647,6 +657,7 @@ between_0_and_64:
MOVO X2, X10 MOVO X2, X10
MOVO X3, X11 MOVO X3, X11
MOVQ DX, R8 MOVQ DX, R8
chacha_loop_64: chacha_loop_64:
CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14)
CHACHA_SHUFFLE(X9, X10, X11) CHACHA_SHUFFLE(X9, X10, X11)
@ -679,6 +690,7 @@ less_than_64:
XORQ R11, R11 XORQ R11, R11
XORQ R12, R12 XORQ R12, R12
MOVQ CX, BP MOVQ CX, BP
xor_loop: xor_loop:
MOVB 0(SI), R11 MOVB 0(SI), R11
MOVB 0(BX), R12 MOVB 0(BX), R12
@ -735,6 +747,7 @@ TEXT ·hChaCha20SSE2(SB), 4, $0-24
MOVOU 0(AX), X3 MOVOU 0(AX), X3
MOVQ $20, CX MOVQ $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)
@ -761,6 +774,7 @@ TEXT ·hChaCha20SSSE3(SB), 4, $0-24
MOVOU ·rol8<>(SB), X6 MOVOU ·rol8<>(SB), X6
MOVQ $20, CX MOVQ $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)

View File

@ -174,11 +174,11 @@ func sizeFixed32(x uint64) int {
// This is the format used for the sint64 protocol buffer type. // This is the format used for the sint64 protocol buffer type.
func (p *Buffer) EncodeZigzag64(x uint64) error { func (p *Buffer) EncodeZigzag64(x uint64) error {
// use signed number to get arithmetic right shift. // use signed number to get arithmetic right shift.
return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) return p.EncodeVarint((x << 1) ^ uint64((int64(x) >> 63)))
} }
func sizeZigzag64(x uint64) int { func sizeZigzag64(x uint64) int {
return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) return sizeVarint((x << 1) ^ uint64((int64(x) >> 63)))
} }
// EncodeZigzag32 writes a zigzag-encoded 32-bit integer // EncodeZigzag32 writes a zigzag-encoded 32-bit integer

View File

@ -865,7 +865,7 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
return p.readStruct(fv, terminator) return p.readStruct(fv, terminator)
case reflect.Uint32: case reflect.Uint32:
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil { if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
fv.SetUint(uint64(x)) fv.SetUint(x)
return nil return nil
} }
case reflect.Uint64: case reflect.Uint64:

View File

@ -6,9 +6,8 @@
// //
// Overview // Overview
// //
// The Conn type represents a WebSocket connection. A server application uses // The Conn type represents a WebSocket connection. A server application calls
// the Upgrade function from an Upgrader object with a HTTP request handler // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
// to get a pointer to a Conn:
// //
// var upgrader = websocket.Upgrader{ // var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024, // ReadBufferSize: 1024,
@ -33,7 +32,7 @@
// if err != nil { // if err != nil {
// return // return
// } // }
// if err = conn.WriteMessage(messageType, p); err != nil { // if err := conn.WriteMessage(messageType, p); err != nil {
// return err // return err
// } // }
// } // }
@ -147,9 +146,9 @@
// CheckOrigin: func(r *http.Request) bool { return true }, // CheckOrigin: func(r *http.Request) bool { return true },
// } // }
// //
// The deprecated Upgrade function does not enforce an origin policy. It's the // The deprecated package-level Upgrade function does not perform origin
// application's responsibility to check the Origin header before calling // checking. The application is responsible for checking the Origin header
// Upgrade. // before calling the Upgrade function.
// //
// Compression EXPERIMENTAL // Compression EXPERIMENTAL
// //

View File

@ -129,6 +129,9 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
} }
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
client.hub.register <- client client.hub.register <- client
// Allow collection of memory referenced by the caller by doing all work in
// new goroutines.
go client.writePump() go client.writePump()
client.readPump() go client.readPump()
} }

View File

@ -9,12 +9,14 @@ import (
"io" "io"
) )
// WriteJSON is deprecated, use c.WriteJSON instead. // WriteJSON writes the JSON encoding of v as a message.
//
// Deprecated: Use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error { func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v) return c.WriteJSON(v)
} }
// WriteJSON writes the JSON encoding of v to the connection. // WriteJSON writes the JSON encoding of v as a message.
// //
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
@ -31,7 +33,10 @@ func (c *Conn) WriteJSON(v interface{}) error {
return err2 return err2
} }
// ReadJSON is deprecated, use c.ReadJSON instead. // ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// Deprecated: Use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error { func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v) return c.ReadJSON(v)
} }

View File

@ -230,10 +230,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// This function is deprecated, use websocket.Upgrader instead. // Deprecated: Use websocket.Upgrader instead.
// //
// The application is responsible for checking the request origin before // Upgrade does not perform origin checking. The application is responsible for
// calling Upgrade. An example implementation of the same origin policy is: // checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
// //
// if req.Header.Get("Origin") != "http://"+req.Host { // if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403) // http.Error(w, "Origin not allowed", 403)

View File

@ -111,14 +111,14 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
case escape: case escape:
escape = false escape = false
p[j] = b p[j] = b
j += 1 j++
case b == '\\': case b == '\\':
escape = true escape = true
case b == '"': case b == '"':
return string(p[:j]), s[i+1:] return string(p[:j]), s[i+1:]
default: default:
p[j] = b p[j] = b
j += 1 j++
} }
} }
return "", "" return "", ""

View File

@ -9,6 +9,7 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error SentPacket(packet *Packet) error
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
@ -26,5 +27,6 @@ type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame GetAckFrame() *frames.AckFrame
} }

View File

@ -8,13 +8,6 @@ import (
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
) )
var (
// ErrDuplicatePacket occurres when a duplicate packet is received
ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet")
// ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored
ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting")
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct { type receivedPacketHandler struct {
@ -30,19 +23,13 @@ type receivedPacketHandler struct {
retransmittablePacketsReceivedSinceLastAck int retransmittablePacketsReceivedSinceLastAck int
ackQueued bool ackQueued bool
ackAlarm time.Time ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame lastAck *frames.AckFrame
} }
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler { func NewReceivedPacketHandler() ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{ return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(), packetHistory: newReceivedPacketHistory(),
ackAlarmResetCallback: ackAlarmResetCallback,
ackSendDelay: protocol.AckSendDelay, ackSendDelay: protocol.AckSendDelay,
} }
} }
@ -52,20 +39,11 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return errInvalidPacketNumber return errInvalidPacketNumber
} }
// if the packet number is smaller than the largest LeastUnacked value of a StopWaiting we received, we cannot detect if this packet has a duplicate number if packetNumber > h.ignorePacketsBelow {
// the packet has to be ignored anyway if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
if packetNumber <= h.ignorePacketsBelow {
return ErrPacketSmallerThanLastStopWaiting
}
if h.packetHistory.IsDuplicate(packetNumber) {
return ErrDuplicatePacket
}
err := h.packetHistory.ReceivedPacket(packetNumber)
if err != nil {
return err return err
} }
}
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
@ -89,7 +67,6 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
} }
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
var ackAlarmSet bool
h.packetsReceivedSinceLastAck++ h.packetsReceivedSinceLastAck++
if shouldInstigateAck { if shouldInstigateAck {
@ -124,7 +101,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
} else { } else {
if h.ackAlarm.IsZero() { if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay) h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
} }
} }
} }
@ -132,11 +108,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
if h.ackQueued { if h.ackQueued {
// cancel the ack alarm // cancel the ack alarm
h.ackAlarm = time.Time{} h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
} }
} }
@ -164,3 +135,5 @@ func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
return ack return ack
} }
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -2,9 +2,9 @@ package ackhandler
import ( import (
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type receivedPacketHistory struct { type receivedPacketHistory struct {

View File

@ -0,0 +1,38 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/frames"
)
// Returns a new slice with all non-retransmittable frames deleted.
func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
res := make([]frames.Frame, 0, len(fs))
for _, f := range fs {
if IsFrameRetransmittable(f) {
res = append(res, f)
}
}
return res
}
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f frames.Frame) bool {
switch f.(type) {
case *frames.StopWaitingFrame:
return false
case *frames.AckFrame:
return false
default:
return true
}
}
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
func HasRetransmittableFrames(fs []frames.Frame) bool {
for _, f := range fs {
if IsFrameRetransmittable(f) {
return true
}
}
return false
}

View File

@ -7,9 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
const ( const (
@ -106,26 +106,27 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
} }
} }
now := time.Now()
packet.SendTime = now
if packet.Length == 0 {
return errors.New("SentPacketHandler: packet cannot be empty")
}
h.bytesInFlight += packet.Length
h.lastSentPacketNumber = packet.PacketNumber h.lastSentPacketNumber = packet.PacketNumber
now := time.Now()
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet) h.packetHistory.PushBack(*packet)
}
h.congestion.OnPacketSent( h.congestion.OnPacketSent(
now, now,
h.bytesInFlight, h.bytesInFlight,
packet.PacketNumber, packet.PacketNumber,
packet.Length, packet.Length,
true, /* TODO: is retransmittable */ isRetransmittable,
) )
h.updateLossDetectionAlarm() h.updateLossDetectionAlarm()
return nil return nil
} }
@ -310,10 +311,11 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 { if len(h.retransmissionQueue) == 0 {
return nil return nil
} }
queueLen := len(h.retransmissionQueue) packet := h.retransmissionQueue[0]
// packets are usually NACKed in descending order. So use the slice as a stack // Shift the slice and don't retain anything that isn't needed.
packet := h.retransmissionQueue[queueLen-1] copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue = h.retransmissionQueue[:queueLen-1] h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet return packet
} }
@ -333,7 +335,11 @@ func (h *sentPacketHandler) SendingAllowed() bool {
h.bytesInFlight, h.bytesInFlight,
h.congestion.GetCongestionWindow()) h.congestion.GetCongestionWindow())
} }
return !(congestionLimited || maxTrackedLimited) // Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
// to RTOs, but we currently don't have a nice way of distinguishing them.
haveRetransmissions := len(h.retransmissionQueue) > 0
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
} }
func (h *sentPacketHandler) retransmitOldestTwoPackets() { func (h *sentPacketHandler) retransmitOldestTwoPackets() {

View File

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -9,9 +10,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type client struct { type client struct {
@ -24,6 +25,7 @@ type client struct {
errorChan chan struct{} errorChan chan struct{}
handshakeChan <-chan handshakeEvent handshakeChan <-chan handshakeEvent
tlsConf *tls.Config
config *Config config *Config
versionNegotiated bool // has version negotiation completed yet versionNegotiated bool // has version negotiation completed yet
@ -39,7 +41,7 @@ var (
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) { func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Dial(udpConn, udpAddr, addr, config) return Dial(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialAddrNonFWSecure establishes a new QUIC connection to a server. // DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) { func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -62,27 +68,41 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return DialNonFWSecure(udpConn, udpAddr, addr, config) return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. // DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) { func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID() connID, err := utils.GenerateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostname, _, err := net.SplitHostPort(host) var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID, connectionID: connID,
hostname: hostname, hostname: hostname,
tlsConf: tlsConf,
config: clientConfig, config: clientConfig,
version: clientConfig.Versions[0], version: clientConfig.Versions[0],
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
@ -93,15 +113,21 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
return nil, err return nil, err
} }
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection() return c.session.(NonFWSession), c.establishSecureConnection()
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { func Dial(
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config) pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,16 +138,38 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return sess, nil return sess, nil
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config { func populateClientConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions versions := config.Versions
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions
} }
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
}
return &Config{ return &Config{
TLSConfig: config.TLSConfig,
Versions: versions, Versions: versions,
HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
} }
} }
@ -163,31 +211,46 @@ func (c *client) listen() {
} }
data = data[:n] data = data[:n]
err = c.handlePacket(addr, data) c.handlePacket(addr, data)
if err != nil {
utils.Errorf("error handling packet: %s", err.Error())
c.session.Close(err)
break
}
} }
} }
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now() rcvTime := time.Now()
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
if err != nil { if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the Public Header
return
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := parsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
return
}
// ignore delayed / duplicated version negotiation packets // ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag { if c.versionNegotiated && hdr.VersionFlag {
return nil return
} }
// this is the first packet after the client sent a packet with the VersionFlag set // this is the first packet after the client sent a packet with the VersionFlag set
@ -198,7 +261,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
if hdr.VersionFlag { if hdr.VersionFlag {
// version negotiation packets have no payload // version negotiation packets have no payload
return c.handlePacketWithVersionFlag(hdr) if err := c.handlePacketWithVersionFlag(hdr); err != nil {
c.session.Close(err)
}
return
} }
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{
@ -207,7 +273,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
data: packet[len(packet)-r.Len():], data: packet[len(packet)-r.Len():],
rcvTime: rcvTime, rcvTime: rcvTime,
}) })
return nil
} }
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
@ -246,6 +311,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.hostname, c.hostname,
c.version, c.version,
c.connectionID, c.connectionID,
c.tlsConf,
c.config, c.config,
negotiatedVersions, negotiatedVersions,
) )

View File

@ -4,8 +4,8 @@ import (
"math" "math"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// This cubic implementation is based on the one found in Chromiums's QUIC // This cubic implementation is based on the one found in Chromiums's QUIC

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
const ( const (

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// Note(pwestin): the magic clamping numbers come from the original code in // Note(pwestin): the magic clamping numbers come from the original code in

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937

View File

@ -3,7 +3,7 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
const ( const (

View File

@ -102,3 +102,12 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
// If nothing matches, return the first certificate. // If nothing matches, return the first certificate.
return &c.Certificates[0], nil return &c.Certificates[0], nil
} }
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type entryType uint8 type entryType uint8

View File

@ -107,15 +107,14 @@ func (c *certManager) Verify(hostname string) error {
var opts x509.VerifyOptions var opts x509.VerifyOptions
if c.config != nil { if c.config != nil {
opts.Roots = c.config.RootCAs opts.Roots = c.config.RootCAs
opts.DNSName = c.config.ServerName
if c.config.Time == nil { if c.config.Time == nil {
opts.CurrentTime = time.Now() opts.CurrentTime = time.Now()
} else { } else {
opts.CurrentTime = c.config.Time() opts.CurrentTime = c.config.Time()
} }
} else {
opts.DNSName = hostname
} }
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates // the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 { if len(c.chain) > 1 {

View File

@ -1,14 +0,0 @@
// +build go1.8
package crypto
import "crypto/tls"
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}

View File

@ -1,9 +0,0 @@
// +build !go1.8
package crypto
import "crypto/tls"
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
return c, nil
}

View File

@ -5,8 +5,8 @@ import (
"crypto/sha256" "crypto/sha256"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )

View File

@ -8,7 +8,7 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
func main() { func main() {
@ -24,7 +24,7 @@ func main() {
utils.SetLogTimeFormat("") utils.SetLogTimeFormat("")
hclient := &http.Client{ hclient := &http.Client{
Transport: &h2quic.QuicRoundTripper{}, Transport: &h2quic.RoundTripper{},
} }
var wg sync.WaitGroup var wg sync.WaitGroup

View File

@ -31,10 +31,7 @@ func main() {
// Start a server that echos all data on the first stream opened by the client // Start a server that echos all data on the first stream opened by the client
func echoServer() error { func echoServer() error {
cfgServer := &quic.Config{ listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
TLSConfig: generateTLSConfig(),
}
listener, err := quic.ListenAddr(addr, cfgServer)
if err != nil { if err != nil {
return err return err
} }
@ -52,10 +49,7 @@ func echoServer() error {
} }
func clientMain() error { func clientMain() error {
cfgClient := &quic.Config{ session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
session, err := quic.DialAddr(addr, cfgClient)
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,7 +18,7 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type binds []string type binds []string

View File

@ -7,9 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowControlManager struct { type flowControlManager struct {
@ -78,7 +78,7 @@ func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset
if streamFlowController.ContributesToConnection() { if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment) f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() { if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
} }
} }
@ -107,7 +107,7 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b
if streamFlowController.ContributesToConnection() { if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment) f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() { if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
} }
} }
@ -157,6 +157,11 @@ func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (proto
f.mutex.RLock() f.mutex.RLock()
defer f.mutex.RUnlock() defer f.mutex.RUnlock()
// StreamID can be 0 when retransmitting
if streamID == 0 {
return f.connFlowController.receiveWindow, nil
}
flowController, err := f.getFlowController(streamID) flowController, err := f.getFlowController(streamID)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -6,8 +6,8 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowController struct { type flowController struct {

View File

@ -5,8 +5,8 @@ import (
"errors" "errors"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A BlockedFrame in QUIC // A BlockedFrame in QUIC

View File

@ -6,9 +6,9 @@ import (
"io" "io"
"math" "math"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A ConnectionCloseFrame in QUIC // A ConnectionCloseFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A GoawayFrame is a GOAWAY frame // A GoawayFrame is a GOAWAY frame

View File

@ -1,6 +1,6 @@
package frames package frames
import "github.com/lucas-clemente/quic-go/utils" import "github.com/lucas-clemente/quic-go/internal/utils"
// LogFrame logs a frame, either sent or received // LogFrame logs a frame, either sent or received
func LogFrame(frame Frame, sent bool) { func LogFrame(frame Frame, sent bool) {

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A RstStreamFrame in QUIC // A RstStreamFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StopWaitingFrame in QUIC // A StopWaitingFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StreamFrame of QUIC // A StreamFrame of QUIC

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A WindowUpdateFrame in QUIC // A WindowUpdateFrame in QUIC

View File

@ -15,59 +15,72 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// Client is a HTTP2 client doing QUIC requests type roundTripperOpts struct {
type Client struct { DisableCompression bool
}
var dialAddr = quic.DialAddr
// client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex mutex sync.RWMutex
dialAddr func(hostname string, config *quic.Config) (quic.Session, error) tlsConf *tls.Config
config *quic.Config config *quic.Config
opts *roundTripperOpts
t *QuicRoundTripper
hostname string hostname string
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
handshakeErr error handshakeErr error
dialChan chan struct{} // will be closed once the handshake is complete and the header stream has been opened dialOnce sync.Once
session quic.Session session quic.Session
headerStream quic.Stream headerStream quic.Stream
headerErr *qerr.QuicError headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response responses map[protocol.StreamID]chan *http.Response
} }
var _ h2quicClient = &Client{} var _ http.RoundTripper = &client{}
// NewClient creates a new client var defaultQuicConfig = &quic.Config{
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client { RequestConnectionIDTruncation: true,
return &Client{ KeepAlive: true,
t: t, }
dialAddr: quic.DialAddr,
// newClient creates a new client
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response), responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted, encryptionLevel: protocol.EncryptionUnencrypted,
config: &quic.Config{ tlsConf: tlsConfig,
TLSConfig: tlsConfig, config: config,
RequestConnectionIDTruncation: true, opts: opts,
}, headerErrored: make(chan struct{}),
dialChan: make(chan struct{}),
} }
} }
// Dial dials the connection // dial dials the connection
func (c *Client) Dial() (err error) { func (c *client) dial() error {
defer func() { var err error
c.handshakeErr = err c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
close(c.dialChan)
}()
c.session, err = c.dialAddr(c.hostname, c.config)
if err != nil { if err != nil {
return err return err
} }
@ -82,10 +95,10 @@ func (c *Client) Dial() (err error) {
} }
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream() go c.handleHeaderStream()
return return nil
} }
func (c *Client) handleHeaderStream() { func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream) h2framer := http2.NewFramer(nil, c.headerStream)
@ -111,7 +124,7 @@ func (c *Client) handleHeaderStream() {
} }
c.mutex.RLock() c.mutex.RLock()
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock() c.mutex.RUnlock()
if !ok { if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
@ -122,41 +135,38 @@ func (c *Client) handleHeaderStream() {
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) c.headerErr = qerr.Error(qerr.InternalError, err.Error())
} }
headerChan <- rsp responseChan <- rsp
} }
// stop all running request // stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
c.mutex.Lock() close(c.headerErrored)
for _, responseChan := range c.responses {
close(responseChan)
}
c.mutex.Unlock()
} }
// Do executes a request and returns a response // Roundtrip executes a request and returns a response
func (c *Client) Do(req *http.Request) (*http.Response, error) { func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one // TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" { if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme") return nil, errors.New("quic http2: unsupported scheme")
} }
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
utils.Debugf("%s vs %s", req.Host, c.hostname) return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
return nil, errors.New("h2quic Client BUG: Do called for the wrong client")
} }
hasBody := (req.Body != nil) c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
// wait until the handshake is complete
<-c.dialChan
if c.handshakeErr != nil { if c.handshakeErr != nil {
return nil, c.handshakeErr return nil, c.handshakeErr
} }
hasBody := (req.Body != nil)
responseChan := make(chan *http.Response) responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync() dataStream, err := c.session.OpenStreamSync()
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
c.mutex.Lock() c.mutex.Lock()
@ -164,14 +174,14 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mutex.Unlock() c.mutex.Unlock()
var requestedGzip bool var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true requestedGzip = true
} }
// TODO: add support for trailers // TODO: add support for trailers
endStream := !hasBody endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
@ -198,15 +208,15 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mutex.Lock() c.mutex.Lock()
delete(c.responses, dataStream.StreamID()) delete(c.responses, dataStream.StreamID())
c.mutex.Unlock() c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
case err := <-resc: case err := <-resc:
bodySent = true bodySent = true
if err != nil { if err != nil {
return nil, err return nil, err
} }
case <-c.headerErrored:
// an error occured on the header stream
_ = c.CloseWithError(c.headerErr)
return nil, c.headerErr
} }
} }
@ -230,11 +240,10 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
} }
res.Request = req res.Request = req
return res, nil return res, nil
} }
func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) { func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() { defer func() {
cerr := body.Close() cerr := body.Close()
if err == nil { if err == nil {
@ -252,8 +261,15 @@ func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
} }
// Close closes the client // Close closes the client
func (c *Client) Close(e error) { func (c *client) CloseWithError(e error) error {
_ = c.session.Close(e) if c.session == nil {
return nil
}
return c.session.Close(e)
}
func (c *client) Close() error {
return c.CloseWithError(nil)
} }
// copied from net/transport.go // copied from net/transport.go

View File

@ -13,8 +13,8 @@ import (
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type requestWriter struct { type requestWriter struct {

View File

@ -8,8 +8,8 @@ import (
"sync" "sync"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )

View File

@ -4,20 +4,23 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
) )
type h2quicClient interface { type roundTripCloser interface {
Dial() error http.RoundTripper
Do(*http.Request) (*http.Response, error) io.Closer
} }
// QuicRoundTripper implements the http.RoundTripper interface // RoundTripper implements the http.RoundTripper interface
type QuicRoundTripper struct { type RoundTripper struct {
mutex sync.Mutex mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from // DisableCompression, if true, prevents the Transport from
@ -34,13 +37,29 @@ type QuicRoundTripper struct {
// tls.Client. If nil, the default configuration is used. // tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
clients map[string]h2quicClient // QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
clients map[string]roundTripCloser
} }
var _ http.RoundTripper = &QuicRoundTripper{} // RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
// RoundTrip does a round trip var _ roundTripCloser = &RoundTripper{}
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil { if req.URL == nil {
closeRequestBody(req) closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL") return nil, errors.New("quic: nil Request.URL")
@ -76,35 +95,48 @@ func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
} }
hostname := authorityAddr("https", hostnameFromRequest(req)) hostname := authorityAddr("https", hostnameFromRequest(req))
client, err := r.getClient(hostname) cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.Do(req) return cl.RoundTrip(req)
} }
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { // RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.clients == nil { if r.clients == nil {
r.clients = make(map[string]h2quicClient) r.clients = make(map[string]roundTripCloser)
} }
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
if !ok { if !ok {
client = NewClient(r, r.TLSClientConfig, hostname) if onlyCached {
err := client.Dial() return nil, ErrNoCachedConn
if err != nil {
return nil, err
} }
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
r.clients[hostname] = client r.clients[hostname] = client
} }
return client, nil return client, nil
} }
func (r *QuicRoundTripper) disableCompression() bool { // Close closes the QUIC connections that this RoundTripper has used
return r.DisableCompression func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
} }
func closeRequestBody(req *http.Request) { func closeRequestBody(req *http.Request) {

View File

@ -13,9 +13,9 @@ import (
"time" "time"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
@ -29,10 +29,20 @@ type remoteCloser interface {
CloseRemote(protocol.ByteCount) CloseRemote(protocol.ByteCount)
} }
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections. // Server is a HTTP2 server listening for QUIC connections.
type Server struct { type Server struct {
*http.Server *http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use // Private flag for demo, do not use
CloseAfterFirstRequest bool CloseAfterFirstRequest bool
@ -69,11 +79,11 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
} }
// Serve an existing UDP connection. // Serve an existing UDP connection.
func (s *Server) Serve(conn *net.UDPConn) error { func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn) return s.serveImpl(s.TLSConfig, conn)
} }
func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
@ -83,17 +93,12 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
return errors.New("ListenAndServe may only be called once") return errors.New("ListenAndServe may only be called once")
} }
config := quic.Config{
TLSConfig: tlsConfig,
Versions: protocol.SupportedVersions,
}
var ln quic.Listener var ln quic.Listener
var err error var err error
if conn == nil { if conn == nil {
ln, err = quic.ListenAddr(s.Addr, &config) ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else { } else {
ln, err = quic.Listen(conn, &config) ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
} }
if err != nil { if err != nil {
s.listenerMutex.Unlock() s.listenerMutex.Unlock()
@ -255,7 +260,6 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC. // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443): // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
// Alternate-Protocol: 443:quic
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error { func (s *Server) SetQuicHeaders(hdr http.Header) error {
port := atomic.LoadUint32(&s.port) port := atomic.LoadUint32(&s.port)
@ -283,7 +287,6 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
} }
} }
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString)) hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil return nil

View File

@ -5,9 +5,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// ConnectionParametersManager negotiates and stores the connection parameters // ConnectionParametersManager negotiates and stores the connection parameters
@ -50,6 +50,8 @@ type connectionParametersManager struct {
sendConnectionFlowControlWindow protocol.ByteCount sendConnectionFlowControlWindow protocol.ByteCount
receiveStreamFlowControlWindow protocol.ByteCount receiveStreamFlowControlWindow protocol.ByteCount
receiveConnectionFlowControlWindow protocol.ByteCount receiveConnectionFlowControlWindow protocol.ByteCount
maxReceiveStreamFlowControlWindow protocol.ByteCount
maxReceiveConnectionFlowControlWindow protocol.ByteCount
} }
var _ ConnectionParametersManager = &connectionParametersManager{} var _ ConnectionParametersManager = &connectionParametersManager{}
@ -61,7 +63,10 @@ var (
) )
// NewConnectionParamatersManager creates a new connection parameters manager // NewConnectionParamatersManager creates a new connection parameters manager
func NewConnectionParamatersManager(pers protocol.Perspective, v protocol.VersionNumber) ConnectionParametersManager { func NewConnectionParamatersManager(
pers protocol.Perspective, v protocol.VersionNumber,
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
) ConnectionParametersManager {
h := &connectionParametersManager{ h := &connectionParametersManager{
perspective: pers, perspective: pers,
version: v, version: v,
@ -69,6 +74,8 @@ func NewConnectionParamatersManager(pers protocol.Perspective, v protocol.Versio
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
} }
if h.perspective == protocol.PerspectiveServer { if h.perspective == protocol.PerspectiveServer {
@ -207,10 +214,7 @@ func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protoc
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveStreamFlowControlWindow
return protocol.MaxReceiveStreamFlowControlWindowServer
}
return protocol.MaxReceiveStreamFlowControlWindowClient
} }
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data // GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
@ -222,10 +226,7 @@ func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() pr
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveConnectionFlowControlWindow
return protocol.MaxReceiveConnectionFlowControlWindowServer
}
return protocol.MaxReceiveConnectionFlowControlWindowClient
} }
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection // GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection

View File

@ -12,9 +12,9 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type cryptoSetupClient struct { type cryptoSetupClient struct {
@ -332,7 +332,6 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.sealForwardSecure
} else if h.secureAEAD != nil { } else if h.secureAEAD != nil {
@ -342,6 +341,10 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
} }
} }
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.sealUnencrypted
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()

View File

@ -10,9 +10,9 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// KeyDerivationFunction is used for key derivation // KeyDerivationFunction is used for key derivation
@ -214,12 +214,16 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
if h.forwardSecureAEAD != nil && h.sentSHLO {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.sealForwardSecure
} else if h.secureAEAD != nil { }
// secureAEAD and forwardSecureAEAD are created at the same time (when receiving the CHLO) return protocol.EncryptionUnencrypted, h.sealUnencrypted
// make sure that the SHLO isn't sent forward-secure }
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure return protocol.EncryptionSecure, h.sealSecure
} }
return protocol.EncryptionUnencrypted, h.sealUnencrypted return protocol.EncryptionUnencrypted, h.sealUnencrypted
@ -251,7 +255,6 @@ func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protoc
} }
func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
h.sentSHLO = true
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
} }

View File

@ -5,8 +5,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (

View File

@ -7,9 +7,9 @@ import (
"io" "io"
"sort" "sort"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A HandshakeMessage is a handshake message // A HandshakeMessage is a handshake message

View File

@ -15,6 +15,7 @@ type CryptoSetup interface {
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
} }
// TransportParameters are parameters sent to the peer during the handshake // TransportParameters are parameters sent to the peer during the handshake

View File

@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type serverConfigClient struct { type serverConfigClient struct {

View File

@ -0,0 +1,133 @@
package handshaketests
import (
"crypto/tls"
"fmt"
"net"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/proxy"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake integration tets", func() {
var (
proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config
testStartedAt time.Time
)
rtt := 400 * time.Millisecond
BeforeEach(func() {
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
})
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
for {
_, _ = server.Accept()
}
}()
}
expectDurationInRTTs := func(num int) {
testDuration := time.Since(testStartedAt)
expectedDuration := time.Duration(num) * rtt
Expect(testDuration).To(SatisfyAll(
BeNumerically(">=", expectedDuration),
BeNumerically("<", expectedDuration+rtt),
))
}
It("fails when there's no matching version, after 1 RTT", func() {
Expect(len(protocol.SupportedVersions)).To(BeNumerically(">", 1))
serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy()
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// TODO (marten-seemann): enable this test (see #625)
PIt("is secure after 2 RTTs", func() {
utils.SetLogLevel(utils.LogLevelDebug)
runServerAndProxy()
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
fmt.Println("#### is non fw secure ###")
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("is forward-secure after 2 RTTs when the server doesn't require an STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
})

View File

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"crypto/tls"
"io" "io"
"net" "net"
"time" "time"
@ -11,12 +10,32 @@ import (
// Stream is the interface implemented by QUIC streams // Stream is the interface implemented by QUIC streams
type Stream interface { type Stream interface {
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
io.Reader io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
io.Writer io.Writer
io.Closer io.Closer
StreamID() protocol.StreamID StreamID() protocol.StreamID
// Reset closes the stream with an error. // Reset closes the stream with an error.
Reset(error) Reset(error)
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
SetDeadline(t time.Time) error
} }
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
@ -37,6 +56,9 @@ type Session interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error Close(error) error
// WaitUntilClosed() blocks until the session is closed.
// Warning: This API should not be considered stable and might change soon.
WaitUntilClosed()
} }
// A NonFWSession is a QUIC connection between two peers half-way through the handshake. // A NonFWSession is a QUIC connection between two peers half-way through the handshake.
@ -61,21 +83,31 @@ type STK struct {
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. // More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct { type Config struct {
TLSConfig *tls.Config
// The QUIC versions that can be negotiated. // The QUIC versions that can be negotiated.
// If not set, it uses all versions available. // If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon. // Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber Versions []protocol.VersionNumber
// Ask the server to truncate the connection ID sent in the Public Header. // Ask the server to truncate the connection ID sent in the Public Header.
// If not set, the default checks if
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client. // Currently only valid for the client.
RequestConnectionIDTruncation bool RequestConnectionIDTruncation bool
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// AcceptSTK determines if an STK is accepted. // AcceptSTK determines if an STK is accepted.
// It is called with stk = nil if the client didn't send an STK. // It is called with stk = nil if the client didn't send an STK.
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours // If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
// This option is only valid for the server. // This option is only valid for the server.
AcceptSTK func(clientAddr net.Addr, stk *STK) bool AcceptSTK func(clientAddr net.Addr, stk *STK) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow protocol.ByteCount
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
} }
// A Listener for incoming QUIC connections // A Listener for incoming QUIC connections

View File

@ -0,0 +1,153 @@
// Automatically generated by MockGen. DO NOT EDIT!
// Source: github.com/lucas-clemente/quic-go/handshake (interfaces: ConnectionParametersManager)
package mocks
import (
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/handshake"
protocol "github.com/lucas-clemente/quic-go/protocol"
time "time"
)
// Mock of ConnectionParametersManager interface
type MockConnectionParametersManager struct {
ctrl *gomock.Controller
recorder *_MockConnectionParametersManagerRecorder
}
// Recorder for MockConnectionParametersManager (not exported)
type _MockConnectionParametersManagerRecorder struct {
mock *MockConnectionParametersManager
}
func NewMockConnectionParametersManager(ctrl *gomock.Controller) *MockConnectionParametersManager {
mock := &MockConnectionParametersManager{ctrl: ctrl}
mock.recorder = &_MockConnectionParametersManagerRecorder{mock}
return mock
}
func (_m *MockConnectionParametersManager) EXPECT() *_MockConnectionParametersManagerRecorder {
return _m.recorder
}
func (_m *MockConnectionParametersManager) GetHelloMap() (map[handshake.Tag][]byte, error) {
ret := _m.ctrl.Call(_m, "GetHelloMap")
ret0, _ := ret[0].(map[handshake.Tag][]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockConnectionParametersManagerRecorder) GetHelloMap() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetHelloMap")
}
func (_m *MockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
ret := _m.ctrl.Call(_m, "GetIdleConnectionStateLifetime")
ret0, _ := ret[0].(time.Duration)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetIdleConnectionStateLifetime() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetIdleConnectionStateLifetime")
}
func (_m *MockConnectionParametersManager) GetMaxIncomingStreams() uint32 {
ret := _m.ctrl.Call(_m, "GetMaxIncomingStreams")
ret0, _ := ret[0].(uint32)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxIncomingStreams() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxIncomingStreams")
}
func (_m *MockConnectionParametersManager) GetMaxOutgoingStreams() uint32 {
ret := _m.ctrl.Call(_m, "GetMaxOutgoingStreams")
ret0, _ := ret[0].(uint32)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxOutgoingStreams() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxOutgoingStreams")
}
func (_m *MockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetMaxReceiveConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetMaxReceiveStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetReceiveConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetReceiveStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetSendConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetSendConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetSendStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetSendStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) SetFromMap(_param0 map[handshake.Tag][]byte) error {
ret := _m.ctrl.Call(_m, "SetFromMap", _param0)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) SetFromMap(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "SetFromMap", arg0)
}
func (_m *MockConnectionParametersManager) TruncateConnectionID() bool {
ret := _m.ctrl.Call(_m, "TruncateConnectionID")
ret0, _ := ret[0].(bool)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) TruncateConnectionID() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "TruncateConnectionID")
}

View File

@ -0,0 +1,4 @@
package mocks
//go:generate mockgen -destination mocks_fc/flow_control_manager.go -package mocks_fc github.com/lucas-clemente/quic-go/flowcontrol FlowControlManager
//go:generate mockgen -destination cpm.go -package mocks github.com/lucas-clemente/quic-go/handshake ConnectionParametersManager

View File

@ -0,0 +1,140 @@
// Automatically generated by MockGen. DO NOT EDIT!
// Source: github.com/lucas-clemente/quic-go/flowcontrol (interfaces: FlowControlManager)
package mocks_fc
import (
gomock "github.com/golang/mock/gomock"
flowcontrol "github.com/lucas-clemente/quic-go/flowcontrol"
protocol "github.com/lucas-clemente/quic-go/protocol"
)
// Mock of FlowControlManager interface
type MockFlowControlManager struct {
ctrl *gomock.Controller
recorder *_MockFlowControlManagerRecorder
}
// Recorder for MockFlowControlManager (not exported)
type _MockFlowControlManagerRecorder struct {
mock *MockFlowControlManager
}
func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager {
mock := &MockFlowControlManager{ctrl: ctrl}
mock.recorder = &_MockFlowControlManagerRecorder{mock}
return mock
}
func (_m *MockFlowControlManager) EXPECT() *_MockFlowControlManagerRecorder {
return _m.recorder
}
func (_m *MockFlowControlManager) AddBytesRead(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesRead", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesRead", arg0, arg1)
}
func (_m *MockFlowControlManager) AddBytesSent(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesSent", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesSent", arg0, arg1)
}
func (_m *MockFlowControlManager) GetReceiveWindow(_param0 protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "GetReceiveWindow", _param0)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveWindow", arg0)
}
func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate {
ret := _m.ctrl.Call(_m, "GetWindowUpdates")
ret0, _ := ret[0].([]flowcontrol.WindowUpdate)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) GetWindowUpdates() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetWindowUpdates")
}
func (_m *MockFlowControlManager) NewStream(_param0 protocol.StreamID, _param1 bool) {
_m.ctrl.Call(_m, "NewStream", _param0, _param1)
}
func (_mr *_MockFlowControlManagerRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "NewStream", arg0, arg1)
}
func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) RemainingConnectionWindowSize() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "RemainingConnectionWindowSize")
}
func (_m *MockFlowControlManager) RemoveStream(_param0 protocol.StreamID) {
_m.ctrl.Call(_m, "RemoveStream", _param0)
}
func (_mr *_MockFlowControlManagerRecorder) RemoveStream(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "RemoveStream", arg0)
}
func (_m *MockFlowControlManager) ResetStream(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "ResetStream", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "ResetStream", arg0, arg1)
}
func (_m *MockFlowControlManager) SendWindowSize(_param0 protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "SendWindowSize", _param0)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) SendWindowSize(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "SendWindowSize", arg0)
}
func (_m *MockFlowControlManager) UpdateHighestReceived(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "UpdateHighestReceived", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateHighestReceived", arg0, arg1)
}
func (_m *MockFlowControlManager) UpdateWindow(_param0 protocol.StreamID, _param1 protocol.ByteCount) (bool, error) {
ret := _m.ctrl.Call(_m, "UpdateWindow", _param0, _param1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) UpdateWindow(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateWindow", arg0, arg1)
}

View File

@ -1,26 +1,26 @@
package utils package utils
import ( import (
"fmt"
"log" "log"
"os" "os"
"strconv"
"time" "time"
) )
// LogLevel of quic-go // LogLevel of quic-go
type LogLevel uint8 type LogLevel uint8
const ( const logEnv = "QUIC_GO_LOG_LEVEL"
logEnv = "QUIC_GO_LOG_LEVEL"
// LogLevelDebug enables debug logs (e.g. packet contents) const (
LogLevelDebug LogLevel = iota // LogLevelNothing disables
// LogLevelInfo enables info logs (e.g. packets) LogLevelNothing LogLevel = iota
LogLevelInfo
// LogLevelError enables err logs // LogLevelError enables err logs
LogLevelError LogLevelError
// LogLevelNothing disables // LogLevelInfo enables info logs (e.g. packets)
LogLevelNothing LogLevelInfo
// LogLevelDebug enables debug logs (e.g. packet contents)
LogLevelDebug
) )
var ( var (
@ -49,14 +49,14 @@ func Debugf(format string, args ...interface{}) {
// Infof logs something // Infof logs something
func Infof(format string, args ...interface{}) { func Infof(format string, args ...interface{}) {
if logLevel <= LogLevelInfo { if logLevel >= LogLevelInfo {
logMessage(format, args...) logMessage(format, args...)
} }
} }
// Errorf logs something // Errorf logs something
func Errorf(format string, args ...interface{}) { func Errorf(format string, args ...interface{}) {
if logLevel <= LogLevelError { if logLevel >= LogLevelError {
logMessage(format, args...) logMessage(format, args...)
} }
} }
@ -79,13 +79,16 @@ func init() {
} }
func readLoggingEnv() { func readLoggingEnv() {
env := os.Getenv(logEnv) switch os.Getenv(logEnv) {
if env == "" { case "":
return return
case "DEBUG":
logLevel = LogLevelDebug
case "INFO":
logLevel = LogLevelInfo
case "ERROR":
logLevel = LogLevelError
default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging")
} }
level, err := strconv.Atoi(env)
if err != nil {
return
}
logLevel = LogLevel(level)
} }

View File

@ -0,0 +1,43 @@
package utils
import "time"
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
t *time.Timer
read bool
deadline time.Time
}
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(0)}
}
// Chan returns the channel of the wrapped timer
func (t *Timer) Chan() <-chan time.Time {
return t.t.C
}
// Reset the timer, no matter whether the value was read or not
func (t *Timer) Reset(deadline time.Time) {
if deadline.Equal(t.deadline) {
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !t.t.Stop() && !t.read {
<-t.t.C
}
t.t.Reset(deadline.Sub(time.Now()))
t.read = false
t.deadline = deadline
}
// SetRead should be called after the value from the chan was read
func (t *Timer) SetRead() {
t.read = true
}

View File

@ -9,7 +9,6 @@ import (
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
) )
type packedPacket struct { type packedPacket struct {
@ -24,18 +23,24 @@ type packetPacker struct {
perspective protocol.Perspective perspective protocol.Perspective
version protocol.VersionNumber version protocol.VersionNumber
cryptoSetup handshake.CryptoSetup cryptoSetup handshake.CryptoSetup
// as long as packets are not sent with forward-secure encryption, we limit the MaxPacketSize such that they can be retransmitted as a whole
isForwardSecure bool
packetNumberGenerator *packetNumberGenerator packetNumberGenerator *packetNumberGenerator
connectionParameters handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
streamFramer *streamFramer streamFramer *streamFramer
controlFrames []frames.Frame controlFrames []frames.Frame
stopWaiting *frames.StopWaitingFrame
ackFrame *frames.AckFrame
leastUnacked protocol.PacketNumber
} }
func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, connectionParameters handshake.ConnectionParametersManager, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber) *packetPacker { func newPacketPacker(connectionID protocol.ConnectionID,
cryptoSetup handshake.CryptoSetup,
connectionParameters handshake.ConnectionParametersManager,
streamFramer *streamFramer,
perspective protocol.Perspective,
version protocol.VersionNumber,
) *packetPacker {
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
connectionID: connectionID, connectionID: connectionID,
@ -48,181 +53,168 @@ func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.C
} }
// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame) (*packedPacket, error) {
// in case the connection is closed, all queued control frames aren't of any use anymore frames := []frames.Frame{ccf}
// discard them and queue the ConnectionCloseFrame encLevel, sealer := p.cryptoSetup.GetSealer()
p.controlFrames = []frames.Frame{ccf} ph := p.getPublicHeader(encLevel)
return p.packPacket(nil, leastUnacked, nil) raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
} }
// RetransmitNonForwardSecurePacket retransmits a handshake packet, that was sent with less than forward-secure encryption func (p *packetPacker) PackAckPacket() (*packedPacket, error) {
func (p *packetPacker) RetransmitNonForwardSecurePacket(stopWaitingFrame *frames.StopWaitingFrame, packet *ackhandler.Packet) (*packedPacket, error) { if p.ackFrame == nil {
return nil, errors.New("packet packer BUG: no ack frame queued")
}
encLevel, sealer := p.cryptoSetup.GetSealer()
ph := p.getPublicHeader(encLevel)
frames := []frames.Frame{p.ackFrame}
if p.stopWaiting != nil {
p.stopWaiting.PacketNumber = ph.PacketNumber
p.stopWaiting.PacketNumberLen = ph.PacketNumberLen
frames = append(frames, p.stopWaiting)
p.stopWaiting = nil
}
p.ackFrame = nil
raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
}
// PackHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption
func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure { if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment") return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")
} }
if stopWaitingFrame == nil { sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel)
if err != nil {
return nil, err
}
if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
} }
ph := p.getPublicHeader(packet.EncryptionLevel)
return p.packPacket(stopWaitingFrame, 0, packet) p.stopWaiting.PacketNumber = ph.PacketNumber
p.stopWaiting.PacketNumberLen = ph.PacketNumberLen
frames := append([]frames.Frame{p.stopWaiting}, packet.Frames...)
p.stopWaiting = nil
raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: packet.EncryptionLevel,
}, err
} }
// PackPacket packs a new packet // PackPacket packs a new packet
// the stopWaitingFrame is *guaranteed* to be included in the next packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { func (p *packetPacker) PackPacket() (*packedPacket, error) {
p.controlFrames = append(p.controlFrames, controlFrames...) if p.streamFramer.HasCryptoStreamFrame() {
return p.packPacket(stopWaitingFrame, leastUnacked, nil) return p.packCryptoPacket()
} }
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber, handshakePacketToRetransmit *ackhandler.Packet) (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer()
// handshakePacketToRetransmit is only set for handshake retransmissions
isHandshakeRetransmission := (handshakePacketToRetransmit != nil)
var sealFunc handshake.Sealer publicHeader := p.getPublicHeader(encLevel)
var encLevel protocol.EncryptionLevel publicHeaderLength, err := publicHeader.GetLength(p.perspective)
if isHandshakeRetransmission {
var err error
encLevel = handshakePacketToRetransmit.EncryptionLevel
sealFunc, err = p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { if p.stopWaiting != nil {
encLevel, sealFunc = p.cryptoSetup.GetSealer() p.stopWaiting.PacketNumber = publicHeader.PacketNumber
p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen
} }
currentPacketNumber := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked)
responsePublicHeader := &PublicHeader{
ConnectionID: p.connectionID,
PacketNumber: currentPacketNumber,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
}
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure {
responsePublicHeader.VersionFlag = true
responsePublicHeader.VersionNumber = p.version
}
publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective)
if err != nil {
return nil, err
}
if stopWaitingFrame != nil {
stopWaitingFrame.PacketNumber = currentPacketNumber
stopWaitingFrame.PacketNumberLen = packetNumberLen
}
// we're packing a ConnectionClose, don't add any StreamFrames
var isConnectionClose bool
if len(p.controlFrames) == 1 {
_, isConnectionClose = p.controlFrames[0].(*frames.ConnectionCloseFrame)
}
var payloadFrames []frames.Frame
if isHandshakeRetransmission {
payloadFrames = append(payloadFrames, stopWaitingFrame)
// don't retransmit Acks and StopWaitings
for _, f := range handshakePacketToRetransmit.Frames {
switch f.(type) {
case *frames.AckFrame:
continue
case *frames.StopWaitingFrame:
continue
}
payloadFrames = append(payloadFrames, f)
}
} else if isConnectionClose {
payloadFrames = []frames.Frame{p.controlFrames[0]}
} else {
maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength
if !p.isForwardSecure { payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel))
maxSize -= protocol.NonForwardSecurePacketSizeReduction
}
payloadFrames, err = p.composeNextPacket(stopWaitingFrame, maxSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
// Check if we have enough frames to send // Check if we have enough frames to send
if len(payloadFrames) == 0 { if len(payloadFrames) == 0 {
return nil, nil return nil, nil
} }
// Don't send out packets that only contain a StopWaitingFrame // Don't send out packets that only contain a StopWaitingFrame
if len(payloadFrames) == 1 && stopWaitingFrame != nil { if len(payloadFrames) == 1 && p.stopWaiting != nil {
return nil, nil return nil, nil
} }
p.stopWaiting = nil
p.ackFrame = nil
raw := getPacketBuffer() raw, err := p.writeAndSealPacket(publicHeader, payloadFrames, sealer)
buffer := bytes.NewBuffer(raw)
if err = responsePublicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
var hasNonCryptoStreamData bool // does this frame contain any stream frame on a stream > 1
for _, frame := range payloadFrames {
if sf, ok := frame.(*frames.StreamFrame); ok && sf.StreamID != 1 {
hasNonCryptoStreamData = true
}
err = frame.Write(buffer, p.version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large")
}
raw = raw[0:buffer.Len()]
_ = sealFunc(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], currentPacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+12]
if hasNonCryptoStreamData && encLevel <= protocol.EncryptionUnencrypted {
return nil, qerr.AttemptToSendUnencryptedStreamData
}
num := p.packetNumberGenerator.Pop()
if num != currentPacketNumber {
return nil, errors.New("PacketPacker BUG: Peeked and Popped packet numbers do not match.")
}
return &packedPacket{ return &packedPacket{
number: currentPacketNumber, number: publicHeader.PacketNumber,
raw: raw, raw: raw,
frames: payloadFrames, frames: payloadFrames,
encryptionLevel: encLevel, encryptionLevel: encLevel,
}, nil }, nil
} }
func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFrame, maxFrameSize protocol.ByteCount) ([]frames.Frame, error) { func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
var payloadLength protocol.ByteCount encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream()
var payloadFrames []frames.Frame publicHeader := p.getPublicHeader(encLevel)
publicHeaderLength, err := publicHeader.GetLength(p.perspective)
if stopWaitingFrame != nil {
payloadFrames = append(payloadFrames, stopWaitingFrame)
minLength, err := stopWaitingFrame.MinLength(p.version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
payloadLength += minLength maxLen := protocol.MaxFrameAndPublicHeaderSize - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength
frames := []frames.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)}
raw, err := p.writeAndSealPacket(publicHeader, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
number: publicHeader.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
}
func (p *packetPacker) composeNextPacket(
maxFrameSize protocol.ByteCount,
canSendStreamFrames bool,
) ([]frames.Frame, error) {
var payloadLength protocol.ByteCount
var payloadFrames []frames.Frame
// STOP_WAITING and ACK will always fit
if p.stopWaiting != nil {
payloadFrames = append(payloadFrames, p.stopWaiting)
l, err := p.stopWaiting.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l
}
if p.ackFrame != nil {
payloadFrames = append(payloadFrames, p.ackFrame)
l, err := p.ackFrame.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l
} }
for len(p.controlFrames) > 0 { for len(p.controlFrames) > 0 {
frame := p.controlFrames[len(p.controlFrames)-1] frame := p.controlFrames[len(p.controlFrames)-1]
minLength, _ := frame.MinLength(p.version) // controlFrames does not contain any StopWaitingFrames. So it will *never* return an error minLength, err := frame.MinLength(p.version)
if err != nil {
return nil, err
}
if payloadLength+minLength > maxFrameSize { if payloadLength+minLength > maxFrameSize {
break break
} }
@ -235,6 +227,10 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra
return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize)
} }
if !canSendStreamFrames {
return payloadFrames, nil
}
// temporarily increase the maxFrameSize by 2 bytes // temporarily increase the maxFrameSize by 2 bytes
// this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set
// however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size // however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size
@ -257,10 +253,79 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra
return payloadFrames, nil return payloadFrames, nil
} }
func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) { func (p *packetPacker) QueueControlFrame(frame frames.Frame) {
switch f := frame.(type) {
case *frames.StopWaitingFrame:
p.stopWaiting = f
case *frames.AckFrame:
p.ackFrame = f
default:
p.controlFrames = append(p.controlFrames, f) p.controlFrames = append(p.controlFrames, f)
}
} }
func (p *packetPacker) SetForwardSecure() { func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *PublicHeader {
p.isForwardSecure = true pnum := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked)
publicHeader := &PublicHeader{
ConnectionID: p.connectionID,
PacketNumber: pnum,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
}
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure {
publicHeader.VersionFlag = true
publicHeader.VersionNumber = p.version
}
return publicHeader
}
func (p *packetPacker) writeAndSealPacket(
publicHeader *PublicHeader,
payloadFrames []frames.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)
if err := publicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
for _, frame := range payloadFrames {
err := frame.Write(buffer, p.version)
if err != nil {
return nil, err
}
}
if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large")
}
raw = raw[0:buffer.Len()]
_ = sealer(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+12]
num := p.packetNumberGenerator.Pop()
if num != publicHeader.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
return raw, nil
}
func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
if p.perspective == protocol.PerspectiveClient {
return encLevel >= protocol.EncryptionSecure
}
return encLevel == protocol.EncryptionForwardSecure
}
func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
p.leastUnacked = leastUnacked
} }

View File

@ -10,6 +10,11 @@ import (
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
type unpackedPacket struct {
encryptionLevel protocol.EncryptionLevel
frames []frames.Frame
}
type quicAEAD interface { type quicAEAD interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
} }

View File

@ -31,7 +31,7 @@ type StreamID uint32
type ByteCount uint64 type ByteCount uint64
// MaxByteCount is the maximum value of a ByteCount // MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = math.MaxUint64 const MaxByteCount = ByteCount(math.MaxUint64)
// MaxReceivePacketSize maximum packet size of any QUIC packet, based on // MaxReceivePacketSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,

View File

@ -39,21 +39,21 @@ const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB
// This is the value that Google servers are using // This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB
// MaxReceiveStreamFlowControlWindowServer is the maximum stream-level flow control window for receiving data // DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB const DefaultMaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB
// MaxReceiveConnectionFlowControlWindowServer is the connection-level flow control window for receiving data // DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB const DefaultMaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB
// MaxReceiveStreamFlowControlWindowClient is the maximum stream-level flow control window for receiving data, for the client // DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using // This is the value that Chromium is using
const MaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB const DefaultMaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB
// MaxReceiveConnectionFlowControlWindowClient is the connection-level flow control window for receiving data, for the server // DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB const DefaultMaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
// This is the value that Chromium is using // This is the value that Chromium is using
@ -128,8 +128,8 @@ const MaxIdleTimeoutServer = 1 * time.Minute
// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server // MaxIdleTimeoutClient is the idle timeout that the client suggests to the server
const MaxIdleTimeoutClient = 2 * time.Minute const MaxIdleTimeoutClient = 2 * time.Minute
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds. // DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const MaxTimeForCryptoHandshake = 10 * time.Second const DefaultHandshakeTimeout = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed // ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// after this time all information about the old connection will be deleted // after this time all information about the old connection will be deleted

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (

View File

@ -6,8 +6,8 @@ import (
"errors" "errors"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type publicReset struct { type publicReset struct {

View File

@ -3,7 +3,7 @@ package qerr
import ( import (
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
// ErrorCode can be used as a normal error without reason. // ErrorCode can be used as a normal error without reason.
@ -31,6 +31,16 @@ func (e *QuicError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage) return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage)
} }
func (e *QuicError) Timeout() bool {
switch e.ErrorCode {
case NetworkIdleTimeout,
HandshakeTimeout,
TimeoutsWithOpenStreams:
return true
}
return false
}
// ToQuicError converts an arbitrary error to a QuicError. It leaves QuicErrors // ToQuicError converts an arbitrary error to a QuicError. It leaves QuicErrors
// unchanged, and properly handles `ErrorCode`s. // unchanged, and properly handles `ErrorCode`s.
func ToQuicError(err error) *QuicError { func ToQuicError(err error) *QuicError {

View File

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"net" "net"
"sync" "sync"
@ -9,9 +10,9 @@ import (
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// packetHandler handles packets // packetHandler handles packets
@ -19,10 +20,12 @@ type packetHandler interface {
Session Session
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
run() error run() error
closeRemote(error)
} }
// A Listener of QUIC // A Listener of QUIC
type server struct { type server struct {
tlsConf *tls.Config
config *Config config *Config
conn net.PacketConn conn net.PacketConn
@ -38,14 +41,15 @@ type server struct {
sessionQueue chan Session sessionQueue chan Session
errorChan chan struct{} errorChan chan struct{}
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, config *Config) (packetHandler, <-chan handshakeEvent, error) newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error)
} }
var _ Listener = &server{} var _ Listener = &server{}
// ListenAddr creates a QUIC server listening on a given address. // ListenAddr creates a QUIC server listening on a given address.
// The listener is not active until Serve() is called. // The listener is not active until Serve() is called.
func ListenAddr(addr string, config *Config) (Listener, error) { // The tls.Config must not be nil, the quic.Config may be nil.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -54,13 +58,14 @@ func ListenAddr(addr string, config *Config) (Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Listen(conn, config) return Listen(conn, tlsConf, config)
} }
// Listen listens for QUIC connections on a given net.PacketConn. // Listen listens for QUIC connections on a given net.PacketConn.
// The listener is not active until Serve() is called. // The listener is not active until Serve() is called.
func Listen(conn net.PacketConn, config *Config) (Listener, error) { // The tls.Config must not be nil, the quic.Config may be nil.
certChain := crypto.NewCertChain(config.TLSConfig) func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
certChain := crypto.NewCertChain(tlsConf)
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
if err != nil { if err != nil {
return nil, err return nil, err
@ -72,6 +77,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
s := &server{ s := &server{
conn: conn, conn: conn,
tlsConf: tlsConf,
config: populateServerConfig(config), config: populateServerConfig(config),
certChain: certChain, certChain: certChain,
scfg: scfg, scfg: scfg,
@ -101,20 +107,42 @@ var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
return sourceAddr == stk.remoteAddr return sourceAddr == stk.remoteAddr
} }
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config { func populateServerConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions versions := config.Versions
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions
} }
vsa := defaultAcceptSTK vsa := defaultAcceptSTK
if config.AcceptSTK != nil { if config.AcceptSTK != nil {
vsa = config.AcceptSTK vsa = config.AcceptSTK
} }
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
}
return &Config{ return &Config{
TLSConfig: config.TLSConfig,
Versions: versions, Versions: versions,
HandshakeTimeout: handshakeTimeout,
AcceptSTK: vsa, AcceptSTK: vsa,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
} }
} }
@ -238,6 +266,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
version, version,
hdr.ConnectionID, hdr.ConnectionID,
s.scfg, s.scfg,
s.tlsConf,
s.config, s.config,
) )
if err != nil { if err != nil {

View File

@ -1,10 +1,11 @@
package quic package quic
import ( import (
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync/atomic" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/ackhandler"
@ -12,9 +13,9 @@ import (
"github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type unpacker interface { type unpacker interface {
@ -31,7 +32,6 @@ type receivedPacket struct {
var ( var (
errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream")
errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream")
errSessionAlreadyClosed = errors.New("cannot close session; it was already closed before")
) )
var ( var (
@ -54,6 +54,7 @@ type session struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
perspective protocol.Perspective perspective protocol.Perspective
version protocol.VersionNumber version protocol.VersionNumber
tlsConf *tls.Config
config *Config config *Config
conn connection conn connection
@ -77,8 +78,10 @@ type session struct {
sendingScheduled chan struct{} sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate. // closeChan is used to notify the run loop that it should terminate.
closeChan chan closeError closeChan chan closeError
// runClosed is closed once the run loop exits
// it is used to block Close() and WaitUntilClosed()
runClosed chan struct{} runClosed chan struct{}
closed uint32 // atomic bool closeOnce sync.Once
// when we receive too many undecryptable packets during the handshake, we send a Public reset // when we receive too many undecryptable packets during the handshake, we send a Public reset
// but only after a time of protocol.PublicResetTimeout has passed // but only after a time of protocol.PublicResetTimeout has passed
@ -97,8 +100,6 @@ type session struct {
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error // it receives at most 3 handshake events: 2 when the encryption level changes, and one error
handshakeChan chan<- handshakeEvent handshakeChan chan<- handshakeEvent
nextAckScheduledTime time.Time
connectionParameters handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
lastRcvdPacketNumber protocol.PacketNumber lastRcvdPacketNumber protocol.PacketNumber
@ -109,9 +110,10 @@ type session struct {
sessionCreationTime time.Time sessionCreationTime time.Time
lastNetworkActivityTime time.Time lastNetworkActivityTime time.Time
timer *time.Timer timer *utils.Timer
currentDeadline time.Time // keepAlivePingSent stores whether a Ping frame was sent to the peer or not
timerRead bool // it is reset as soon as we receive a packet from the peer
keepAlivePingSent bool
} }
var _ Session = &session{} var _ Session = &session{}
@ -122,6 +124,7 @@ func newSession(
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
sCfg *handshake.ServerConfig, sCfg *handshake.ServerConfig,
tlsConf *tls.Config,
config *Config, config *Config,
) (packetHandler, <-chan handshakeEvent, error) { ) (packetHandler, <-chan handshakeEvent, error) {
s := &session{ s := &session{
@ -130,46 +133,8 @@ func newSession(
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
config: config, config: config,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
} }
return s.setup(sCfg, "", nil)
s.setup()
cryptoStream, _ := s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan
verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool {
if hstk == nil {
return config.AcceptSTK(clientAddr, nil)
}
return config.AcceptSTK(
clientAddr,
&STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime},
)
}
var err error
s.cryptoSetup, err = newCryptoSetup(
connectionID,
conn.RemoteAddr(),
v,
sCfg,
cryptoStream,
s.connectionParameters,
config.Versions,
verifySourceAddr,
aeadChanged,
)
if err != nil {
return nil, nil, err
}
s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, handshakeChan, err
} }
// declare this as a variable, such that we can it mock it in the tests // declare this as a variable, such that we can it mock it in the tests
@ -178,6 +143,7 @@ var newClientSession = func(
hostname string, hostname string,
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
tlsConf *tls.Config,
config *Config, config *Config,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) { ) (packetHandler, <-chan handshakeEvent, error) {
@ -186,68 +152,92 @@ var newClientSession = func(
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
tlsConf: tlsConf,
config: config, config: config,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
} }
return s.setup(nil, hostname, negotiatedVersions)
}
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) func (s *session) setup(
s.setup() scfg *handshake.ServerConfig,
hostname string,
negotiatedVersions []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2) aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3) handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan s.handshakeChan = handshakeChan
cryptoStream, _ := s.OpenStream() s.runClosed = make(chan struct{})
var err error s.handshakeCompleteChan = make(chan error, 1)
s.cryptoSetup, err = newCryptoSetupClient(
hostname,
connectionID,
v,
cryptoStream,
config.TLSConfig,
s.connectionParameters,
aeadChanged,
&handshake.TransportParameters{RequestConnectionIDTruncation: config.RequestConnectionIDTruncation},
negotiatedVersions,
)
if err != nil {
return nil, nil, err
}
s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, handshakeChan, err
}
// setup is called from newSession and newClientSession and initializes values that are independent of the perspective
func (s *session) setup() {
s.rttStats = &congestion.RTTStats{}
flowControlManager := flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats)
sentPacketHandler := ackhandler.NewSentPacketHandler(s.rttStats)
now := time.Now()
s.sentPacketHandler = sentPacketHandler
s.flowControlManager = flowControlManager
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
s.runClosed = make(chan struct{})
s.handshakeCompleteChan = make(chan error, 1)
s.timer = time.NewTimer(0) s.timer = utils.NewTimer()
now := time.Now()
s.lastNetworkActivityTime = now s.lastNetworkActivityTime = now
s.sessionCreationTime = now s.sessionCreationTime = now
s.rttStats = &congestion.RTTStats{}
s.connectionParameters = handshake.NewConnectionParamatersManager(s.perspective, s.version,
s.config.MaxReceiveStreamFlowControlWindow, s.config.MaxReceiveConnectionFlowControlWindow)
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler()
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
var err error
if s.perspective == protocol.PerspectiveServer {
cryptoStream, _ := s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool {
var stk *STK
if hstk != nil {
stk = &STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime}
}
return s.config.AcceptSTK(clientAddr, stk)
}
s.cryptoSetup, err = newCryptoSetup(
s.connectionID,
s.conn.RemoteAddr(),
s.version,
scfg,
cryptoStream,
s.connectionParameters,
s.config.Versions,
verifySourceAddr,
aeadChanged,
)
} else {
cryptoStream, _ := s.OpenStream()
s.cryptoSetup, err = newCryptoSetupClient(
hostname,
s.connectionID,
s.version,
cryptoStream,
s.tlsConf,
s.connectionParameters,
aeadChanged,
&handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation},
negotiatedVersions,
)
}
if err != nil {
return nil, nil, err
}
s.packer = newPacketPacker(s.connectionID,
s.cryptoSetup,
s.connectionParameters,
s.streamFramer,
s.perspective,
s.version,
)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, handshakeChan, nil
} }
// run the session main loop // run the session main loop
@ -276,8 +266,8 @@ runLoop:
select { select {
case closeErr = <-s.closeChan: case closeErr = <-s.closeChan:
break runLoop break runLoop
case <-s.timer.C: case <-s.timer.Chan():
s.timerRead = true s.timer.SetRead()
// We do all the interesting stuff after the switch statement, so // We do all the interesting stuff after the switch statement, so
// nothing to see here. // nothing to see here.
case <-s.sendingScheduled: case <-s.sendingScheduled:
@ -290,7 +280,7 @@ runLoop:
s.tryQueueingUndecryptablePacket(p) s.tryQueueingUndecryptablePacket(p)
continue continue
} }
s.close(err) s.closeLocal(err)
continue continue
} }
// This is a bit unclean, but works properly, since the packet always // This is a bit unclean, but works properly, since the packet always
@ -303,32 +293,35 @@ runLoop:
close(s.handshakeChan) close(s.handshakeChan)
close(s.handshakeCompleteChan) close(s.handshakeCompleteChan)
} else { } else {
if l == protocol.EncryptionForwardSecure {
s.packer.SetForwardSecure()
}
s.tryDecryptingQueuedPackets() s.tryDecryptingQueuedPackets()
s.handshakeChan <- handshakeEvent{encLevel: l} s.handshakeChan <- handshakeEvent{encLevel: l}
} }
} }
now := time.Now() now := time.Now()
if s.sentPacketHandler.GetAlarmTimeout().Before(now) { if timeout := s.sentPacketHandler.GetAlarmTimeout(); !timeout.IsZero() && timeout.Before(now) {
// This could cause packets to be retransmitted, so check it before trying // This could cause packets to be retransmitted, so check it before trying
// to send packets. // to send packets.
s.sentPacketHandler.OnAlarm() s.sentPacketHandler.OnAlarm()
} }
if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.idleTimeout()/2 {
// send the PING frame since there is no activity in the session
s.packer.QueueControlFrame(&frames.PingFrame{})
s.keepAlivePingSent = true
}
if err := s.sendPacket(); err != nil { if err := s.sendPacket(); err != nil {
s.close(err) s.closeLocal(err)
} }
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 {
s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
} }
if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
} }
if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake { if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout {
s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time."))
} }
s.garbageCollectStreams() s.garbageCollectStreams()
} }
@ -344,37 +337,33 @@ runLoop:
return closeErr.err return closeErr.err
} }
func (s *session) maybeResetTimer() { func (s *session) WaitUntilClosed() {
nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) <-s.runClosed
}
if !s.nextAckScheduledTime.IsZero() { func (s *session) maybeResetTimer() {
nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime) var deadline time.Time
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
deadline = s.lastNetworkActivityTime.Add(s.idleTimeout() / 2)
} else {
deadline = s.lastNetworkActivityTime.Add(s.idleTimeout())
}
if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() {
deadline = utils.MinTime(deadline, ackAlarm)
} }
if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() { if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, lossTime) deadline = utils.MinTime(deadline, lossTime)
} }
if !s.handshakeComplete { if !s.handshakeComplete {
handshakeDeadline := s.sessionCreationTime.Add(protocol.MaxTimeForCryptoHandshake) handshakeDeadline := s.sessionCreationTime.Add(s.config.HandshakeTimeout)
nextDeadline = utils.MinTime(nextDeadline, handshakeDeadline) deadline = utils.MinTime(deadline, handshakeDeadline)
} }
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() { if !s.receivedTooManyUndecrytablePacketsTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout)) deadline = utils.MinTime(deadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout))
} }
if nextDeadline.Equal(s.currentDeadline) { s.timer.Reset(deadline)
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !s.timer.Stop() && !s.timerRead {
<-s.timer.C
}
s.timer.Reset(nextDeadline.Sub(time.Now()))
s.timerRead = false
s.currentDeadline = nextDeadline
} }
func (s *session) idleTimeout() time.Duration { func (s *session) idleTimeout() time.Duration {
@ -398,6 +387,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
} }
s.lastNetworkActivityTime = p.rcvTime s.lastNetworkActivityTime = p.rcvTime
s.keepAlivePingSent = false
hdr := p.publicHeader hdr := p.publicHeader
data := p.data data := p.data
@ -433,19 +423,8 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
// Only do this after decrypting, so we are sure the packet is not attacker-controlled // Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, packet.IsRetransmittable()) isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
// ignore duplicate packets if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil {
if err == ackhandler.ErrDuplicatePacket {
utils.Infof("Ignoring packet 0x%x due to ErrDuplicatePacket", hdr.PacketNumber)
return nil
}
// ignore packets with packet numbers smaller than the LeastUnacked of a StopWaiting
if err == ackhandler.ErrPacketSmallerThanLastStopWaiting {
utils.Infof("Ignoring packet 0x%x due to ErrPacketSmallerThanLastStopWaiting", hdr.PacketNumber)
return nil
}
if err != nil {
return err return err
} }
@ -462,7 +441,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.AckFrame: case *frames.AckFrame:
err = s.handleAckFrame(frame) err = s.handleAckFrame(frame)
case *frames.ConnectionCloseFrame: case *frames.ConnectionCloseFrame:
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
case *frames.GoawayFrame: case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames") err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame: case *frames.StopWaitingFrame:
@ -548,48 +527,31 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
} }
func (s *session) registerClose(e error, remoteClose bool) error { func (s *session) closeLocal(e error) {
// Only close once s.closeOnce.Do(func() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { s.closeChan <- closeError{err: e, remote: false}
return errSessionAlreadyClosed })
} }
if e == nil { func (s *session) closeRemote(e error) {
e = qerr.PeerGoingAway s.closeOnce.Do(func() {
} s.closeChan <- closeError{err: e, remote: true}
})
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
}
s.closeChan <- closeError{err: e, remote: remoteClose}
return nil
} }
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning // It waits until the run loop has stopped before returning
func (s *session) Close(e error) error { func (s *session) Close(e error) error {
err := s.registerClose(e, false) s.closeLocal(e)
if err == errSessionAlreadyClosed {
return nil
}
// wait for the run loop to finish
<-s.runClosed <-s.runClosed
return err
}
// close the connection. Use this when called from the run loop
func (s *session) close(e error) error {
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil return nil
}
return err
} }
func (s *session) handleCloseError(closeErr closeError) error { func (s *session) handleCloseError(closeErr closeError) error {
if closeErr.err == nil {
closeErr.err = qerr.PeerGoingAway
}
var quicErr *qerr.QuicError var quicErr *qerr.QuicError
var ok bool var ok bool
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok { if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
@ -602,13 +564,12 @@ func (s *session) handleCloseError(closeErr closeError) error {
utils.Errorf("Closing session with error: %s", closeErr.err.Error()) utils.Errorf("Closing session with error: %s", closeErr.err.Error())
} }
s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion { if closeErr.err == errCloseSessionForNewVersion {
return nil return nil
} }
s.streamsMap.CloseWithError(quicErr)
s.closeStreamsWithError(quicErr)
// If this is a remote close we're done here // If this is a remote close we're done here
if closeErr.remote { if closeErr.remote {
return nil return nil
@ -620,27 +581,37 @@ func (s *session) handleCloseError(closeErr closeError) error {
return s.sendConnectionClose(quicErr) return s.sendConnectionClose(quicErr)
} }
func (s *session) closeStreamsWithError(err error) {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
str.Cancel(err)
return true, nil
})
}
func (s *session) sendPacket() error { func (s *session) sendPacket() error {
// Repeatedly try sending until we don't have any more data, or run out of the congestion window s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
for {
if !s.sentPacketHandler.SendingAllowed() {
return nil
}
var controlFrames []frames.Frame // Get WindowUpdate frames
// get WindowUpdate frames
// this call triggers the flow controller to increase the flow control windows, if necessary // this call triggers the flow controller to increase the flow control windows, if necessary
windowUpdateFrames := s.getWindowUpdateFrames() windowUpdateFrames := s.getWindowUpdateFrames()
for _, wuf := range windowUpdateFrames { for _, wuf := range windowUpdateFrames {
controlFrames = append(controlFrames, wuf) s.packer.QueueControlFrame(wuf)
}
ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
s.packer.QueueControlFrame(ack)
}
// Repeatedly try sending until we don't have any more data, or run out of the congestion window
for {
if !s.sentPacketHandler.SendingAllowed() {
if ack == nil {
return nil
}
// If we aren't allowed to send, at least try sending an ACK frame
swf := s.sentPacketHandler.GetStopWaitingFrame(false)
if swf != nil {
s.packer.QueueControlFrame(swf)
}
packet, err := s.packer.PackAckPacket()
if err != nil {
return err
}
return s.sendPackedPacket(packet)
} }
// check for retransmissions first // check for retransmissions first
@ -649,75 +620,67 @@ func (s *session) sendPacket() error {
if retransmitPacket == nil { if retransmitPacket == nil {
break break
} }
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure {
if s.handshakeComplete {
// Don't retransmit handshake packets when the handshake is complete
continue
}
utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber)
stopWaitingFrame := s.sentPacketHandler.GetStopWaitingFrame(true) s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true))
var packet *packedPacket packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket)
packet, err := s.packer.RetransmitNonForwardSecurePacket(stopWaitingFrame, retransmitPacket)
if err != nil { if err != nil {
return err return err
} }
if packet == nil { if err = s.sendPackedPacket(packet); err != nil {
continue
}
err = s.sendPackedPacket(packet)
if err != nil {
return err return err
} }
continue
} else { } else {
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
// resend the frames that were in the packet // resend the frames that were in the packet
for _, frame := range retransmitPacket.GetFramesForRetransmission() { for _, frame := range retransmitPacket.GetFramesForRetransmission() {
switch frame.(type) { switch f := frame.(type) {
case *frames.StreamFrame: case *frames.StreamFrame:
s.streamFramer.AddFrameForRetransmission(frame.(*frames.StreamFrame)) s.streamFramer.AddFrameForRetransmission(f)
case *frames.WindowUpdateFrame: case *frames.WindowUpdateFrame:
// only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream // only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream
var currentOffset protocol.ByteCount
f := frame.(*frames.WindowUpdateFrame)
currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID) currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID)
if err == nil && f.ByteOffset >= currentOffset { if err == nil && f.ByteOffset >= currentOffset {
controlFrames = append(controlFrames, frame) s.packer.QueueControlFrame(f)
} }
default: default:
controlFrames = append(controlFrames, frame) s.packer.QueueControlFrame(frame)
} }
} }
} }
} }
ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
controlFrames = append(controlFrames, ack)
}
hasRetransmission := s.streamFramer.HasFramesForRetransmission() hasRetransmission := s.streamFramer.HasFramesForRetransmission()
var stopWaitingFrame *frames.StopWaitingFrame
if ack != nil || hasRetransmission { if ack != nil || hasRetransmission {
stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission)
if swf != nil {
s.packer.QueueControlFrame(swf)
} }
packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked()) }
if err != nil { packet, err := s.packer.PackPacket()
if err != nil || packet == nil {
return err return err
} }
if packet == nil { if err = s.sendPackedPacket(packet); err != nil {
return nil return err
} }
// send every window update twice // send every window update twice
for _, f := range windowUpdateFrames { for _, f := range windowUpdateFrames {
s.packer.QueueControlFrameForNextPacket(f) s.packer.QueueControlFrame(f)
} }
windowUpdateFrames = nil
err = s.sendPackedPacket(packet) ack = nil
if err != nil {
return err
}
s.nextAckScheduledTime = time.Time{}
} }
} }
func (s *session) sendPackedPacket(packet *packedPacket) error { func (s *session) sendPackedPacket(packet *packedPacket) error {
defer putPacketBuffer(packet.raw)
err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{ err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{
PacketNumber: packet.number, PacketNumber: packet.number,
Frames: packet.frames, Frames: packet.frames,
@ -727,22 +690,19 @@ func (s *session) sendPackedPacket(packet *packedPacket) error {
if err != nil { if err != nil {
return err return err
} }
s.logPacket(packet) s.logPacket(packet)
return s.conn.Write(packet.raw)
err = s.conn.Write(packet.raw)
putPacketBuffer(packet.raw)
return err
} }
func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error {
packet, err := s.packer.PackConnectionClose(&frames.ConnectionCloseFrame{ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage}, s.sentPacketHandler.GetLeastUnacked()) s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
packet, err := s.packer.PackConnectionClose(&frames.ConnectionCloseFrame{
ErrorCode: quicErr.ErrorCode,
ReasonPhrase: quicErr.ErrorMessage,
})
if err != nil { if err != nil {
return err return err
} }
if packet == nil {
return errors.New("Session BUG: expected packet not to be nil")
}
s.logPacket(packet) s.logPacket(packet)
return s.conn.Write(packet.raw) return s.conn.Write(packet.raw)
} }
@ -752,12 +712,10 @@ func (s *session) logPacket(packet *packedPacket) {
// We don't need to allocate the slices for calling the format functions // We don't need to allocate the slices for calling the format functions
return return
} }
if utils.Debug() { utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.number, len(packet.raw), s.connectionID, packet.encryptionLevel)
utils.Debugf("-> Sending packet 0x%x (%d bytes), %s", packet.number, len(packet.raw), packet.encryptionLevel)
for _, frame := range packet.frames { for _, frame := range packet.frames {
frames.LogFrame(frame, true) frames.LogFrame(frame, true)
} }
}
} }
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
@ -790,27 +748,21 @@ func (s *session) WaitUntilHandshakeComplete() error {
} }
func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ s.packer.QueueControlFrame(&frames.RstStreamFrame{
StreamID: id, StreamID: id,
ByteOffset: offset, ByteOffset: offset,
}) })
s.scheduleSending() s.scheduleSending()
} }
func (s *session) newStream(id protocol.StreamID) (*stream, error) { func (s *session) newStream(id protocol.StreamID) *stream {
stream, err := newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager)
if err != nil {
return nil, err
}
// TODO: find a better solution for determining which streams contribute to connection level flow control // TODO: find a better solution for determining which streams contribute to connection level flow control
if id == 1 || id == 3 { if id == 1 || id == 3 {
s.flowControlManager.NewStream(id, false) s.flowControlManager.NewStream(id, false)
} else { } else {
s.flowControlManager.NewStream(id, true) s.flowControlManager.NewStream(id, true)
} }
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager)
return stream, nil
} }
// garbageCollectStreams goes through all streams and removes EOF'ed streams // garbageCollectStreams goes through all streams and removes EOF'ed streams
@ -844,6 +796,7 @@ func (s *session) scheduleSending() {
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
if s.handshakeComplete { if s.handshakeComplete {
utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.publicHeader, len(p.data))
return return
} }
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
@ -875,11 +828,6 @@ func (s *session) getWindowUpdateFrames() []*frames.WindowUpdateFrame {
return res return res
} }
func (s *session) ackAlarmChanged(t time.Time) {
s.nextAckScheduledTime = t
s.maybeResetTimer()
}
func (s *session) LocalAddr() net.Addr { func (s *session) LocalAddr() net.Addr {
return s.conn.LocalAddr() return s.conn.LocalAddr()
} }

View File

@ -3,12 +3,14 @@ package quic
import ( import (
"fmt" "fmt"
"io" "io"
"net"
"sync" "sync"
"time"
"github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
@ -41,30 +43,40 @@ type stream struct {
resetRemotely utils.AtomicBool resetRemotely utils.AtomicBool
frameQueue *streamFrameSorter frameQueue *streamFrameSorter
newFrameOrErrCond sync.Cond readChan chan struct{}
readDeadline time.Time
dataForWriting []byte dataForWriting []byte
finSent utils.AtomicBool finSent utils.AtomicBool
rstSent utils.AtomicBool rstSent utils.AtomicBool
doneWritingOrErrCond sync.Cond writeChan chan struct{}
writeDeadline time.Time
flowControlManager flowcontrol.FlowControlManager flowControlManager flowcontrol.FlowControlManager
} }
type deadlineError struct{}
func (deadlineError) Error() string { return "deadline exceeded" }
func (deadlineError) Temporary() bool { return true }
func (deadlineError) Timeout() bool { return true }
var errDeadline net.Error = &deadlineError{}
// newStream creates a new Stream // newStream creates a new Stream
func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowControlManager flowcontrol.FlowControlManager) (*stream, error) { func newStream(StreamID protocol.StreamID,
s := &stream{ onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
flowControlManager flowcontrol.FlowControlManager) *stream {
return &stream{
onData: onData, onData: onData,
onReset: onReset, onReset: onReset,
streamID: StreamID, streamID: StreamID,
flowControlManager: flowControlManager, flowControlManager: flowControlManager,
frameQueue: newStreamFrameSorter(), frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
} }
s.newFrameOrErrCond.L = &s.mutex
s.doneWritingOrErrCond.L = &s.mutex
return s, nil
} }
// Read implements io.Reader. It is not thread safe! // Read implements io.Reader. It is not thread safe!
@ -83,10 +95,10 @@ func (s *stream) Read(p []byte) (int, error) {
for bytesRead < len(p) { for bytesRead < len(p) {
s.mutex.Lock() s.mutex.Lock()
frame := s.frameQueue.Head() frame := s.frameQueue.Head()
if frame == nil && bytesRead > 0 { if frame == nil && bytesRead > 0 {
err = s.err
s.mutex.Unlock() s.mutex.Unlock()
return bytesRead, s.err return bytesRead, err
} }
var err error var err error
@ -96,11 +108,28 @@ func (s *stream) Read(p []byte) (int, error) {
err = s.err err = s.err
break break
} }
deadline := s.readDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if frame != nil { if frame != nil {
s.readPosInFrame = int(s.readOffset - frame.Offset) s.readPosInFrame = int(s.readOffset - frame.Offset)
break break
} }
s.newFrameOrErrCond.Wait()
s.mutex.Unlock()
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
frame = s.frameQueue.Head() frame = s.frameQueue.Head()
} }
s.mutex.Unlock() s.mutex.Unlock()
@ -145,34 +174,49 @@ func (s *stream) Read(p []byte) (int, error) {
} }
func (s *stream) Write(p []byte) (int, error) { func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() {
return 0, s.err
}
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.err != nil { if s.resetLocally.Get() || s.err != nil {
return 0, s.err return 0, s.err
} }
if len(p) == 0 { if len(p) == 0 {
return 0, nil return 0, nil
} }
s.dataForWriting = make([]byte, len(p)) s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p) copy(s.dataForWriting, p)
s.onData() s.onData()
for s.dataForWriting != nil && s.err == nil { var err error
s.doneWritingOrErrCond.Wait() for {
deadline := s.writeDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if s.dataForWriting == nil || s.err != nil {
break
} }
s.mutex.Unlock()
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
}
if err != nil {
return 0, err
}
if s.err != nil { if s.err != nil {
return 0, s.err return len(p) - len(s.dataForWriting), s.err
} }
return len(p), nil return len(p), nil
} }
@ -188,14 +232,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
s.mutex.Lock() s.mutex.Lock()
if s.err != nil { defer s.mutex.Unlock()
s.mutex.Unlock()
return nil if s.err != nil || s.dataForWriting == nil {
}
if s.dataForWriting == nil {
s.mutex.Unlock()
return nil return nil
} }
var ret []byte var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes] ret = s.dataForWriting[:maxBytes]
@ -203,10 +245,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
} else { } else {
ret = s.dataForWriting ret = s.dataForWriting
s.dataForWriting = nil s.dataForWriting = nil
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
s.writeOffset += protocol.ByteCount(len(ret)) s.writeOffset += protocol.ByteCount(len(ret))
s.mutex.Unlock()
return ret return ret
} }
@ -249,7 +290,52 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error {
if err != nil && err != errDuplicateStreamData { if err != nil && err != errDuplicateStreamData {
return err return err
} }
s.newFrameOrErrCond.Signal() s.signalRead()
return nil
}
// signalRead performs a non-blocking send on the readChan
func (s *stream) signalRead() {
select {
case s.readChan <- struct{}{}:
default:
}
}
// signalRead performs a non-blocking send on the writeChan
func (s *stream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}
func (s *stream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.readDeadline
s.readDeadline = t
s.mutex.Unlock()
// if the new deadline is before the currently set deadline, wake up Read()
if t.Before(oldDeadline) {
s.signalRead()
}
return nil
}
func (s *stream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.writeDeadline
s.writeDeadline = t
s.mutex.Unlock()
if t.Before(oldDeadline) {
s.signalWrite()
}
return nil
}
func (s *stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t) // SetReadDeadline never errors
_ = s.SetWriteDeadline(t) // SetWriteDeadline never errors
return nil return nil
} }
@ -266,8 +352,8 @@ func (s *stream) Cancel(err error) {
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.newFrameOrErrCond.Signal() s.signalRead()
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
s.mutex.Unlock() s.mutex.Unlock()
} }
@ -282,8 +368,8 @@ func (s *stream) Reset(err error) {
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.newFrameOrErrCond.Signal() s.signalRead()
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
if s.shouldSendReset() { if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset) s.onReset(s.streamID, s.writeOffset)
@ -302,7 +388,7 @@ func (s *stream) RegisterRemoteError(err error) {
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
if s.shouldSendReset() { if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset) s.onReset(s.streamID, s.writeOffset)

View File

@ -4,8 +4,8 @@ import (
"errors" "errors"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type streamFrameSorter struct { type streamFrameSorter struct {

View File

@ -3,8 +3,8 @@ package quic
import ( import (
"github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type streamFramer struct { type streamFramer struct {
@ -45,6 +45,28 @@ func (f *streamFramer) HasFramesForRetransmission() bool {
return len(f.retransmissionQueue) > 0 return len(f.retransmissionQueue) > 0
} }
func (f *streamFramer) HasCryptoStreamFrame() bool {
// TODO(#657): Flow control
cs, _ := f.streamsMap.GetOrOpenStream(1)
return cs.lenOfDataForWriting() > 0
}
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
// TODO(#657): Flow control
func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *frames.StreamFrame {
if !f.HasCryptoStreamFrame() {
return nil
}
cs, _ := f.streamsMap.GetOrOpenStream(1)
frame := &frames.StreamFrame{
StreamID: 1,
Offset: cs.writeOffset,
}
frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error
frame.Data = cs.getDataForWriting(maxLen - frameHeaderBytes)
return frame
}
func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*frames.StreamFrame, currentLen protocol.ByteCount) { func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*frames.StreamFrame, currentLen protocol.ByteCount) {
for len(f.retransmissionQueue) > 0 { for len(f.retransmissionQueue) > 0 {
frame := f.retransmissionQueue[0] frame := f.retransmissionQueue[0]
@ -76,7 +98,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
var currentLen protocol.ByteCount var currentLen protocol.ByteCount
fn := func(s *stream) (bool, error) { fn := func(s *stream) (bool, error) {
if s == nil { if s == nil || s.streamID == 1 /* crypto stream is handled separately */ {
return true, nil return true, nil
} }
@ -90,7 +112,8 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
maxLen := maxBytes - currentLen - frameHeaderBytes maxLen := maxBytes - currentLen - frameHeaderBytes
var sendWindowSize protocol.ByteCount var sendWindowSize protocol.ByteCount
if s.lenOfDataForWriting() != 0 { lenStreamData := s.lenOfDataForWriting()
if lenStreamData != 0 {
sendWindowSize, _ = f.flowControlManager.SendWindowSize(s.streamID) sendWindowSize, _ = f.flowControlManager.SendWindowSize(s.streamID)
maxLen = utils.MinByteCount(maxLen, sendWindowSize) maxLen = utils.MinByteCount(maxLen, sendWindowSize)
} }
@ -99,7 +122,12 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
return true, nil return true, nil
} }
data := s.getDataForWriting(maxLen) var data []byte
if lenStreamData != 0 {
// Only getDataForWriting() if we didn't have data earlier, so that we
// don't send without FC approval (if a Write() raced).
data = s.getDataForWriting(maxLen)
}
// This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests.
shouldSendFin := s.shouldSendFin() shouldSendFin := s.shouldSendFin()

View File

@ -36,7 +36,7 @@ type streamsMap struct {
} }
type streamLambda func(*stream) (bool, error) type streamLambda func(*stream) (bool, error)
type newStreamLambda func(protocol.StreamID) (*stream, error) type newStreamLambda func(protocol.StreamID) *stream
var ( var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map") errMapAccess = errors.New("streamsMap: Error accessing the streams map")
@ -83,16 +83,28 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
return s, nil return s, nil
} }
if id <= m.highestStreamOpenedByPeer { if m.perspective == protocol.PerspectiveServer {
if id%2 == 0 {
if id <= m.nextStream { // this is a server-side stream that we already opened. Must have been closed already
return nil, nil return nil, nil
} }
if m.perspective == protocol.PerspectiveServer && id%2 == 0 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
} }
if m.perspective == protocol.PerspectiveClient && id%2 == 1 { if id <= m.highestStreamOpenedByPeer { // this is a client-side stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
}
if m.perspective == protocol.PerspectiveClient {
if id%2 == 1 {
if id <= m.nextStream { // this is a client-side stream that we already opened.
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
} }
if id <= m.highestStreamOpenedByPeer { // this is a server-side stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
}
// sid is the next stream that will be opened // sid is the next stream that will be opened
sid := m.highestStreamOpenedByPeer + 2 sid := m.highestStreamOpenedByPeer + 2
@ -120,11 +132,6 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
} }
s, err := m.newStream(id)
if err != nil {
return nil, err
}
if m.perspective == protocol.PerspectiveServer { if m.perspective == protocol.PerspectiveServer {
m.numIncomingStreams++ m.numIncomingStreams++
} else { } else {
@ -135,6 +142,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
m.highestStreamOpenedByPeer = id m.highestStreamOpenedByPeer = id
} }
s := m.newStream(id)
m.putStream(s) m.putStream(s)
return s, nil return s, nil
} }
@ -145,11 +153,6 @@ func (m *streamsMap) openStreamImpl() (*stream, error) {
return nil, qerr.TooManyOpenStreams return nil, qerr.TooManyOpenStreams
} }
s, err := m.newStream(id)
if err != nil {
return nil, err
}
if m.perspective == protocol.PerspectiveServer { if m.perspective == protocol.PerspectiveServer {
m.numOutgoingStreams++ m.numOutgoingStreams++
} else { } else {
@ -157,6 +160,7 @@ func (m *streamsMap) openStreamImpl() (*stream, error) {
} }
m.nextStream += 2 m.nextStream += 2
s := m.newStream(id)
m.putStream(s) m.putStream(s)
return s, nil return s, nil
} }
@ -319,8 +323,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
func (m *streamsMap) CloseWithError(err error) { func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err m.closeErr = err
m.nextStreamOrErrCond.Broadcast() m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast()
m.mutex.Unlock() for _, s := range m.openStreams {
m.streams[s].Cancel(err)
}
} }

View File

@ -1,31 +0,0 @@
package quic
import (
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
)
type unpackedPacket struct {
encryptionLevel protocol.EncryptionLevel
frames []frames.Frame
}
func (u *unpackedPacket) IsRetransmittable() bool {
for _, f := range u.frames {
switch f.(type) {
case *frames.StreamFrame:
return true
case *frames.RstStreamFrame:
return true
case *frames.WindowUpdateFrame:
return true
case *frames.BlockedFrame:
return true
case *frames.PingFrame:
return true
case *frames.GoawayFrame:
return true
}
}
return false
}

View File

@ -4,6 +4,7 @@ package dns
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
"io" "io"
@ -70,6 +71,43 @@ func Exchange(m *Msg, a string) (r *Msg, err error) {
return r, err return r, err
} }
// ExchangeContext performs a synchronous UDP query, like Exchange. It
// additionally obeys deadlines from the passed Context.
func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
// Combine context deadline with built-in timeout. Context chooses whichever
// is sooner.
timeoutCtx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
deadline, _ := timeoutCtx.Deadline()
co := new(Conn)
dialer := net.Dialer{}
co.Conn, err = dialer.DialContext(timeoutCtx, "udp", a)
if err != nil {
return nil, err
}
defer co.Conn.Close()
opt := m.IsEdns0()
// If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize {
co.UDPSize = opt.UDPSize()
}
co.SetWriteDeadline(deadline)
if err = co.WriteMsg(m); err != nil {
return nil, err
}
co.SetReadDeadline(deadline)
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, err
}
// ExchangeConn performs a synchronous query. It sends the message m via the connection // ExchangeConn performs a synchronous query. It sends the message m via the connection
// c and waits for a reply. The connection c is not closed by ExchangeConn. // c and waits for a reply. The connection c is not closed by ExchangeConn.
// This function is going away, but can easily be mimicked: // This function is going away, but can easily be mimicked:
@ -106,8 +144,18 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit // buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
// of 512 bytes. // of 512 bytes.
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
return c.ExchangeContext(context.Background(), m, a)
}
// ExchangeContext acts like Exchange, but honors the deadline on the provided
// context, if present. If there is both a context deadline and a configured
// timeout on the client, the earliest of the two takes effect.
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (
r *Msg,
rtt time.Duration,
err error) {
if !c.SingleInflight { if !c.SingleInflight {
return c.exchange(m, a) return c.exchange(ctx, m, a)
} }
// This adds a bunch of garbage, TODO(miek). // This adds a bunch of garbage, TODO(miek).
t := "nop" t := "nop"
@ -119,7 +167,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
cl = cl1 cl = cl1
} }
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
return c.exchange(m, a) return c.exchange(ctx, m, a)
}) })
if r != nil && shared { if r != nil && shared {
r = r.Copy() r = r.Copy()
@ -154,7 +202,7 @@ func (c *Client) writeTimeout() time.Duration {
return dnsTimeout return dnsTimeout
} }
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { func (c *Client) exchange(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
var co *Conn var co *Conn
network := "udp" network := "udp"
tls := false tls := false
@ -180,10 +228,13 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
deadline = time.Now().Add(c.Timeout) deadline = time.Now().Add(c.Timeout)
} }
dialDeadline := deadlineOrTimeoutOrCtx(ctx, deadline, c.dialTimeout())
dialTimeout := dialDeadline.Sub(time.Now())
if tls { if tls {
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout()) co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, dialTimeout)
} else { } else {
co, err = DialTimeout(network, a, c.dialTimeout()) co, err = DialTimeout(network, a, dialTimeout)
} }
if err != nil { if err != nil {
@ -202,12 +253,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
} }
co.TsigSecret = c.TsigSecret co.TsigSecret = c.TsigSecret
co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout())) co.SetWriteDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.writeTimeout()))
if err = co.WriteMsg(m); err != nil { if err = co.WriteMsg(m); err != nil {
return nil, 0, err return nil, 0, err
} }
co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout())) co.SetReadDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.readTimeout()))
r, err = co.ReadMsg() r, err = co.ReadMsg()
if err == nil && r.Id != m.Id { if err == nil && r.Id != m.Id {
err = ErrId err = ErrId
@ -459,9 +510,22 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
return conn, nil return conn, nil
} }
// deadlineOrTimeout chooses between the provided deadline and timeout
// by always preferring the deadline so long as it's non-zero (regardless
// of which is bigger), and returns the equivalent deadline value.
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time { func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
if deadline.IsZero() { if deadline.IsZero() {
return time.Now().Add(timeout) return time.Now().Add(timeout)
} }
return deadline return deadline
} }
// deadlineOrTimeoutOrCtx returns the earliest of: a context deadline, or the
// output of deadlineOrtimeout.
func deadlineOrTimeoutOrCtx(ctx context.Context, deadline time.Time, timeout time.Duration) time.Time {
result := deadlineOrTimeout(deadline, timeout)
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(result) {
result = ctxDeadline
}
return result
}

View File

@ -96,7 +96,7 @@ func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte,
return hdr, len(msg), msg, err return hdr, len(msg), msg, err
} }
msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength) msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
return hdr, off, msg, nil return hdr, off, msg, err
} }
// pack packs an RR header, returning the offset to the end of the header. // pack packs an RR header, returning the offset to the end of the header.

42
vendor/github.com/miekg/dns/types.go generated vendored
View File

@ -115,27 +115,27 @@ const (
ClassNONE = 254 ClassNONE = 254
ClassANY = 255 ClassANY = 255
// Message Response Codes. // Message Response Codes, see https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml
RcodeSuccess = 0 RcodeSuccess = 0 // NoError - No Error [DNS]
RcodeFormatError = 1 RcodeFormatError = 1 // FormErr - Format Error [DNS]
RcodeServerFailure = 2 RcodeServerFailure = 2 // ServFail - Server Failure [DNS]
RcodeNameError = 3 RcodeNameError = 3 // NXDomain - Non-Existent Domain [DNS]
RcodeNotImplemented = 4 RcodeNotImplemented = 4 // NotImp - Not Implemented [DNS]
RcodeRefused = 5 RcodeRefused = 5 // Refused - Query Refused [DNS]
RcodeYXDomain = 6 RcodeYXDomain = 6 // YXDomain - Name Exists when it should not [DNS Update]
RcodeYXRrset = 7 RcodeYXRrset = 7 // YXRRSet - RR Set Exists when it should not [DNS Update]
RcodeNXRrset = 8 RcodeNXRrset = 8 // NXRRSet - RR Set that should exist does not [DNS Update]
RcodeNotAuth = 9 RcodeNotAuth = 9 // NotAuth - Server Not Authoritative for zone [DNS Update]
RcodeNotZone = 10 RcodeNotZone = 10 // NotZone - Name not contained in zone [DNS Update/TSIG]
RcodeBadSig = 16 // TSIG RcodeBadSig = 16 // BADSIG - TSIG Signature Failure [TSIG]
RcodeBadVers = 16 // EDNS0 RcodeBadVers = 16 // BADVERS - Bad OPT Version [EDNS0]
RcodeBadKey = 17 RcodeBadKey = 17 // BADKEY - Key not recognized [TSIG]
RcodeBadTime = 18 RcodeBadTime = 18 // BADTIME - Signature out of time window [TSIG]
RcodeBadMode = 19 // TKEY RcodeBadMode = 19 // BADMODE - Bad TKEY Mode [TKEY]
RcodeBadName = 20 RcodeBadName = 20 // BADNAME - Duplicate key name [TKEY]
RcodeBadAlg = 21 RcodeBadAlg = 21 // BADALG - Algorithm not supported [TKEY]
RcodeBadTrunc = 22 // TSIG RcodeBadTrunc = 22 // BADTRUNC - Bad Truncation [TSIG]
RcodeBadCookie = 23 // DNS Cookies RcodeBadCookie = 23 // BADCOOKIE - Bad/missing Server Cookie [DNS Cookies]
// Message Opcodes. There is no 3. // Message Opcodes. There is no 3.
OpcodeQuery = 0 OpcodeQuery = 0

11
vendor/github.com/miekg/dns/xfr.go generated vendored
View File

@ -1,6 +1,7 @@
package dns package dns
import ( import (
"fmt"
"time" "time"
) )
@ -81,6 +82,10 @@ func (t *Transfer) inAxfr(id uint16, c chan *Envelope) {
return return
} }
if first { if first {
if in.Rcode != RcodeSuccess {
c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
return
}
if !isSOAFirst(in) { if !isSOAFirst(in) {
c <- &Envelope{in.Answer, ErrSoa} c <- &Envelope{in.Answer, ErrSoa}
return return
@ -126,6 +131,10 @@ func (t *Transfer) inIxfr(id uint16, c chan *Envelope) {
return return
} }
if first { if first {
if in.Rcode != RcodeSuccess {
c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
return
}
// A single SOA RR signals "no changes" // A single SOA RR signals "no changes"
if len(in.Answer) == 1 && isSOAFirst(in) { if len(in.Answer) == 1 && isSOAFirst(in) {
c <- &Envelope{in.Answer, nil} c <- &Envelope{in.Answer, nil}
@ -242,3 +251,5 @@ func isSOALast(in *Msg) bool {
} }
return false return false
} }
const errXFR = "bad xfr rcode: %d"

View File

@ -488,6 +488,7 @@ func link(p *parser, out *bytes.Buffer, data []byte, offset int) int {
} }
p.notes = append(p.notes, ref) p.notes = append(p.notes, ref)
p.notesRecord[string(ref.link)] = struct{}{}
link = ref.link link = ref.link
title = ref.title title = ref.title
@ -498,9 +499,10 @@ func link(p *parser, out *bytes.Buffer, data []byte, offset int) int {
return 0 return 0
} }
if t == linkDeferredFootnote { if t == linkDeferredFootnote && !p.isFootnote(lr) {
lr.noteId = len(p.notes) + 1 lr.noteId = len(p.notes) + 1
p.notes = append(p.notes, lr) p.notes = append(p.notes, lr)
p.notesRecord[string(lr.link)] = struct{}{}
} }
// keep link and title from reference // keep link and title from reference

View File

@ -219,6 +219,7 @@ type parser struct {
// presence. If a ref is also a footnote, it's stored both in refs and here // presence. If a ref is also a footnote, it's stored both in refs and here
// in notes. Slice is nil if footnotes not enabled. // in notes. Slice is nil if footnotes not enabled.
notes []*reference notes []*reference
notesRecord map[string]struct{}
} }
func (p *parser) getRef(refid string) (ref *reference, found bool) { func (p *parser) getRef(refid string) (ref *reference, found bool) {
@ -241,6 +242,11 @@ func (p *parser) getRef(refid string) (ref *reference, found bool) {
return ref, found return ref, found
} }
func (p *parser) isFootnote(ref *reference) bool {
_, ok := p.notesRecord[string(ref.link)]
return ok
}
// //
// //
// Public interface // Public interface
@ -376,6 +382,7 @@ func MarkdownOptions(input []byte, renderer Renderer, opts Options) []byte {
if extensions&EXTENSION_FOOTNOTES != 0 { if extensions&EXTENSION_FOOTNOTES != 0 {
p.notes = make([]*reference, 0) p.notes = make([]*reference, 0)
p.notesRecord = make(map[string]struct{})
} }
first := firstPass(p, input) first := firstPass(p, input)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"runtime" "runtime"
"strings" "strings"
@ -15,7 +16,17 @@ import (
var UserAgent string var UserAgent string
// HTTPClient is an HTTP client with a reasonable timeout value. // HTTPClient is an HTTP client with a reasonable timeout value.
var HTTPClient = http.Client{Timeout: 10 * time.Second} var HTTPClient = http.Client{
Transport: &http.Transport{
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 15 * time.Second,
ResponseHeaderTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
const ( const (
// defaultGoUserAgent is the Go HTTP package user agent string. Too // defaultGoUserAgent is the Go HTTP package user agent string. Too

View File

@ -207,7 +207,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
return nil, "", responseError(res) return nil, "", responseError(res)
} }
curl := res.Header.Get("location") // cert permanent URL curl := res.Header.Get("Location") // cert permanent URL
if res.ContentLength == 0 { if res.ContentLength == 0 {
// no cert in the body; poll until we get it // no cert in the body; poll until we get it
cert, err := c.FetchCert(ctx, curl, bundle) cert, err := c.FetchCert(ctx, curl, bundle)
@ -240,7 +240,7 @@ func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]by
if res.StatusCode > 299 { if res.StatusCode > 299 {
return nil, responseError(res) return nil, responseError(res)
} }
d := retryAfter(res.Header.Get("retry-after"), 3*time.Second) d := retryAfter(res.Header.Get("Retry-After"), 3*time.Second)
select { select {
case <-time.After(d): case <-time.After(d):
// retry // retry
@ -444,7 +444,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
if err != nil { if err != nil {
return nil, err return nil, err
} }
retry := res.Header.Get("retry-after") retry := res.Header.Get("Retry-After")
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted { if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
res.Body.Close() res.Body.Close()
if err := sleep(retry, 1); err != nil { if err := sleep(retry, 1); err != nil {
@ -703,7 +703,7 @@ func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string
// clear any nonces that we might've stored that might now be // clear any nonces that we might've stored that might now be
// considered bad // considered bad
c.clearNonces() c.clearNonces()
retry := res.Header.Get("retry-after") retry := res.Header.Get("Retry-After")
if err := sleep(retry, 1); err != nil { if err := sleep(retry, 1); err != nil {
return nil, err return nil, err
} }

View File

@ -36,6 +36,9 @@ import (
// operating system-specific cache or temp directory. This may not // operating system-specific cache or temp directory. This may not
// be suitable for servers spanning multiple machines. // be suitable for servers spanning multiple machines.
// //
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted // The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS // connections. The returned *tls.Conn are returned before their TLS
// handshake has completed. // handshake has completed.
@ -58,6 +61,9 @@ func NewListener(domains ...string) net.Listener {
// Listener listens on the standard TLS port (443) on all interfaces // Listener listens on the standard TLS port (443) on all interfaces
// and returns a net.Listener returning *tls.Conn connections. // and returns a net.Listener returning *tls.Conn connections.
// //
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted // The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS // connections. The returned *tls.Conn are returned before their TLS
// handshake has completed. // handshake has completed.
@ -69,6 +75,7 @@ func (m *Manager) Listener() net.Listener {
m: m, m: m,
conf: &tls.Config{ conf: &tls.Config{
GetCertificate: m.GetCertificate, // bonus: panic on nil m GetCertificate: m.GetCertificate, // bonus: panic on nil m
NextProtos: []string{"h2", "http/1.1"}, // Enable HTTP/2
}, },
} }
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443") ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")

View File

@ -3,6 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
#define REDMASK51 0x0007FFFFFFFFFFFF #define REDMASK51 0x0007FFFFFFFFFFFF

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package curve25519 provides an implementation of scalar multiplication on // Package curve25519 provides an implementation of scalar multiplication on
// the elliptic curve known as curve25519. See http://cr.yp.to/ecdh.html // the elliptic curve known as curve25519. See https://cr.yp.to/ecdh.html
package curve25519 // import "golang.org/x/crypto/curve25519" package curve25519 // import "golang.org/x/crypto/curve25519"
// basePoint is the x coordinate of the generator of the curve. // basePoint is the x coordinate of the generator of the curve.

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

Some files were not shown because too many files have changed in this diff Show More